test(tools): starting test suite for tools

This commit is contained in:
2025-04-05 14:35:54 +01:00
parent 03e7803467
commit a1ce96d2e3
5 changed files with 65 additions and 23 deletions

View File

@ -266,7 +266,12 @@ func (client AgentClient) Process(info ToolHandlerInfo, request AgentRequestBody
var err error
for {
err = client.ToolHandler.Handle(info, &request)
toolCall, ok := request.Messages[len(request.Messages)-1].(AgentAssistantToolCall)
if !ok {
return errors.New("Latest message isnt a tool call. TODO")
}
_, err = client.ToolHandler.Handle(info, toolCall)
if err != nil {
break
}

View File

@ -3,7 +3,6 @@ package client
import (
"encoding/json"
"errors"
"log"
"github.com/google/uuid"
)
@ -21,37 +20,26 @@ type ToolsHandlers struct {
handlers *map[string]ToolHandler
}
func (handler ToolsHandlers) Handle(info ToolHandlerInfo, request *AgentRequestBody) error {
agentMessage := request.Messages[len(request.Messages)-1]
toolCall, ok := agentMessage.(AgentAssistantToolCall)
if !ok {
return errors.New("Latest message was not a tool call.")
}
fnName := toolCall.ToolCalls[0].Function.Name
arguments := toolCall.ToolCalls[0].Function.Arguments
func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage AgentAssistantToolCall) (AgentTextMessage, error) {
fnName := toolCallMessage.ToolCalls[0].Function.Name
arguments := toolCallMessage.ToolCalls[0].Function.Arguments
fnHandler, exists := (*handler.handlers)[fnName]
if !exists {
return errors.New("Could not find tool with this name.")
return AgentTextMessage{}, errors.New("Could not find tool with this name.")
}
log.Printf("Calling: %s\n", fnName)
res, err := fnHandler.Fn(info, arguments, toolCall.ToolCalls[0])
res, err := fnHandler.Fn(info, arguments, toolCallMessage.ToolCalls[0])
if err != nil {
return err
return AgentTextMessage{}, err
}
log.Println(res)
request.AddText(AgentTextMessage{
return AgentTextMessage{
Role: "tool",
Name: fnName,
Content: res,
ToolCallId: toolCall.ToolCalls[0].Id,
})
return nil
ToolCallId: toolCallMessage.ToolCalls[0].Id,
}, nil
}
func (handler ToolsHandlers) AddTool(name string, fn func(info ToolHandlerInfo, args string, call ToolCall) (any, error)) {

View File

@ -0,0 +1,47 @@
package client
import (
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func TestSingleToolCall(t *testing.T) {
assert := assert.New(t)
tools := ToolsHandlers{
handlers: &map[string]ToolHandler{},
}
tools.AddTool("a", func(info ToolHandlerInfo, args string, call ToolCall) (any, error) {
return true, nil
})
response, err := tools.Handle(
ToolHandlerInfo{
UserId: uuid.Nil,
ImageId: uuid.Nil,
},
AgentAssistantToolCall{
Role: "assistant",
Content: "",
ToolCalls: []ToolCall{{
Index: 0,
Id: "1",
Function: FunctionCall{
Name: "a",
Arguments: "",
},
}},
})
if assert.NoError(err, "Tool call shouldnt return an error") {
assert.EqualValues(response, AgentTextMessage{
Role: "tool",
Content: "true",
ToolCallId: "1",
Name: "a",
})
}
}

View File

@ -10,6 +10,6 @@ require (
github.com/joho/godotenv v1.5.1 // indirect
github.com/lib/pq v1.10.9 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/stretchr/testify v1.9.0 // indirect
github.com/stretchr/testify v1.10.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@ -14,6 +14,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=