feat(tool-calling) Big refactor on how tool calling is handled

these commits are too big
This commit is contained in:
2025-03-22 20:46:26 +00:00
parent 7debe6bab2
commit 6f938a34e3
8 changed files with 294 additions and 147 deletions

View File

@@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"log"
"reflect"
"screenmark/screenmark/.gen/haystack/haystack/model"
"screenmark/screenmark/models"
@@ -21,6 +22,10 @@ Only create an event if you see an event on the image, not all locations have an
It is possible that there is no location or event on an image.
You should ask for a list of locations, as the user is likely to have this location saved. Reuse existing locations where possible.
Do not create an event if you don't see any dates, or a name indicating an event.
Always reuse existing locations from listLocations . Do not create duplicates.
`
// TODO: this should be read directly from a file on load.
@@ -30,7 +35,7 @@ const TOOLS = `
"type": "function",
"function": {
"name": "createLocation",
"description": "Creates a location",
"description": "Creates a location. No not use if you think an existing location is suitable!",
"parameters": {
"type": "object",
"properties": {
@@ -63,12 +68,13 @@ const TOOLS = `
"type": "function",
"function": {
"name": "attachImageLocation",
"description": "Add a location to an image",
"description": "Add a location to an image. You must use UUID.",
"parameters": {
"type": "object",
"properties": {
"locationId": {
"type": "string"
"type": "string",
"description": "UUID of an existing location, you can use listLocations to get values, or use the return value of createLocation"
}
},
"required": ["locationId"]
@@ -117,8 +123,12 @@ type EventLocationAgent struct {
eventModel models.EventModel
locationModel models.LocationModel
toolHandler ToolsHandlers
}
type ListLocationArguments struct{}
type CreateLocationArguments struct {
Name string `json:"name"`
Address *string `json:"address,omitempty"`
@@ -129,97 +139,6 @@ type AttachImageLocationArguments struct {
LocationId string `json:"locationId"`
}
func (agent EventLocationAgent) HandleToolCall(userId uuid.UUID, imageId uuid.UUID, request *AgentRequestBody) error {
latestMessage := request.AgentMessages.Messages[len(request.AgentMessages.Messages)-1]
toolCall, ok := latestMessage.(AgentAssistantToolCall)
if !ok {
return errors.New("Latest message is not a tool_call")
}
for _, call := range toolCall.ToolCalls {
if call.Function.Name == "listLocations" {
log.Println("Function call: listLocations")
locations, err := agent.locationModel.List(context.Background(), userId)
if err != nil {
return err
}
jsonLocations, err := json.Marshal(locations)
if err != nil {
return err
}
request.AddText(AgentTextMessage{
Role: "tool",
Name: "listLocations",
Content: string(jsonLocations),
ToolCallId: call.Id,
})
} else if call.Function.Name == "createLocation" {
log.Println("Function call: createLocation")
locationArguments := CreateLocationArguments{}
err := json.Unmarshal([]byte(call.Function.Arguments), &locationArguments)
if err != nil {
return err
}
locations, err := agent.locationModel.Save(context.Background(), []model.Locations{{
Name: locationArguments.Name,
Address: locationArguments.Address,
Coordinates: locationArguments.Coordinates,
}})
if err != nil {
return err
}
createdLocation := locations[0]
jsonCreatedLocation, err := json.Marshal(createdLocation)
if err != nil {
return err
}
request.AddText(AgentTextMessage{
Role: "tool",
Name: "createLocation",
Content: string(jsonCreatedLocation),
ToolCallId: call.Id,
})
} else if call.Function.Name == "attachImageLocation" {
log.Println("Function call: attachImageLocation")
attachLocationArguments := AttachImageLocationArguments{}
err := json.Unmarshal([]byte(call.Function.Arguments), &attachLocationArguments)
if err != nil {
return err
}
_, err = agent.locationModel.SaveToImage(context.Background(), imageId, uuid.MustParse(attachLocationArguments.LocationId))
if err != nil {
return err
}
request.AddText(AgentTextMessage{
Role: "tool",
Name: "attachImageLocation",
Content: "OK",
ToolCallId: call.Id,
})
} else if call.Function.Name == "finish" {
log.Println("Finished!")
return errors.New("hmmm, this isnt actually an error but hey")
} else {
return errors.New("Unknown tool_call")
}
}
return nil
}
func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
var tools any
err := json.Unmarshal([]byte(TOOLS), &tools)
@@ -241,8 +160,6 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID
return err
}
log.Println(request)
request.AddImage(imageName, imageData)
_, err = agent.client.Request(&request)
@@ -250,7 +167,13 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID
return err
}
err = agent.HandleToolCall(userId, imageId, &request)
toolHandlerInfo := ToolHandlerInfo{
imageId: imageId,
userId: userId,
}
_, err = agent.toolHandler.Handle(toolHandlerInfo, &request)
for err == nil {
log.Printf("Latest message: %+v\n", request.AgentMessages.Messages[len(request.AgentMessages.Messages)-1])
@@ -261,21 +184,135 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID
log.Println(response)
err = agent.HandleToolCall(userId, imageId, &request)
a, innerErr := agent.toolHandler.Handle(toolHandlerInfo, &request)
err = innerErr
log.Println(a)
log.Println("--------------------------")
}
return err
}
func (handler ToolsHandlers) Handle(info ToolHandlerInfo, request *AgentRequestBody) (string, error) {
agentMessage := request.Messages[len(request.Messages)-1]
toolCall, ok := agentMessage.(AgentAssistantToolCall)
if !ok {
return "", errors.New("Latest message was not a tool call.")
}
fnName := toolCall.ToolCalls[0].Function.Name
arguments := toolCall.ToolCalls[0].Function.Arguments
if fnName == "finish" {
return "", errors.New("This is the end! Maybe we just return a boolean.")
}
fn, exists := handler.Handlers[fnName]
if !exists {
return "", errors.New("Could not find tool with this name.")
}
// holy jesus what the fuck.
parseMethod := reflect.ValueOf(fn).Field(1)
if !parseMethod.IsValid() {
return "", errors.New("Parse method not found")
}
parsedArgs := parseMethod.Call([]reflect.Value{reflect.ValueOf(arguments)})
if !parsedArgs[1].IsNil() {
return "", parsedArgs[1].Interface().(error)
}
log.Printf("Calling: %s\n", fnName)
fnMethod := reflect.ValueOf(fn).Field(2)
if !fnMethod.IsValid() {
return "", errors.New("Fn method not found")
}
response := fnMethod.Call([]reflect.Value{reflect.ValueOf(info), parsedArgs[0], reflect.ValueOf(toolCall.ToolCalls[0])})
if !response[1].IsNil() {
return "", response[1].Interface().(error)
}
stringResponse, err := json.Marshal(response[0].Interface())
if err != nil {
return "", err
}
request.AddText(AgentTextMessage{
Role: "tool",
Name: "createLocation",
Content: string(stringResponse),
ToolCallId: toolCall.ToolCalls[0].Id,
})
return string(stringResponse), nil
}
func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel) (EventLocationAgent, error) {
client, err := CreateAgentClient(eventLocationPrompt)
if err != nil {
return EventLocationAgent{}, err
}
return EventLocationAgent{
agent := EventLocationAgent{
client: client,
locationModel: locationModel,
eventModel: eventModel,
}, nil
}
toolHandler := ToolsHandlers{
Handlers: make(map[string]ToolHandlerInterface),
}
toolHandler.Handlers["listLocations"] = ToolHandler[ListLocationArguments, []model.Locations]{
FunctionName: "listLocations",
Parse: func(stringArgs string) (ListLocationArguments, error) {
args := ListLocationArguments{}
err := json.Unmarshal([]byte(stringArgs), &args)
return args, err
},
Fn: func(info ToolHandlerInfo, _args ListLocationArguments, call ToolCall) ([]model.Locations, error) {
return agent.locationModel.List(context.Background(), info.userId)
},
}
toolHandler.Handlers["createLocation"] = ToolHandler[CreateLocationArguments, model.Locations]{
FunctionName: "createLocation",
Parse: func(stringArgs string) (CreateLocationArguments, error) {
args := CreateLocationArguments{}
err := json.Unmarshal([]byte(stringArgs), &args)
return args, err
},
Fn: func(info ToolHandlerInfo, args CreateLocationArguments, call ToolCall) (model.Locations, error) {
return agent.locationModel.Save(context.Background(), info.userId, model.Locations{
Name: args.Name,
Address: args.Address,
Coordinates: args.Coordinates,
})
},
}
toolHandler.Handlers["attachImageLocation"] = ToolHandler[AttachImageLocationArguments, model.ImageLocations]{
FunctionName: "attachImageLocation",
Parse: func(stringArgs string) (AttachImageLocationArguments, error) {
args := AttachImageLocationArguments{}
err := json.Unmarshal([]byte(stringArgs), &args)
return args, err
},
Fn: func(info ToolHandlerInfo, args AttachImageLocationArguments, call ToolCall) (model.ImageLocations, error) {
return agent.locationModel.SaveToImage(context.Background(), info.imageId, uuid.MustParse(args.LocationId))
},
}
agent.toolHandler = toolHandler
return agent, nil
}

View File

@@ -0,0 +1,26 @@
package agents
import "github.com/google/uuid"
type ToolHandlerInfo struct {
userId uuid.UUID
imageId uuid.UUID
}
type ToolHandler[TArgs any, TResp any] struct {
FunctionName string
Parse func(args string) (TArgs, error)
Fn func(info ToolHandlerInfo, args TArgs, call ToolCall) (TResp, error)
}
type ToolHandlerInterface interface {
GetFunctionName() string
}
func (handler ToolHandler[TArgs, TResp]) GetFunctionName() string {
return handler.FunctionName
}
type ToolsHandlers struct {
Handlers map[string]ToolHandlerInterface
}