82 lines
1.8 KiB
Go
82 lines
1.8 KiB
Go
package client
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
type ToolHandlerInfo struct {
|
|
UserId uuid.UUID
|
|
ImageId uuid.UUID
|
|
}
|
|
|
|
type ToolHandler struct {
|
|
Fn func(info ToolHandlerInfo, args string, call ToolCall) (string, error)
|
|
}
|
|
|
|
type ToolsHandlers struct {
|
|
handlers map[string]ToolHandler
|
|
}
|
|
|
|
var NoToolCallError = errors.New("An assistant tool call with no tool calls was provided.")
|
|
|
|
const NonExistantTool = "This tool does not exist"
|
|
|
|
func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage AgentAssistantToolCall) ([]AgentTextMessage, error) {
|
|
if len(toolCallMessage.ToolCalls) == 0 {
|
|
return []AgentTextMessage{}, NoToolCallError
|
|
}
|
|
|
|
responses := make([]AgentTextMessage, len(toolCallMessage.ToolCalls))
|
|
|
|
for i, toolCall := range toolCallMessage.ToolCalls {
|
|
fnName := toolCall.Function.Name
|
|
arguments := toolCall.Function.Arguments
|
|
|
|
responseMessage := AgentTextMessage{
|
|
Role: "tool",
|
|
Name: fnName,
|
|
ToolCallId: toolCall.Id,
|
|
}
|
|
|
|
fnHandler, exists := handler.handlers[fnName]
|
|
if !exists {
|
|
responseMessage.Content = NonExistantTool
|
|
responses[i] = responseMessage
|
|
continue
|
|
}
|
|
|
|
res, err := fnHandler.Fn(info, arguments, toolCallMessage.ToolCalls[0])
|
|
|
|
if err != nil {
|
|
responseMessage.Content = err.Error()
|
|
} else {
|
|
responseMessage.Content = res
|
|
}
|
|
|
|
responses[i] = responseMessage
|
|
}
|
|
|
|
return responses, nil
|
|
}
|
|
|
|
func (handler *ToolsHandlers) AddTool(name string, fn func(info ToolHandlerInfo, args string, call ToolCall) (any, error)) {
|
|
handler.handlers[name] = ToolHandler{
|
|
Fn: func(info ToolHandlerInfo, args string, call ToolCall) (string, error) {
|
|
res, err := fn(info, args, call)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
marshalledRes, err := json.Marshal(res)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return string(marshalledRes), nil
|
|
},
|
|
}
|
|
}
|