diff --git a/backend/agents/client/tools.go b/backend/agents/client/tools.go index 8c0d977..0e259ec 100644 --- a/backend/agents/client/tools.go +++ b/backend/agents/client/tools.go @@ -20,26 +20,38 @@ type ToolsHandlers struct { handlers *map[string]ToolHandler } -func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage AgentAssistantToolCall) (AgentTextMessage, error) { - fnName := toolCallMessage.ToolCalls[0].Function.Name - arguments := toolCallMessage.ToolCalls[0].Function.Arguments +var NoToolCallError = errors.New("An assistant tool call with no tool calls was provided.") - fnHandler, exists := (*handler.handlers)[fnName] - if !exists { - return AgentTextMessage{}, errors.New("Could not find tool with this name.") +func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage AgentAssistantToolCall) ([]AgentTextMessage, error) { + if len(toolCallMessage.ToolCalls) == 0 { + return []AgentTextMessage{}, NoToolCallError } - res, err := fnHandler.Fn(info, arguments, toolCallMessage.ToolCalls[0]) - if err != nil { - return AgentTextMessage{}, err + responses := make([]AgentTextMessage, len(toolCallMessage.ToolCalls)) + + for i, toolCall := range toolCallMessage.ToolCalls { + fnName := toolCall.Function.Name + arguments := toolCall.Function.Arguments + + fnHandler, exists := (*handler.handlers)[fnName] + if !exists { + return []AgentTextMessage{}, errors.New("Could not find tool with this name.") + } + + res, err := fnHandler.Fn(info, arguments, toolCallMessage.ToolCalls[0]) + if err != nil { + return []AgentTextMessage{}, err + } + + responses[i] = AgentTextMessage{ + Role: "tool", + Name: fnName, + Content: res, + ToolCallId: toolCall.Id, + } } - return AgentTextMessage{ - Role: "tool", - Name: fnName, - Content: res, - ToolCallId: toolCallMessage.ToolCalls[0].Id, - }, nil + return responses, nil } func (handler ToolsHandlers) AddTool(name string, fn func(info ToolHandlerInfo, args string, call ToolCall) (any, error)) { diff --git a/backend/agents/client/tools_test.go b/backend/agents/client/tools_test.go index 9b17ce0..35423ff 100644 --- a/backend/agents/client/tools_test.go +++ b/backend/agents/client/tools_test.go @@ -4,21 +4,30 @@ import ( "testing" "github.com/google/uuid" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) -func TestSingleToolCall(t *testing.T) { - assert := assert.New(t) +type ToolTestSuite struct { + suite.Suite - tools := ToolsHandlers{ + handler ToolsHandlers +} + +func (suite *ToolTestSuite) SetupTest() { + suite.handler = ToolsHandlers{ handlers: &map[string]ToolHandler{}, } - tools.AddTool("a", func(info ToolHandlerInfo, args string, call ToolCall) (any, error) { + suite.handler.AddTool("a", func(info ToolHandlerInfo, args string, call ToolCall) (any, error) { return true, nil }) +} - response, err := tools.Handle( +func (suite *ToolTestSuite) TestSingleToolCall() { + assert := suite.Assert() + require := suite.Require() + + response, err := suite.handler.Handle( ToolHandlerInfo{ UserId: uuid.Nil, ImageId: uuid.Nil, @@ -36,12 +45,83 @@ func TestSingleToolCall(t *testing.T) { }}, }) - if assert.NoError(err, "Tool call shouldnt return an error") { - assert.EqualValues(response, AgentTextMessage{ + require.NoError(err, "Tool call shouldnt return an error") + + assert.EqualValues(response, []AgentTextMessage{{ + Role: "tool", + Content: "true", + ToolCallId: "1", + Name: "a", + }}) +} + +func (suite *ToolTestSuite) TestEmptyCall() { + require := suite.Require() + + _, err := suite.handler.Handle( + ToolHandlerInfo{ + UserId: uuid.Nil, + ImageId: uuid.Nil, + }, + AgentAssistantToolCall{ + Role: "assistant", + Content: "", + ToolCalls: []ToolCall{}, + }) + + require.ErrorIs(err, NoToolCallError) +} + +func (suite *ToolTestSuite) TestMultipleToolCalls() { + assert := suite.Assert() + require := suite.Require() + + response, err := suite.handler.Handle( + ToolHandlerInfo{ + UserId: uuid.Nil, + ImageId: uuid.Nil, + }, + AgentAssistantToolCall{ + Role: "assistant", + Content: "", + ToolCalls: []ToolCall{ + { + Index: 0, + Id: "1", + Function: FunctionCall{ + Name: "a", + Arguments: "", + }, + }, + { + Index: 1, + Id: "2", + Function: FunctionCall{ + Name: "a", + Arguments: "", + }, + }, + }, + }) + + require.NoError(err, "Tool call shouldnt return an error") + + assert.EqualValues(response, []AgentTextMessage{ + { Role: "tool", Content: "true", ToolCallId: "1", Name: "a", - }) - } + }, + { + Role: "tool", + Content: "true", + ToolCallId: "2", + Name: "a", + }, + }) +} + +func TestToolSuite(t *testing.T) { + suite.Run(t, &ToolTestSuite{}) }