fix(tool-calls): ToolLoop
This commit is contained in:
@ -23,9 +23,32 @@ type AgentRequestBody struct {
|
||||
Tools *any `json:"tools,omitempty"`
|
||||
ToolChoice *string `json:"tool_choice,omitempty"`
|
||||
|
||||
EndToolCall string `json:"-"`
|
||||
|
||||
Chat *Chat `json:"messages"`
|
||||
}
|
||||
|
||||
func (req AgentRequestBody) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(&struct {
|
||||
Model string `json:"model"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
ResponseFormat ResponseFormat `json:"response_format"`
|
||||
|
||||
Tools *any `json:"tools,omitempty"`
|
||||
ToolChoice *string `json:"tool_choice,omitempty"`
|
||||
Messages []ChatMessage `json:"messages"`
|
||||
}{
|
||||
Model: req.Model,
|
||||
Temperature: req.Temperature,
|
||||
ResponseFormat: req.ResponseFormat,
|
||||
|
||||
Tools: req.Tools,
|
||||
ToolChoice: req.ToolChoice,
|
||||
|
||||
Messages: req.Chat.Messages,
|
||||
})
|
||||
}
|
||||
|
||||
type ResponseChoice struct {
|
||||
Index int `json:"index"`
|
||||
Message ChatAiMessage `json:"message"`
|
||||
@ -123,6 +146,22 @@ func (client AgentClient) Request(req *AgentRequestBody) (AgentResponse, error)
|
||||
return agentResponse, nil
|
||||
}
|
||||
|
||||
func (client AgentClient) ToolLoop(info ToolHandlerInfo, req *AgentRequestBody) error {
|
||||
for {
|
||||
err := client.Process(info, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = client.Request(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var FinishedCall = errors.New("Last tool tool was called")
|
||||
|
||||
func (client AgentClient) Process(info ToolHandlerInfo, req *AgentRequestBody) error {
|
||||
var err error
|
||||
|
||||
@ -142,6 +181,10 @@ func (client AgentClient) Process(info ToolHandlerInfo, req *AgentRequestBody) e
|
||||
}
|
||||
|
||||
for _, toolCall := range *aiMessage.ToolCalls {
|
||||
if toolCall.Function.Name == req.EndToolCall {
|
||||
return FinishedCall
|
||||
}
|
||||
|
||||
toolResponse := client.ToolHandler.Handle(info, toolCall)
|
||||
|
||||
req.Chat.AddToolResponse(toolResponse)
|
||||
|
@ -158,6 +158,7 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID
|
||||
ToolChoice: &toolChoice,
|
||||
Model: "pixtral-12b-2409",
|
||||
Temperature: 0.3,
|
||||
EndToolCall: "finish",
|
||||
ResponseFormat: client.ResponseFormat{
|
||||
Type: "text",
|
||||
},
|
||||
@ -179,7 +180,7 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID
|
||||
UserId: userId,
|
||||
}
|
||||
|
||||
return agent.client.Process(toolHandlerInfo, &request)
|
||||
return agent.client.ToolLoop(toolHandlerInfo, &request)
|
||||
}
|
||||
|
||||
func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel, contactModel models.ContactModel) (EventLocationAgent, error) {
|
||||
|
Reference in New Issue
Block a user