fix(tool-calls): ToolLoop

This commit is contained in:
2025-04-09 15:15:31 +01:00
parent f169fd2ba2
commit 1a503c8320
2 changed files with 45 additions and 1 deletions

View File

@ -23,9 +23,32 @@ type AgentRequestBody struct {
Tools *any `json:"tools,omitempty"`
ToolChoice *string `json:"tool_choice,omitempty"`
EndToolCall string `json:"-"`
Chat *Chat `json:"messages"`
}
func (req AgentRequestBody) MarshalJSON() ([]byte, error) {
return json.Marshal(&struct {
Model string `json:"model"`
Temperature float64 `json:"temperature"`
ResponseFormat ResponseFormat `json:"response_format"`
Tools *any `json:"tools,omitempty"`
ToolChoice *string `json:"tool_choice,omitempty"`
Messages []ChatMessage `json:"messages"`
}{
Model: req.Model,
Temperature: req.Temperature,
ResponseFormat: req.ResponseFormat,
Tools: req.Tools,
ToolChoice: req.ToolChoice,
Messages: req.Chat.Messages,
})
}
type ResponseChoice struct {
Index int `json:"index"`
Message ChatAiMessage `json:"message"`
@ -123,6 +146,22 @@ func (client AgentClient) Request(req *AgentRequestBody) (AgentResponse, error)
return agentResponse, nil
}
func (client AgentClient) ToolLoop(info ToolHandlerInfo, req *AgentRequestBody) error {
for {
err := client.Process(info, req)
if err != nil {
return err
}
_, err = client.Request(req)
if err != nil {
return err
}
}
}
var FinishedCall = errors.New("Last tool tool was called")
func (client AgentClient) Process(info ToolHandlerInfo, req *AgentRequestBody) error {
var err error
@ -142,6 +181,10 @@ func (client AgentClient) Process(info ToolHandlerInfo, req *AgentRequestBody) e
}
for _, toolCall := range *aiMessage.ToolCalls {
if toolCall.Function.Name == req.EndToolCall {
return FinishedCall
}
toolResponse := client.ToolHandler.Handle(info, toolCall)
req.Chat.AddToolResponse(toolResponse)

View File

@ -158,6 +158,7 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID
ToolChoice: &toolChoice,
Model: "pixtral-12b-2409",
Temperature: 0.3,
EndToolCall: "finish",
ResponseFormat: client.ResponseFormat{
Type: "text",
},
@ -179,7 +180,7 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID
UserId: userId,
}
return agent.client.Process(toolHandlerInfo, &request)
return agent.client.ToolLoop(toolHandlerInfo, &request)
}
func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel, contactModel models.ContactModel) (EventLocationAgent, error) {