192 lines
3.5 KiB
Go
192 lines
3.5 KiB
Go
package client
|
|
|
|
import (
|
|
"errors"
|
|
"os"
|
|
"testing"
|
|
|
|
"github.com/charmbracelet/log"
|
|
"github.com/google/uuid"
|
|
"github.com/stretchr/testify/suite"
|
|
)
|
|
|
|
type ToolTestSuite struct {
|
|
suite.Suite
|
|
|
|
handler ToolsHandlers
|
|
client AgentClient
|
|
}
|
|
|
|
func (suite *ToolTestSuite) SetupTest() {
|
|
suite.handler = ToolsHandlers{
|
|
handlers: map[string]ToolHandler{},
|
|
}
|
|
|
|
suite.handler.AddTool("a", func(info ToolHandlerInfo, args string, call ToolCall) (any, error) {
|
|
return args, nil
|
|
})
|
|
|
|
suite.handler.AddTool("error", func(info ToolHandlerInfo, args string, call ToolCall) (any, error) {
|
|
return false, errors.New("I will always error")
|
|
})
|
|
|
|
suite.client.Log = log.New(os.Stdout)
|
|
suite.client.ToolHandler = suite.handler
|
|
}
|
|
|
|
func (suite *ToolTestSuite) TestSingleToolCall() {
|
|
require := suite.Require()
|
|
|
|
response := suite.handler.Handle(
|
|
ToolHandlerInfo{
|
|
UserId: uuid.Nil,
|
|
ImageId: uuid.Nil,
|
|
},
|
|
ToolCall{
|
|
Index: 0,
|
|
Id: "1",
|
|
Function: FunctionCall{
|
|
Name: "a",
|
|
Arguments: "return",
|
|
},
|
|
})
|
|
|
|
require.EqualValues(response, ChatUserToolResponse{
|
|
Role: "tool",
|
|
Content: "\"return\"",
|
|
ToolCallId: "1",
|
|
Name: "a",
|
|
})
|
|
}
|
|
|
|
func (suite *ToolTestSuite) TestMultipleToolCalls() {
|
|
assert := suite.Assert()
|
|
require := suite.Require()
|
|
|
|
chat := Chat{
|
|
Messages: []ChatMessage{ChatAiMessage{
|
|
Role: "assistant",
|
|
Content: "",
|
|
ToolCalls: &[]ToolCall{
|
|
{
|
|
Index: 0,
|
|
Id: "1",
|
|
Function: FunctionCall{
|
|
Name: "a",
|
|
Arguments: "first-call",
|
|
},
|
|
},
|
|
{
|
|
Index: 1,
|
|
Id: "2",
|
|
Function: FunctionCall{
|
|
Name: "a",
|
|
Arguments: "second-call",
|
|
},
|
|
},
|
|
},
|
|
}},
|
|
}
|
|
|
|
err := suite.client.Process(
|
|
ToolHandlerInfo{
|
|
UserId: uuid.Nil,
|
|
ImageId: uuid.Nil,
|
|
},
|
|
&AgentRequestBody{
|
|
Chat: &chat,
|
|
})
|
|
|
|
require.NoError(err, "Tool call shouldnt return an error")
|
|
assert.EqualValues(chat.Messages[1:], []ChatMessage{
|
|
ChatUserToolResponse{
|
|
Role: "tool",
|
|
Content: "\"first-call\"",
|
|
ToolCallId: "1",
|
|
Name: "a",
|
|
},
|
|
ChatUserToolResponse{
|
|
Role: "tool",
|
|
Content: "\"second-call\"",
|
|
ToolCallId: "2",
|
|
Name: "a",
|
|
},
|
|
})
|
|
}
|
|
|
|
func (suite *ToolTestSuite) TestMultipleToolCallsWithErrors() {
|
|
assert := suite.Assert()
|
|
require := suite.Require()
|
|
|
|
chat := Chat{
|
|
Messages: []ChatMessage{ChatAiMessage{
|
|
Role: "assistant",
|
|
Content: "",
|
|
ToolCalls: &[]ToolCall{
|
|
{
|
|
Index: 0,
|
|
Id: "1",
|
|
Function: FunctionCall{
|
|
Name: "error",
|
|
Arguments: "",
|
|
},
|
|
},
|
|
{
|
|
Index: 1,
|
|
Id: "2",
|
|
Function: FunctionCall{
|
|
Name: "non-existant",
|
|
Arguments: "",
|
|
},
|
|
},
|
|
{
|
|
Index: 2,
|
|
Id: "3",
|
|
Function: FunctionCall{
|
|
Name: "a",
|
|
Arguments: "no-error",
|
|
},
|
|
},
|
|
},
|
|
}},
|
|
}
|
|
|
|
err := suite.client.Process(
|
|
ToolHandlerInfo{
|
|
UserId: uuid.Nil,
|
|
ImageId: uuid.Nil,
|
|
},
|
|
&AgentRequestBody{
|
|
Chat: &chat,
|
|
})
|
|
|
|
require.NoError(err, "Tool call shouldnt return an error")
|
|
|
|
assert.EqualValues(chat.Messages[1:], []ChatMessage{
|
|
ChatUserToolResponse{
|
|
Role: "tool",
|
|
Content: "I will always error",
|
|
ToolCallId: "1",
|
|
Name: "error",
|
|
},
|
|
ChatUserToolResponse{
|
|
Role: "tool",
|
|
Content: "This tool does not exist",
|
|
ToolCallId: "2",
|
|
Name: "non-existant",
|
|
},
|
|
ChatUserToolResponse{
|
|
Role: "tool",
|
|
Content: "\"no-error\"",
|
|
ToolCallId: "3",
|
|
Name: "a",
|
|
},
|
|
})
|
|
}
|
|
|
|
func TestToolSuite(t *testing.T) {
|
|
suite.Run(t, &ToolTestSuite{
|
|
client: AgentClient{},
|
|
})
|
|
}
|