fix(tool-calls): ToolLoop
This commit is contained in:
@ -23,9 +23,32 @@ type AgentRequestBody struct {
|
|||||||
Tools *any `json:"tools,omitempty"`
|
Tools *any `json:"tools,omitempty"`
|
||||||
ToolChoice *string `json:"tool_choice,omitempty"`
|
ToolChoice *string `json:"tool_choice,omitempty"`
|
||||||
|
|
||||||
|
EndToolCall string `json:"-"`
|
||||||
|
|
||||||
Chat *Chat `json:"messages"`
|
Chat *Chat `json:"messages"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (req AgentRequestBody) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(&struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Temperature float64 `json:"temperature"`
|
||||||
|
ResponseFormat ResponseFormat `json:"response_format"`
|
||||||
|
|
||||||
|
Tools *any `json:"tools,omitempty"`
|
||||||
|
ToolChoice *string `json:"tool_choice,omitempty"`
|
||||||
|
Messages []ChatMessage `json:"messages"`
|
||||||
|
}{
|
||||||
|
Model: req.Model,
|
||||||
|
Temperature: req.Temperature,
|
||||||
|
ResponseFormat: req.ResponseFormat,
|
||||||
|
|
||||||
|
Tools: req.Tools,
|
||||||
|
ToolChoice: req.ToolChoice,
|
||||||
|
|
||||||
|
Messages: req.Chat.Messages,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
type ResponseChoice struct {
|
type ResponseChoice struct {
|
||||||
Index int `json:"index"`
|
Index int `json:"index"`
|
||||||
Message ChatAiMessage `json:"message"`
|
Message ChatAiMessage `json:"message"`
|
||||||
@ -123,6 +146,22 @@ func (client AgentClient) Request(req *AgentRequestBody) (AgentResponse, error)
|
|||||||
return agentResponse, nil
|
return agentResponse, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (client AgentClient) ToolLoop(info ToolHandlerInfo, req *AgentRequestBody) error {
|
||||||
|
for {
|
||||||
|
err := client.Process(info, req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.Request(req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var FinishedCall = errors.New("Last tool tool was called")
|
||||||
|
|
||||||
func (client AgentClient) Process(info ToolHandlerInfo, req *AgentRequestBody) error {
|
func (client AgentClient) Process(info ToolHandlerInfo, req *AgentRequestBody) error {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
@ -142,6 +181,10 @@ func (client AgentClient) Process(info ToolHandlerInfo, req *AgentRequestBody) e
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, toolCall := range *aiMessage.ToolCalls {
|
for _, toolCall := range *aiMessage.ToolCalls {
|
||||||
|
if toolCall.Function.Name == req.EndToolCall {
|
||||||
|
return FinishedCall
|
||||||
|
}
|
||||||
|
|
||||||
toolResponse := client.ToolHandler.Handle(info, toolCall)
|
toolResponse := client.ToolHandler.Handle(info, toolCall)
|
||||||
|
|
||||||
req.Chat.AddToolResponse(toolResponse)
|
req.Chat.AddToolResponse(toolResponse)
|
||||||
|
@ -158,6 +158,7 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID
|
|||||||
ToolChoice: &toolChoice,
|
ToolChoice: &toolChoice,
|
||||||
Model: "pixtral-12b-2409",
|
Model: "pixtral-12b-2409",
|
||||||
Temperature: 0.3,
|
Temperature: 0.3,
|
||||||
|
EndToolCall: "finish",
|
||||||
ResponseFormat: client.ResponseFormat{
|
ResponseFormat: client.ResponseFormat{
|
||||||
Type: "text",
|
Type: "text",
|
||||||
},
|
},
|
||||||
@ -179,7 +180,7 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID
|
|||||||
UserId: userId,
|
UserId: userId,
|
||||||
}
|
}
|
||||||
|
|
||||||
return agent.client.Process(toolHandlerInfo, &request)
|
return agent.client.ToolLoop(toolHandlerInfo, &request)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel, contactModel models.ContactModel) (EventLocationAgent, error) {
|
func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel, contactModel models.ContactModel) (EventLocationAgent, error) {
|
||||||
|
Reference in New Issue
Block a user