Haystack/backend/agents/client/tools_test.go
2025-04-12 14:43:01 +01:00

192 lines
3.5 KiB
Go

package client
import (
"errors"
"os"
"testing"
"github.com/charmbracelet/log"
"github.com/google/uuid"
"github.com/stretchr/testify/suite"
)
type ToolTestSuite struct {
suite.Suite
handler ToolsHandlers
client AgentClient
}
func (suite *ToolTestSuite) SetupTest() {
suite.handler = ToolsHandlers{
handlers: map[string]ToolHandler{},
}
suite.handler.AddTool("a", func(info ToolHandlerInfo, args string, call ToolCall) (any, error) {
return args, nil
})
suite.handler.AddTool("error", func(info ToolHandlerInfo, args string, call ToolCall) (any, error) {
return false, errors.New("I will always error")
})
suite.client.Log = log.New(os.Stdout)
suite.client.ToolHandler = suite.handler
}
func (suite *ToolTestSuite) TestSingleToolCall() {
require := suite.Require()
response := suite.handler.Handle(
ToolHandlerInfo{
UserId: uuid.Nil,
ImageId: uuid.Nil,
},
ToolCall{
Index: 0,
Id: "1",
Function: FunctionCall{
Name: "a",
Arguments: "return",
},
})
require.EqualValues(response, ChatUserToolResponse{
Role: "tool",
Content: "\"return\"",
ToolCallId: "1",
Name: "a",
})
}
func (suite *ToolTestSuite) TestMultipleToolCalls() {
assert := suite.Assert()
require := suite.Require()
chat := Chat{
Messages: []ChatMessage{ChatAiMessage{
Role: "assistant",
Content: "",
ToolCalls: &[]ToolCall{
{
Index: 0,
Id: "1",
Function: FunctionCall{
Name: "a",
Arguments: "first-call",
},
},
{
Index: 1,
Id: "2",
Function: FunctionCall{
Name: "a",
Arguments: "second-call",
},
},
},
}},
}
err := suite.client.Process(
ToolHandlerInfo{
UserId: uuid.Nil,
ImageId: uuid.Nil,
},
&AgentRequestBody{
Chat: &chat,
})
require.NoError(err, "Tool call shouldnt return an error")
assert.EqualValues(chat.Messages[1:], []ChatMessage{
ChatUserToolResponse{
Role: "tool",
Content: "\"first-call\"",
ToolCallId: "1",
Name: "a",
},
ChatUserToolResponse{
Role: "tool",
Content: "\"second-call\"",
ToolCallId: "2",
Name: "a",
},
})
}
func (suite *ToolTestSuite) TestMultipleToolCallsWithErrors() {
assert := suite.Assert()
require := suite.Require()
chat := Chat{
Messages: []ChatMessage{ChatAiMessage{
Role: "assistant",
Content: "",
ToolCalls: &[]ToolCall{
{
Index: 0,
Id: "1",
Function: FunctionCall{
Name: "error",
Arguments: "",
},
},
{
Index: 1,
Id: "2",
Function: FunctionCall{
Name: "non-existant",
Arguments: "",
},
},
{
Index: 2,
Id: "3",
Function: FunctionCall{
Name: "a",
Arguments: "no-error",
},
},
},
}},
}
err := suite.client.Process(
ToolHandlerInfo{
UserId: uuid.Nil,
ImageId: uuid.Nil,
},
&AgentRequestBody{
Chat: &chat,
})
require.NoError(err, "Tool call shouldnt return an error")
assert.EqualValues(chat.Messages[1:], []ChatMessage{
ChatUserToolResponse{
Role: "tool",
Content: "I will always error",
ToolCallId: "1",
Name: "error",
},
ChatUserToolResponse{
Role: "tool",
Content: "This tool does not exist",
ToolCallId: "2",
Name: "non-existant",
},
ChatUserToolResponse{
Role: "tool",
Content: "\"no-error\"",
ToolCallId: "3",
Name: "a",
},
})
}
func TestToolSuite(t *testing.T) {
suite.Run(t, &ToolTestSuite{
client: AgentClient{},
})
}