diff --git a/backend/agents/client/client.go b/backend/agents/client/client.go index d1eff3b..18af351 100644 --- a/backend/agents/client/client.go +++ b/backend/agents/client/client.go @@ -23,9 +23,32 @@ type AgentRequestBody struct { Tools *any `json:"tools,omitempty"` ToolChoice *string `json:"tool_choice,omitempty"` + EndToolCall string `json:"-"` + Chat *Chat `json:"messages"` } +func (req AgentRequestBody) MarshalJSON() ([]byte, error) { + return json.Marshal(&struct { + Model string `json:"model"` + Temperature float64 `json:"temperature"` + ResponseFormat ResponseFormat `json:"response_format"` + + Tools *any `json:"tools,omitempty"` + ToolChoice *string `json:"tool_choice,omitempty"` + Messages []ChatMessage `json:"messages"` + }{ + Model: req.Model, + Temperature: req.Temperature, + ResponseFormat: req.ResponseFormat, + + Tools: req.Tools, + ToolChoice: req.ToolChoice, + + Messages: req.Chat.Messages, + }) +} + type ResponseChoice struct { Index int `json:"index"` Message ChatAiMessage `json:"message"` @@ -123,6 +146,22 @@ func (client AgentClient) Request(req *AgentRequestBody) (AgentResponse, error) return agentResponse, nil } +func (client AgentClient) ToolLoop(info ToolHandlerInfo, req *AgentRequestBody) error { + for { + err := client.Process(info, req) + if err != nil { + return err + } + + _, err = client.Request(req) + if err != nil { + return err + } + } +} + +var FinishedCall = errors.New("Last tool tool was called") + func (client AgentClient) Process(info ToolHandlerInfo, req *AgentRequestBody) error { var err error @@ -142,6 +181,10 @@ func (client AgentClient) Process(info ToolHandlerInfo, req *AgentRequestBody) e } for _, toolCall := range *aiMessage.ToolCalls { + if toolCall.Function.Name == req.EndToolCall { + return FinishedCall + } + toolResponse := client.ToolHandler.Handle(info, toolCall) req.Chat.AddToolResponse(toolResponse) diff --git a/backend/agents/event_location_agent.go b/backend/agents/event_location_agent.go index 53bc017..22672f6 100644 --- a/backend/agents/event_location_agent.go +++ b/backend/agents/event_location_agent.go @@ -158,6 +158,7 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID ToolChoice: &toolChoice, Model: "pixtral-12b-2409", Temperature: 0.3, + EndToolCall: "finish", ResponseFormat: client.ResponseFormat{ Type: "text", }, @@ -179,7 +180,7 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID UserId: userId, } - return agent.client.Process(toolHandlerInfo, &request) + return agent.client.ToolLoop(toolHandlerInfo, &request) } func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel, contactModel models.ContactModel) (EventLocationAgent, error) {