272 lines
5.9 KiB
Go
272 lines
5.9 KiB
Go
package client
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"errors"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
|
|
"github.com/charmbracelet/log"
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
type ResponseFormat struct {
|
|
Type string `json:"type"`
|
|
JsonSchema any `json:"json_schema,omitzero"`
|
|
}
|
|
|
|
type AgentRequestBody struct {
|
|
Model string `json:"model"`
|
|
Temperature float64 `json:"temperature"`
|
|
ResponseFormat ResponseFormat `json:"response_format"`
|
|
|
|
Tools *any `json:"tools,omitempty"`
|
|
ToolChoice *string `json:"tool_choice,omitempty"`
|
|
|
|
RandomSeed *int `json:"random_seed,omitempty"`
|
|
|
|
EndToolCall string `json:"-"`
|
|
|
|
Chat *Chat `json:"messages"`
|
|
}
|
|
|
|
func (req AgentRequestBody) MarshalJSON() ([]byte, error) {
|
|
return json.Marshal(&struct {
|
|
Model string `json:"model"`
|
|
Temperature float64 `json:"temperature"`
|
|
ResponseFormat ResponseFormat `json:"response_format"`
|
|
|
|
Tools *any `json:"tools,omitempty"`
|
|
ToolChoice *string `json:"tool_choice,omitempty"`
|
|
Messages []ChatMessage `json:"messages"`
|
|
}{
|
|
Model: req.Model,
|
|
Temperature: req.Temperature,
|
|
ResponseFormat: req.ResponseFormat,
|
|
|
|
Tools: req.Tools,
|
|
ToolChoice: req.ToolChoice,
|
|
|
|
Messages: req.Chat.Messages,
|
|
})
|
|
}
|
|
|
|
type ResponseChoice struct {
|
|
Index int `json:"index"`
|
|
Message ChatAiMessage `json:"message"`
|
|
FinishReason string `json:"finish_reason"`
|
|
}
|
|
|
|
type AgentResponse struct {
|
|
Id string `json:"id"`
|
|
Object string `json:"object"`
|
|
Choices []ResponseChoice `json:"choices"`
|
|
Created int `json:"created"`
|
|
}
|
|
|
|
type AgentClient struct {
|
|
url string
|
|
apiKey string
|
|
responseFormat string
|
|
|
|
ToolHandler ToolsHandlers
|
|
|
|
Log *log.Logger
|
|
|
|
Reply string
|
|
|
|
Do func(req *http.Request) (*http.Response, error)
|
|
|
|
Options CreateAgentClientOptions
|
|
}
|
|
|
|
const OPENAI_API_KEY = "REAL_OPEN_AI_KEY"
|
|
|
|
type CreateAgentClientOptions struct {
|
|
Log *log.Logger
|
|
SystemPrompt string
|
|
JsonTools string
|
|
EndToolCall string
|
|
Query *string
|
|
}
|
|
|
|
func CreateAgentClient(options CreateAgentClientOptions) AgentClient {
|
|
apiKey := os.Getenv(OPENAI_API_KEY)
|
|
|
|
if len(apiKey) == 0 {
|
|
panic("No api key")
|
|
}
|
|
|
|
return AgentClient{
|
|
apiKey: apiKey,
|
|
url: "https://api.openai.com/v1/chat/completions",
|
|
Do: func(req *http.Request) (*http.Response, error) {
|
|
client := &http.Client{}
|
|
return client.Do(req)
|
|
},
|
|
|
|
Log: options.Log,
|
|
|
|
ToolHandler: ToolsHandlers{
|
|
handlers: map[string]ToolHandler{},
|
|
},
|
|
|
|
Options: options,
|
|
}
|
|
}
|
|
|
|
func (client AgentClient) getRequest(body []byte) (*http.Request, error) {
|
|
req, err := http.NewRequest("POST", client.url, bytes.NewBuffer(body))
|
|
if err != nil {
|
|
return req, err
|
|
}
|
|
|
|
req.Header.Add("Authorization", "Bearer "+client.apiKey)
|
|
req.Header.Add("Content-Type", "application/json")
|
|
|
|
return req, nil
|
|
}
|
|
|
|
func (client AgentClient) Request(req *AgentRequestBody) (AgentResponse, error) {
|
|
jsonAiRequest, err := json.Marshal(req)
|
|
if err != nil {
|
|
return AgentResponse{}, err
|
|
}
|
|
|
|
httpRequest, err := client.getRequest(jsonAiRequest)
|
|
if err != nil {
|
|
return AgentResponse{}, err
|
|
}
|
|
|
|
resp, err := client.Do(httpRequest)
|
|
if err != nil {
|
|
return AgentResponse{}, err
|
|
}
|
|
|
|
response, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return AgentResponse{}, err
|
|
}
|
|
|
|
agentResponse := AgentResponse{}
|
|
err = json.Unmarshal(response, &agentResponse)
|
|
|
|
if err != nil {
|
|
return AgentResponse{}, err
|
|
}
|
|
|
|
if len(agentResponse.Choices) != 1 {
|
|
client.Log.Errorf("Received more than 1 choice from AI \n %s\n", string(response))
|
|
return AgentResponse{}, errors.New("Unsupported. We currently only accept 1 choice from AI.")
|
|
}
|
|
|
|
msg := agentResponse.Choices[0].Message
|
|
req.Chat.AddAiResponse(msg)
|
|
|
|
return agentResponse, nil
|
|
}
|
|
|
|
func (client *AgentClient) ToolLoop(info ToolHandlerInfo, req *AgentRequestBody) error {
|
|
for {
|
|
response, err := client.Request(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if response.Choices[0].FinishReason == "stop" {
|
|
client.Log.Debug("Agent is finished")
|
|
return nil
|
|
}
|
|
|
|
err = client.Process(info, req)
|
|
|
|
if err != nil {
|
|
|
|
if err == FinishedCall {
|
|
client.Log.Debug("Agent is finished")
|
|
}
|
|
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
var FinishedCall = errors.New("Last tool tool was called")
|
|
|
|
func (client *AgentClient) Process(info ToolHandlerInfo, req *AgentRequestBody) error {
|
|
var err error
|
|
|
|
message, err := req.Chat.GetLatest()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
aiMessage, ok := message.(ChatAiMessage)
|
|
if !ok {
|
|
return errors.New("Latest message isnt an AI message")
|
|
}
|
|
|
|
if aiMessage.ToolCalls == nil {
|
|
// Not an error, we just dont have any tool calls to process.
|
|
return nil
|
|
}
|
|
|
|
for _, toolCall := range *aiMessage.ToolCalls {
|
|
if toolCall.Function.Name == req.EndToolCall {
|
|
return FinishedCall
|
|
}
|
|
|
|
toolResponse := client.ToolHandler.Handle(info, toolCall)
|
|
|
|
if toolCall.Function.Name == "reply" {
|
|
client.Reply = toolCall.Function.Arguments
|
|
}
|
|
|
|
client.Log.Debug("Tool call", "name", toolCall.Function.Name, "arguments", toolCall.Function.Arguments, "response", toolResponse.Content)
|
|
|
|
req.Chat.AddToolResponse(toolResponse)
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
func (client *AgentClient) RunAgent(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
|
|
var tools any
|
|
err := json.Unmarshal([]byte(client.Options.JsonTools), &tools)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
toolChoice := "auto"
|
|
seed := 42
|
|
|
|
request := AgentRequestBody{
|
|
Tools: &tools,
|
|
ToolChoice: &toolChoice,
|
|
Model: "gpt-4.1-mini",
|
|
RandomSeed: &seed,
|
|
Temperature: 0.3,
|
|
EndToolCall: client.Options.EndToolCall,
|
|
ResponseFormat: ResponseFormat{
|
|
Type: "text",
|
|
},
|
|
Chat: &Chat{
|
|
Messages: make([]ChatMessage, 0),
|
|
},
|
|
}
|
|
|
|
request.Chat.AddSystem(client.Options.SystemPrompt)
|
|
request.Chat.AddImage(imageName, imageData, client.Options.Query)
|
|
|
|
toolHandlerInfo := ToolHandlerInfo{
|
|
ImageId: imageId,
|
|
ImageName: imageName,
|
|
UserId: userId,
|
|
Image: &imageData,
|
|
}
|
|
|
|
return client.ToolLoop(toolHandlerInfo, &request)
|
|
}
|