test(tools): more robust multiple tool call handling
This commit is contained in:
@ -20,26 +20,38 @@ type ToolsHandlers struct {
|
|||||||
handlers *map[string]ToolHandler
|
handlers *map[string]ToolHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage AgentAssistantToolCall) (AgentTextMessage, error) {
|
var NoToolCallError = errors.New("An assistant tool call with no tool calls was provided.")
|
||||||
fnName := toolCallMessage.ToolCalls[0].Function.Name
|
|
||||||
arguments := toolCallMessage.ToolCalls[0].Function.Arguments
|
func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage AgentAssistantToolCall) ([]AgentTextMessage, error) {
|
||||||
|
if len(toolCallMessage.ToolCalls) == 0 {
|
||||||
|
return []AgentTextMessage{}, NoToolCallError
|
||||||
|
}
|
||||||
|
|
||||||
|
responses := make([]AgentTextMessage, len(toolCallMessage.ToolCalls))
|
||||||
|
|
||||||
|
for i, toolCall := range toolCallMessage.ToolCalls {
|
||||||
|
fnName := toolCall.Function.Name
|
||||||
|
arguments := toolCall.Function.Arguments
|
||||||
|
|
||||||
fnHandler, exists := (*handler.handlers)[fnName]
|
fnHandler, exists := (*handler.handlers)[fnName]
|
||||||
if !exists {
|
if !exists {
|
||||||
return AgentTextMessage{}, errors.New("Could not find tool with this name.")
|
return []AgentTextMessage{}, errors.New("Could not find tool with this name.")
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err := fnHandler.Fn(info, arguments, toolCallMessage.ToolCalls[0])
|
res, err := fnHandler.Fn(info, arguments, toolCallMessage.ToolCalls[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return AgentTextMessage{}, err
|
return []AgentTextMessage{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return AgentTextMessage{
|
responses[i] = AgentTextMessage{
|
||||||
Role: "tool",
|
Role: "tool",
|
||||||
Name: fnName,
|
Name: fnName,
|
||||||
Content: res,
|
Content: res,
|
||||||
ToolCallId: toolCallMessage.ToolCalls[0].Id,
|
ToolCallId: toolCall.Id,
|
||||||
}, nil
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return responses, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (handler ToolsHandlers) AddTool(name string, fn func(info ToolHandlerInfo, args string, call ToolCall) (any, error)) {
|
func (handler ToolsHandlers) AddTool(name string, fn func(info ToolHandlerInfo, args string, call ToolCall) (any, error)) {
|
||||||
|
@ -4,21 +4,30 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/suite"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestSingleToolCall(t *testing.T) {
|
type ToolTestSuite struct {
|
||||||
assert := assert.New(t)
|
suite.Suite
|
||||||
|
|
||||||
tools := ToolsHandlers{
|
handler ToolsHandlers
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *ToolTestSuite) SetupTest() {
|
||||||
|
suite.handler = ToolsHandlers{
|
||||||
handlers: &map[string]ToolHandler{},
|
handlers: &map[string]ToolHandler{},
|
||||||
}
|
}
|
||||||
|
|
||||||
tools.AddTool("a", func(info ToolHandlerInfo, args string, call ToolCall) (any, error) {
|
suite.handler.AddTool("a", func(info ToolHandlerInfo, args string, call ToolCall) (any, error) {
|
||||||
return true, nil
|
return true, nil
|
||||||
})
|
})
|
||||||
|
}
|
||||||
|
|
||||||
response, err := tools.Handle(
|
func (suite *ToolTestSuite) TestSingleToolCall() {
|
||||||
|
assert := suite.Assert()
|
||||||
|
require := suite.Require()
|
||||||
|
|
||||||
|
response, err := suite.handler.Handle(
|
||||||
ToolHandlerInfo{
|
ToolHandlerInfo{
|
||||||
UserId: uuid.Nil,
|
UserId: uuid.Nil,
|
||||||
ImageId: uuid.Nil,
|
ImageId: uuid.Nil,
|
||||||
@ -36,12 +45,83 @@ func TestSingleToolCall(t *testing.T) {
|
|||||||
}},
|
}},
|
||||||
})
|
})
|
||||||
|
|
||||||
if assert.NoError(err, "Tool call shouldnt return an error") {
|
require.NoError(err, "Tool call shouldnt return an error")
|
||||||
assert.EqualValues(response, AgentTextMessage{
|
|
||||||
|
assert.EqualValues(response, []AgentTextMessage{{
|
||||||
Role: "tool",
|
Role: "tool",
|
||||||
Content: "true",
|
Content: "true",
|
||||||
ToolCallId: "1",
|
ToolCallId: "1",
|
||||||
Name: "a",
|
Name: "a",
|
||||||
|
}})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *ToolTestSuite) TestEmptyCall() {
|
||||||
|
require := suite.Require()
|
||||||
|
|
||||||
|
_, err := suite.handler.Handle(
|
||||||
|
ToolHandlerInfo{
|
||||||
|
UserId: uuid.Nil,
|
||||||
|
ImageId: uuid.Nil,
|
||||||
|
},
|
||||||
|
AgentAssistantToolCall{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "",
|
||||||
|
ToolCalls: []ToolCall{},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.ErrorIs(err, NoToolCallError)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *ToolTestSuite) TestMultipleToolCalls() {
|
||||||
|
assert := suite.Assert()
|
||||||
|
require := suite.Require()
|
||||||
|
|
||||||
|
response, err := suite.handler.Handle(
|
||||||
|
ToolHandlerInfo{
|
||||||
|
UserId: uuid.Nil,
|
||||||
|
ImageId: uuid.Nil,
|
||||||
|
},
|
||||||
|
AgentAssistantToolCall{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "",
|
||||||
|
ToolCalls: []ToolCall{
|
||||||
|
{
|
||||||
|
Index: 0,
|
||||||
|
Id: "1",
|
||||||
|
Function: FunctionCall{
|
||||||
|
Name: "a",
|
||||||
|
Arguments: "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Index: 1,
|
||||||
|
Id: "2",
|
||||||
|
Function: FunctionCall{
|
||||||
|
Name: "a",
|
||||||
|
Arguments: "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(err, "Tool call shouldnt return an error")
|
||||||
|
|
||||||
|
assert.EqualValues(response, []AgentTextMessage{
|
||||||
|
{
|
||||||
|
Role: "tool",
|
||||||
|
Content: "true",
|
||||||
|
ToolCallId: "1",
|
||||||
|
Name: "a",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "tool",
|
||||||
|
Content: "true",
|
||||||
|
ToolCallId: "2",
|
||||||
|
Name: "a",
|
||||||
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestToolSuite(t *testing.T) {
|
||||||
|
suite.Run(t, &ToolTestSuite{})
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user