2025-05-04 10:19:39 +01:00

272 lines
5.9 KiB
Go

package client
import (
"bytes"
"encoding/json"
"errors"
"io"
"net/http"
"os"
"github.com/charmbracelet/log"
"github.com/google/uuid"
)
type ResponseFormat struct {
Type string `json:"type"`
JsonSchema any `json:"json_schema,omitzero"`
}
type AgentRequestBody struct {
Model string `json:"model"`
Temperature float64 `json:"temperature"`
ResponseFormat ResponseFormat `json:"response_format"`
Tools *any `json:"tools,omitempty"`
ToolChoice *string `json:"tool_choice,omitempty"`
RandomSeed *int `json:"random_seed,omitempty"`
EndToolCall string `json:"-"`
Chat *Chat `json:"messages"`
}
func (req AgentRequestBody) MarshalJSON() ([]byte, error) {
return json.Marshal(&struct {
Model string `json:"model"`
Temperature float64 `json:"temperature"`
ResponseFormat ResponseFormat `json:"response_format"`
Tools *any `json:"tools,omitempty"`
ToolChoice *string `json:"tool_choice,omitempty"`
Messages []ChatMessage `json:"messages"`
}{
Model: req.Model,
Temperature: req.Temperature,
ResponseFormat: req.ResponseFormat,
Tools: req.Tools,
ToolChoice: req.ToolChoice,
Messages: req.Chat.Messages,
})
}
type ResponseChoice struct {
Index int `json:"index"`
Message ChatAiMessage `json:"message"`
FinishReason string `json:"finish_reason"`
}
type AgentResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Choices []ResponseChoice `json:"choices"`
Created int `json:"created"`
}
type AgentClient struct {
url string
apiKey string
responseFormat string
ToolHandler ToolsHandlers
Log *log.Logger
Reply string
Do func(req *http.Request) (*http.Response, error)
Options CreateAgentClientOptions
}
const OPENAI_API_KEY = "REAL_OPEN_AI_KEY"
type CreateAgentClientOptions struct {
Log *log.Logger
SystemPrompt string
JsonTools string
EndToolCall string
Query *string
}
func CreateAgentClient(options CreateAgentClientOptions) AgentClient {
apiKey := os.Getenv(OPENAI_API_KEY)
if len(apiKey) == 0 {
panic("No api key")
}
return AgentClient{
apiKey: apiKey,
url: "https://api.openai.com/v1/chat/completions",
Do: func(req *http.Request) (*http.Response, error) {
client := &http.Client{}
return client.Do(req)
},
Log: options.Log,
ToolHandler: ToolsHandlers{
handlers: map[string]ToolHandler{},
},
Options: options,
}
}
func (client AgentClient) getRequest(body []byte) (*http.Request, error) {
req, err := http.NewRequest("POST", client.url, bytes.NewBuffer(body))
if err != nil {
return req, err
}
req.Header.Add("Authorization", "Bearer "+client.apiKey)
req.Header.Add("Content-Type", "application/json")
return req, nil
}
func (client AgentClient) Request(req *AgentRequestBody) (AgentResponse, error) {
jsonAiRequest, err := json.Marshal(req)
if err != nil {
return AgentResponse{}, err
}
httpRequest, err := client.getRequest(jsonAiRequest)
if err != nil {
return AgentResponse{}, err
}
resp, err := client.Do(httpRequest)
if err != nil {
return AgentResponse{}, err
}
response, err := io.ReadAll(resp.Body)
if err != nil {
return AgentResponse{}, err
}
agentResponse := AgentResponse{}
err = json.Unmarshal(response, &agentResponse)
if err != nil {
return AgentResponse{}, err
}
if len(agentResponse.Choices) != 1 {
client.Log.Errorf("Received more than 1 choice from AI \n %s\n", string(response))
return AgentResponse{}, errors.New("Unsupported. We currently only accept 1 choice from AI.")
}
msg := agentResponse.Choices[0].Message
req.Chat.AddAiResponse(msg)
return agentResponse, nil
}
func (client *AgentClient) ToolLoop(info ToolHandlerInfo, req *AgentRequestBody) error {
for {
response, err := client.Request(req)
if err != nil {
return err
}
if response.Choices[0].FinishReason == "stop" {
client.Log.Debug("Agent is finished")
return nil
}
err = client.Process(info, req)
if err != nil {
if err == FinishedCall {
client.Log.Debug("Agent is finished")
}
return err
}
}
}
var FinishedCall = errors.New("Last tool tool was called")
func (client *AgentClient) Process(info ToolHandlerInfo, req *AgentRequestBody) error {
var err error
message, err := req.Chat.GetLatest()
if err != nil {
return err
}
aiMessage, ok := message.(ChatAiMessage)
if !ok {
return errors.New("Latest message isnt an AI message")
}
if aiMessage.ToolCalls == nil {
// Not an error, we just dont have any tool calls to process.
return nil
}
for _, toolCall := range *aiMessage.ToolCalls {
if toolCall.Function.Name == req.EndToolCall {
return FinishedCall
}
toolResponse := client.ToolHandler.Handle(info, toolCall)
if toolCall.Function.Name == "reply" {
client.Reply = toolCall.Function.Arguments
}
client.Log.Debug("Tool call", "name", toolCall.Function.Name, "arguments", toolCall.Function.Arguments, "response", toolResponse.Content)
req.Chat.AddToolResponse(toolResponse)
}
return err
}
func (client *AgentClient) RunAgent(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
var tools any
err := json.Unmarshal([]byte(client.Options.JsonTools), &tools)
if err != nil {
panic(err)
}
toolChoice := "auto"
seed := 42
request := AgentRequestBody{
Tools: &tools,
ToolChoice: &toolChoice,
Model: "gpt-4.1-mini",
RandomSeed: &seed,
Temperature: 0.3,
EndToolCall: client.Options.EndToolCall,
ResponseFormat: ResponseFormat{
Type: "text",
},
Chat: &Chat{
Messages: make([]ChatMessage, 0),
},
}
request.Chat.AddSystem(client.Options.SystemPrompt)
request.Chat.AddImage(imageName, imageData, client.Options.Query)
toolHandlerInfo := ToolHandlerInfo{
ImageId: imageId,
ImageName: imageName,
UserId: userId,
Image: &imageData,
}
return client.ToolLoop(toolHandlerInfo, &request)
}