From a1ce96d2e324a537a1373b5aa6f29d6bf2c43537 Mon Sep 17 00:00:00 2001 From: John Costa Date: Sat, 5 Apr 2025 14:35:54 +0100 Subject: [PATCH] test(tools): starting test suite for tools --- backend/agents/client/client.go | 7 ++++- backend/agents/client/tools.go | 30 ++++++------------ backend/agents/client/tools_test.go | 47 +++++++++++++++++++++++++++++ backend/go.mod | 2 +- backend/go.sum | 2 ++ 5 files changed, 65 insertions(+), 23 deletions(-) create mode 100644 backend/agents/client/tools_test.go diff --git a/backend/agents/client/client.go b/backend/agents/client/client.go index 464dba2..a49c86d 100644 --- a/backend/agents/client/client.go +++ b/backend/agents/client/client.go @@ -266,7 +266,12 @@ func (client AgentClient) Process(info ToolHandlerInfo, request AgentRequestBody var err error for { - err = client.ToolHandler.Handle(info, &request) + toolCall, ok := request.Messages[len(request.Messages)-1].(AgentAssistantToolCall) + if !ok { + return errors.New("Latest message isnt a tool call. TODO") + } + + _, err = client.ToolHandler.Handle(info, toolCall) if err != nil { break } diff --git a/backend/agents/client/tools.go b/backend/agents/client/tools.go index 38dbb1f..8c0d977 100644 --- a/backend/agents/client/tools.go +++ b/backend/agents/client/tools.go @@ -3,7 +3,6 @@ package client import ( "encoding/json" "errors" - "log" "github.com/google/uuid" ) @@ -21,37 +20,26 @@ type ToolsHandlers struct { handlers *map[string]ToolHandler } -func (handler ToolsHandlers) Handle(info ToolHandlerInfo, request *AgentRequestBody) error { - agentMessage := request.Messages[len(request.Messages)-1] - - toolCall, ok := agentMessage.(AgentAssistantToolCall) - if !ok { - return errors.New("Latest message was not a tool call.") - } - - fnName := toolCall.ToolCalls[0].Function.Name - arguments := toolCall.ToolCalls[0].Function.Arguments +func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage AgentAssistantToolCall) (AgentTextMessage, error) { + fnName := toolCallMessage.ToolCalls[0].Function.Name + arguments := toolCallMessage.ToolCalls[0].Function.Arguments fnHandler, exists := (*handler.handlers)[fnName] if !exists { - return errors.New("Could not find tool with this name.") + return AgentTextMessage{}, errors.New("Could not find tool with this name.") } - log.Printf("Calling: %s\n", fnName) - res, err := fnHandler.Fn(info, arguments, toolCall.ToolCalls[0]) + res, err := fnHandler.Fn(info, arguments, toolCallMessage.ToolCalls[0]) if err != nil { - return err + return AgentTextMessage{}, err } - log.Println(res) - request.AddText(AgentTextMessage{ + return AgentTextMessage{ Role: "tool", Name: fnName, Content: res, - ToolCallId: toolCall.ToolCalls[0].Id, - }) - - return nil + ToolCallId: toolCallMessage.ToolCalls[0].Id, + }, nil } func (handler ToolsHandlers) AddTool(name string, fn func(info ToolHandlerInfo, args string, call ToolCall) (any, error)) { diff --git a/backend/agents/client/tools_test.go b/backend/agents/client/tools_test.go new file mode 100644 index 0000000..9b17ce0 --- /dev/null +++ b/backend/agents/client/tools_test.go @@ -0,0 +1,47 @@ +package client + +import ( + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" +) + +func TestSingleToolCall(t *testing.T) { + assert := assert.New(t) + + tools := ToolsHandlers{ + handlers: &map[string]ToolHandler{}, + } + + tools.AddTool("a", func(info ToolHandlerInfo, args string, call ToolCall) (any, error) { + return true, nil + }) + + response, err := tools.Handle( + ToolHandlerInfo{ + UserId: uuid.Nil, + ImageId: uuid.Nil, + }, + AgentAssistantToolCall{ + Role: "assistant", + Content: "", + ToolCalls: []ToolCall{{ + Index: 0, + Id: "1", + Function: FunctionCall{ + Name: "a", + Arguments: "", + }, + }}, + }) + + if assert.NoError(err, "Tool call shouldnt return an error") { + assert.EqualValues(response, AgentTextMessage{ + Role: "tool", + Content: "true", + ToolCallId: "1", + Name: "a", + }) + } +} diff --git a/backend/go.mod b/backend/go.mod index fcdc38a..df64294 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -10,6 +10,6 @@ require ( github.com/joho/godotenv v1.5.1 // indirect github.com/lib/pq v1.10.9 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/stretchr/testify v1.9.0 // indirect + github.com/stretchr/testify v1.10.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/backend/go.sum b/backend/go.sum index ecd9db0..b982fe0 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -14,6 +14,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=