I guess some repeated code doesnt hurt anyone, if it keeps things simpler. Trying to be fancy with the interfaces didn't work so well.
76 lines
1.6 KiB
Go
76 lines
1.6 KiB
Go
package client
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"log"
|
|
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
type ToolHandlerInfo struct {
|
|
UserId uuid.UUID
|
|
ImageId uuid.UUID
|
|
}
|
|
|
|
type ToolHandler struct {
|
|
Fn func(info ToolHandlerInfo, args string, call ToolCall) (string, error)
|
|
}
|
|
|
|
type ToolsHandlers struct {
|
|
handlers *map[string]ToolHandler
|
|
}
|
|
|
|
func (handler ToolsHandlers) Handle(info ToolHandlerInfo, request *AgentRequestBody) error {
|
|
agentMessage := request.Messages[len(request.Messages)-1]
|
|
|
|
toolCall, ok := agentMessage.(AgentAssistantToolCall)
|
|
if !ok {
|
|
return errors.New("Latest message was not a tool call.")
|
|
}
|
|
|
|
fnName := toolCall.ToolCalls[0].Function.Name
|
|
arguments := toolCall.ToolCalls[0].Function.Arguments
|
|
|
|
log.Println(handler.handlers)
|
|
|
|
fnHandler, exists := (*handler.handlers)[fnName]
|
|
if !exists {
|
|
return errors.New("Could not find tool with this name.")
|
|
}
|
|
|
|
log.Printf("Calling: %s\n", fnName)
|
|
res, err := fnHandler.Fn(info, arguments, toolCall.ToolCalls[0])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
log.Println(res)
|
|
|
|
request.AddText(AgentTextMessage{
|
|
Role: "tool",
|
|
Name: fnName,
|
|
Content: res,
|
|
ToolCallId: toolCall.ToolCalls[0].Id,
|
|
})
|
|
|
|
return nil
|
|
}
|
|
|
|
func (handler ToolsHandlers) AddTool(name string, fn func(info ToolHandlerInfo, args string, call ToolCall) (any, error)) {
|
|
(*handler.handlers)[name] = ToolHandler{
|
|
Fn: func(info ToolHandlerInfo, args string, call ToolCall) (string, error) {
|
|
res, err := fn(info, args, call)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
marshalledRes, err := json.Marshal(res)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return string(marshalledRes), nil
|
|
},
|
|
}
|
|
}
|