feat(tool-calling) Big refactor on how tool calling is handled
these commits are too big
This commit is contained in:
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log"
|
||||
"reflect"
|
||||
"screenmark/screenmark/.gen/haystack/haystack/model"
|
||||
"screenmark/screenmark/models"
|
||||
|
||||
@@ -21,6 +22,10 @@ Only create an event if you see an event on the image, not all locations have an
|
||||
It is possible that there is no location or event on an image.
|
||||
|
||||
You should ask for a list of locations, as the user is likely to have this location saved. Reuse existing locations where possible.
|
||||
|
||||
Do not create an event if you don't see any dates, or a name indicating an event.
|
||||
|
||||
Always reuse existing locations from listLocations . Do not create duplicates.
|
||||
`
|
||||
|
||||
// TODO: this should be read directly from a file on load.
|
||||
@@ -30,7 +35,7 @@ const TOOLS = `
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "createLocation",
|
||||
"description": "Creates a location",
|
||||
"description": "Creates a location. No not use if you think an existing location is suitable!",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -63,12 +68,13 @@ const TOOLS = `
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "attachImageLocation",
|
||||
"description": "Add a location to an image",
|
||||
"description": "Add a location to an image. You must use UUID.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"locationId": {
|
||||
"type": "string"
|
||||
"type": "string",
|
||||
"description": "UUID of an existing location, you can use listLocations to get values, or use the return value of createLocation"
|
||||
}
|
||||
},
|
||||
"required": ["locationId"]
|
||||
@@ -117,8 +123,12 @@ type EventLocationAgent struct {
|
||||
|
||||
eventModel models.EventModel
|
||||
locationModel models.LocationModel
|
||||
|
||||
toolHandler ToolsHandlers
|
||||
}
|
||||
|
||||
type ListLocationArguments struct{}
|
||||
|
||||
type CreateLocationArguments struct {
|
||||
Name string `json:"name"`
|
||||
Address *string `json:"address,omitempty"`
|
||||
@@ -129,97 +139,6 @@ type AttachImageLocationArguments struct {
|
||||
LocationId string `json:"locationId"`
|
||||
}
|
||||
|
||||
func (agent EventLocationAgent) HandleToolCall(userId uuid.UUID, imageId uuid.UUID, request *AgentRequestBody) error {
|
||||
latestMessage := request.AgentMessages.Messages[len(request.AgentMessages.Messages)-1]
|
||||
|
||||
toolCall, ok := latestMessage.(AgentAssistantToolCall)
|
||||
|
||||
if !ok {
|
||||
return errors.New("Latest message is not a tool_call")
|
||||
}
|
||||
|
||||
for _, call := range toolCall.ToolCalls {
|
||||
if call.Function.Name == "listLocations" {
|
||||
log.Println("Function call: listLocations")
|
||||
|
||||
locations, err := agent.locationModel.List(context.Background(), userId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
jsonLocations, err := json.Marshal(locations)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
request.AddText(AgentTextMessage{
|
||||
Role: "tool",
|
||||
Name: "listLocations",
|
||||
Content: string(jsonLocations),
|
||||
ToolCallId: call.Id,
|
||||
})
|
||||
} else if call.Function.Name == "createLocation" {
|
||||
log.Println("Function call: createLocation")
|
||||
|
||||
locationArguments := CreateLocationArguments{}
|
||||
err := json.Unmarshal([]byte(call.Function.Arguments), &locationArguments)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
locations, err := agent.locationModel.Save(context.Background(), []model.Locations{{
|
||||
Name: locationArguments.Name,
|
||||
Address: locationArguments.Address,
|
||||
Coordinates: locationArguments.Coordinates,
|
||||
}})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
createdLocation := locations[0]
|
||||
jsonCreatedLocation, err := json.Marshal(createdLocation)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
request.AddText(AgentTextMessage{
|
||||
Role: "tool",
|
||||
Name: "createLocation",
|
||||
Content: string(jsonCreatedLocation),
|
||||
ToolCallId: call.Id,
|
||||
})
|
||||
} else if call.Function.Name == "attachImageLocation" {
|
||||
log.Println("Function call: attachImageLocation")
|
||||
|
||||
attachLocationArguments := AttachImageLocationArguments{}
|
||||
err := json.Unmarshal([]byte(call.Function.Arguments), &attachLocationArguments)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = agent.locationModel.SaveToImage(context.Background(), imageId, uuid.MustParse(attachLocationArguments.LocationId))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
request.AddText(AgentTextMessage{
|
||||
Role: "tool",
|
||||
Name: "attachImageLocation",
|
||||
Content: "OK",
|
||||
ToolCallId: call.Id,
|
||||
})
|
||||
} else if call.Function.Name == "finish" {
|
||||
log.Println("Finished!")
|
||||
return errors.New("hmmm, this isnt actually an error but hey")
|
||||
} else {
|
||||
return errors.New("Unknown tool_call")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
|
||||
var tools any
|
||||
err := json.Unmarshal([]byte(TOOLS), &tools)
|
||||
@@ -241,8 +160,6 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID
|
||||
return err
|
||||
}
|
||||
|
||||
log.Println(request)
|
||||
|
||||
request.AddImage(imageName, imageData)
|
||||
|
||||
_, err = agent.client.Request(&request)
|
||||
@@ -250,7 +167,13 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID
|
||||
return err
|
||||
}
|
||||
|
||||
err = agent.HandleToolCall(userId, imageId, &request)
|
||||
toolHandlerInfo := ToolHandlerInfo{
|
||||
imageId: imageId,
|
||||
userId: userId,
|
||||
}
|
||||
|
||||
_, err = agent.toolHandler.Handle(toolHandlerInfo, &request)
|
||||
|
||||
for err == nil {
|
||||
log.Printf("Latest message: %+v\n", request.AgentMessages.Messages[len(request.AgentMessages.Messages)-1])
|
||||
|
||||
@@ -261,21 +184,135 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID
|
||||
|
||||
log.Println(response)
|
||||
|
||||
err = agent.HandleToolCall(userId, imageId, &request)
|
||||
a, innerErr := agent.toolHandler.Handle(toolHandlerInfo, &request)
|
||||
|
||||
err = innerErr
|
||||
|
||||
log.Println(a)
|
||||
log.Println("--------------------------")
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (handler ToolsHandlers) Handle(info ToolHandlerInfo, request *AgentRequestBody) (string, error) {
|
||||
agentMessage := request.Messages[len(request.Messages)-1]
|
||||
|
||||
toolCall, ok := agentMessage.(AgentAssistantToolCall)
|
||||
if !ok {
|
||||
return "", errors.New("Latest message was not a tool call.")
|
||||
}
|
||||
|
||||
fnName := toolCall.ToolCalls[0].Function.Name
|
||||
arguments := toolCall.ToolCalls[0].Function.Arguments
|
||||
|
||||
if fnName == "finish" {
|
||||
return "", errors.New("This is the end! Maybe we just return a boolean.")
|
||||
}
|
||||
|
||||
fn, exists := handler.Handlers[fnName]
|
||||
if !exists {
|
||||
return "", errors.New("Could not find tool with this name.")
|
||||
}
|
||||
|
||||
// holy jesus what the fuck.
|
||||
parseMethod := reflect.ValueOf(fn).Field(1)
|
||||
if !parseMethod.IsValid() {
|
||||
return "", errors.New("Parse method not found")
|
||||
}
|
||||
|
||||
parsedArgs := parseMethod.Call([]reflect.Value{reflect.ValueOf(arguments)})
|
||||
if !parsedArgs[1].IsNil() {
|
||||
return "", parsedArgs[1].Interface().(error)
|
||||
}
|
||||
|
||||
log.Printf("Calling: %s\n", fnName)
|
||||
|
||||
fnMethod := reflect.ValueOf(fn).Field(2)
|
||||
if !fnMethod.IsValid() {
|
||||
return "", errors.New("Fn method not found")
|
||||
}
|
||||
|
||||
response := fnMethod.Call([]reflect.Value{reflect.ValueOf(info), parsedArgs[0], reflect.ValueOf(toolCall.ToolCalls[0])})
|
||||
if !response[1].IsNil() {
|
||||
return "", response[1].Interface().(error)
|
||||
}
|
||||
|
||||
stringResponse, err := json.Marshal(response[0].Interface())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
request.AddText(AgentTextMessage{
|
||||
Role: "tool",
|
||||
Name: "createLocation",
|
||||
Content: string(stringResponse),
|
||||
ToolCallId: toolCall.ToolCalls[0].Id,
|
||||
})
|
||||
|
||||
return string(stringResponse), nil
|
||||
}
|
||||
|
||||
func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel) (EventLocationAgent, error) {
|
||||
client, err := CreateAgentClient(eventLocationPrompt)
|
||||
if err != nil {
|
||||
return EventLocationAgent{}, err
|
||||
}
|
||||
|
||||
return EventLocationAgent{
|
||||
agent := EventLocationAgent{
|
||||
client: client,
|
||||
locationModel: locationModel,
|
||||
eventModel: eventModel,
|
||||
}, nil
|
||||
}
|
||||
|
||||
toolHandler := ToolsHandlers{
|
||||
Handlers: make(map[string]ToolHandlerInterface),
|
||||
}
|
||||
|
||||
toolHandler.Handlers["listLocations"] = ToolHandler[ListLocationArguments, []model.Locations]{
|
||||
FunctionName: "listLocations",
|
||||
Parse: func(stringArgs string) (ListLocationArguments, error) {
|
||||
args := ListLocationArguments{}
|
||||
err := json.Unmarshal([]byte(stringArgs), &args)
|
||||
|
||||
return args, err
|
||||
},
|
||||
Fn: func(info ToolHandlerInfo, _args ListLocationArguments, call ToolCall) ([]model.Locations, error) {
|
||||
return agent.locationModel.List(context.Background(), info.userId)
|
||||
},
|
||||
}
|
||||
|
||||
toolHandler.Handlers["createLocation"] = ToolHandler[CreateLocationArguments, model.Locations]{
|
||||
FunctionName: "createLocation",
|
||||
Parse: func(stringArgs string) (CreateLocationArguments, error) {
|
||||
args := CreateLocationArguments{}
|
||||
err := json.Unmarshal([]byte(stringArgs), &args)
|
||||
|
||||
return args, err
|
||||
},
|
||||
Fn: func(info ToolHandlerInfo, args CreateLocationArguments, call ToolCall) (model.Locations, error) {
|
||||
return agent.locationModel.Save(context.Background(), info.userId, model.Locations{
|
||||
Name: args.Name,
|
||||
Address: args.Address,
|
||||
Coordinates: args.Coordinates,
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
toolHandler.Handlers["attachImageLocation"] = ToolHandler[AttachImageLocationArguments, model.ImageLocations]{
|
||||
FunctionName: "attachImageLocation",
|
||||
Parse: func(stringArgs string) (AttachImageLocationArguments, error) {
|
||||
args := AttachImageLocationArguments{}
|
||||
err := json.Unmarshal([]byte(stringArgs), &args)
|
||||
|
||||
return args, err
|
||||
},
|
||||
Fn: func(info ToolHandlerInfo, args AttachImageLocationArguments, call ToolCall) (model.ImageLocations, error) {
|
||||
return agent.locationModel.SaveToImage(context.Background(), info.imageId, uuid.MustParse(args.LocationId))
|
||||
},
|
||||
}
|
||||
|
||||
agent.toolHandler = toolHandler
|
||||
|
||||
return agent, nil
|
||||
}
|
||||
|
||||
26
backend/agents/tools_handler.go
Normal file
26
backend/agents/tools_handler.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package agents
|
||||
|
||||
import "github.com/google/uuid"
|
||||
|
||||
type ToolHandlerInfo struct {
|
||||
userId uuid.UUID
|
||||
imageId uuid.UUID
|
||||
}
|
||||
|
||||
type ToolHandler[TArgs any, TResp any] struct {
|
||||
FunctionName string
|
||||
Parse func(args string) (TArgs, error)
|
||||
Fn func(info ToolHandlerInfo, args TArgs, call ToolCall) (TResp, error)
|
||||
}
|
||||
|
||||
type ToolHandlerInterface interface {
|
||||
GetFunctionName() string
|
||||
}
|
||||
|
||||
func (handler ToolHandler[TArgs, TResp]) GetFunctionName() string {
|
||||
return handler.FunctionName
|
||||
}
|
||||
|
||||
type ToolsHandlers struct {
|
||||
Handlers map[string]ToolHandlerInterface
|
||||
}
|
||||
Reference in New Issue
Block a user