feat(tool-calling) Big refactor on how tool calling is handled

these commits are too big
This commit is contained in:
2025-03-22 20:46:26 +00:00
parent 13e5ed9f9e
commit 410df01b4d
8 changed files with 294 additions and 147 deletions

View File

@ -0,0 +1,18 @@
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
package model
import (
"github.com/google/uuid"
)
type UserLocations struct {
ID uuid.UUID `sql:"primary_key"`
LocationID uuid.UUID
UserID uuid.UUID
}

View File

@ -20,6 +20,7 @@ func UseSchema(schema string) {
Locations = Locations.FromSchema(schema)
UserImages = UserImages.FromSchema(schema)
UserImagesToProcess = UserImagesToProcess.FromSchema(schema)
UserLocations = UserLocations.FromSchema(schema)
UserTags = UserTags.FromSchema(schema)
Users = Users.FromSchema(schema)
}

View File

@ -0,0 +1,81 @@
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
package table
import (
"github.com/go-jet/jet/v2/postgres"
)
var UserLocations = newUserLocationsTable("haystack", "user_locations", "")
type userLocationsTable struct {
postgres.Table
// Columns
ID postgres.ColumnString
LocationID postgres.ColumnString
UserID postgres.ColumnString
AllColumns postgres.ColumnList
MutableColumns postgres.ColumnList
}
type UserLocationsTable struct {
userLocationsTable
EXCLUDED userLocationsTable
}
// AS creates new UserLocationsTable with assigned alias
func (a UserLocationsTable) AS(alias string) *UserLocationsTable {
return newUserLocationsTable(a.SchemaName(), a.TableName(), alias)
}
// Schema creates new UserLocationsTable with assigned schema name
func (a UserLocationsTable) FromSchema(schemaName string) *UserLocationsTable {
return newUserLocationsTable(schemaName, a.TableName(), a.Alias())
}
// WithPrefix creates new UserLocationsTable with assigned table prefix
func (a UserLocationsTable) WithPrefix(prefix string) *UserLocationsTable {
return newUserLocationsTable(a.SchemaName(), prefix+a.TableName(), a.TableName())
}
// WithSuffix creates new UserLocationsTable with assigned table suffix
func (a UserLocationsTable) WithSuffix(suffix string) *UserLocationsTable {
return newUserLocationsTable(a.SchemaName(), a.TableName()+suffix, a.TableName())
}
func newUserLocationsTable(schemaName, tableName, alias string) *UserLocationsTable {
return &UserLocationsTable{
userLocationsTable: newUserLocationsTableImpl(schemaName, tableName, alias),
EXCLUDED: newUserLocationsTableImpl("", "excluded", ""),
}
}
func newUserLocationsTableImpl(schemaName, tableName, alias string) userLocationsTable {
var (
IDColumn = postgres.StringColumn("id")
LocationIDColumn = postgres.StringColumn("location_id")
UserIDColumn = postgres.StringColumn("user_id")
allColumns = postgres.ColumnList{IDColumn, LocationIDColumn, UserIDColumn}
mutableColumns = postgres.ColumnList{LocationIDColumn, UserIDColumn}
)
return userLocationsTable{
Table: postgres.NewTable(schemaName, tableName, alias, allColumns...),
//Columns
ID: IDColumn,
LocationID: LocationIDColumn,
UserID: UserIDColumn,
AllColumns: allColumns,
MutableColumns: mutableColumns,
}
}

View File

@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"log"
"reflect"
"screenmark/screenmark/.gen/haystack/haystack/model"
"screenmark/screenmark/models"
@ -21,6 +22,10 @@ Only create an event if you see an event on the image, not all locations have an
It is possible that there is no location or event on an image.
You should ask for a list of locations, as the user is likely to have this location saved. Reuse existing locations where possible.
Do not create an event if you don't see any dates, or a name indicating an event.
Always reuse existing locations from listLocations . Do not create duplicates.
`
// TODO: this should be read directly from a file on load.
@ -30,7 +35,7 @@ const TOOLS = `
"type": "function",
"function": {
"name": "createLocation",
"description": "Creates a location",
"description": "Creates a location. No not use if you think an existing location is suitable!",
"parameters": {
"type": "object",
"properties": {
@ -63,12 +68,13 @@ const TOOLS = `
"type": "function",
"function": {
"name": "attachImageLocation",
"description": "Add a location to an image",
"description": "Add a location to an image. You must use UUID.",
"parameters": {
"type": "object",
"properties": {
"locationId": {
"type": "string"
"type": "string",
"description": "UUID of an existing location, you can use listLocations to get values, or use the return value of createLocation"
}
},
"required": ["locationId"]
@ -117,8 +123,12 @@ type EventLocationAgent struct {
eventModel models.EventModel
locationModel models.LocationModel
toolHandler ToolsHandlers
}
type ListLocationArguments struct{}
type CreateLocationArguments struct {
Name string `json:"name"`
Address *string `json:"address,omitempty"`
@ -129,97 +139,6 @@ type AttachImageLocationArguments struct {
LocationId string `json:"locationId"`
}
func (agent EventLocationAgent) HandleToolCall(userId uuid.UUID, imageId uuid.UUID, request *AgentRequestBody) error {
latestMessage := request.AgentMessages.Messages[len(request.AgentMessages.Messages)-1]
toolCall, ok := latestMessage.(AgentAssistantToolCall)
if !ok {
return errors.New("Latest message is not a tool_call")
}
for _, call := range toolCall.ToolCalls {
if call.Function.Name == "listLocations" {
log.Println("Function call: listLocations")
locations, err := agent.locationModel.List(context.Background(), userId)
if err != nil {
return err
}
jsonLocations, err := json.Marshal(locations)
if err != nil {
return err
}
request.AddText(AgentTextMessage{
Role: "tool",
Name: "listLocations",
Content: string(jsonLocations),
ToolCallId: call.Id,
})
} else if call.Function.Name == "createLocation" {
log.Println("Function call: createLocation")
locationArguments := CreateLocationArguments{}
err := json.Unmarshal([]byte(call.Function.Arguments), &locationArguments)
if err != nil {
return err
}
locations, err := agent.locationModel.Save(context.Background(), []model.Locations{{
Name: locationArguments.Name,
Address: locationArguments.Address,
Coordinates: locationArguments.Coordinates,
}})
if err != nil {
return err
}
createdLocation := locations[0]
jsonCreatedLocation, err := json.Marshal(createdLocation)
if err != nil {
return err
}
request.AddText(AgentTextMessage{
Role: "tool",
Name: "createLocation",
Content: string(jsonCreatedLocation),
ToolCallId: call.Id,
})
} else if call.Function.Name == "attachImageLocation" {
log.Println("Function call: attachImageLocation")
attachLocationArguments := AttachImageLocationArguments{}
err := json.Unmarshal([]byte(call.Function.Arguments), &attachLocationArguments)
if err != nil {
return err
}
_, err = agent.locationModel.SaveToImage(context.Background(), imageId, uuid.MustParse(attachLocationArguments.LocationId))
if err != nil {
return err
}
request.AddText(AgentTextMessage{
Role: "tool",
Name: "attachImageLocation",
Content: "OK",
ToolCallId: call.Id,
})
} else if call.Function.Name == "finish" {
log.Println("Finished!")
return errors.New("hmmm, this isnt actually an error but hey")
} else {
return errors.New("Unknown tool_call")
}
}
return nil
}
func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
var tools any
err := json.Unmarshal([]byte(TOOLS), &tools)
@ -241,8 +160,6 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID
return err
}
log.Println(request)
request.AddImage(imageName, imageData)
_, err = agent.client.Request(&request)
@ -250,7 +167,13 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID
return err
}
err = agent.HandleToolCall(userId, imageId, &request)
toolHandlerInfo := ToolHandlerInfo{
imageId: imageId,
userId: userId,
}
_, err = agent.toolHandler.Handle(toolHandlerInfo, &request)
for err == nil {
log.Printf("Latest message: %+v\n", request.AgentMessages.Messages[len(request.AgentMessages.Messages)-1])
@ -261,21 +184,135 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID
log.Println(response)
err = agent.HandleToolCall(userId, imageId, &request)
a, innerErr := agent.toolHandler.Handle(toolHandlerInfo, &request)
err = innerErr
log.Println(a)
log.Println("--------------------------")
}
return err
}
func (handler ToolsHandlers) Handle(info ToolHandlerInfo, request *AgentRequestBody) (string, error) {
agentMessage := request.Messages[len(request.Messages)-1]
toolCall, ok := agentMessage.(AgentAssistantToolCall)
if !ok {
return "", errors.New("Latest message was not a tool call.")
}
fnName := toolCall.ToolCalls[0].Function.Name
arguments := toolCall.ToolCalls[0].Function.Arguments
if fnName == "finish" {
return "", errors.New("This is the end! Maybe we just return a boolean.")
}
fn, exists := handler.Handlers[fnName]
if !exists {
return "", errors.New("Could not find tool with this name.")
}
// holy jesus what the fuck.
parseMethod := reflect.ValueOf(fn).Field(1)
if !parseMethod.IsValid() {
return "", errors.New("Parse method not found")
}
parsedArgs := parseMethod.Call([]reflect.Value{reflect.ValueOf(arguments)})
if !parsedArgs[1].IsNil() {
return "", parsedArgs[1].Interface().(error)
}
log.Printf("Calling: %s\n", fnName)
fnMethod := reflect.ValueOf(fn).Field(2)
if !fnMethod.IsValid() {
return "", errors.New("Fn method not found")
}
response := fnMethod.Call([]reflect.Value{reflect.ValueOf(info), parsedArgs[0], reflect.ValueOf(toolCall.ToolCalls[0])})
if !response[1].IsNil() {
return "", response[1].Interface().(error)
}
stringResponse, err := json.Marshal(response[0].Interface())
if err != nil {
return "", err
}
request.AddText(AgentTextMessage{
Role: "tool",
Name: "createLocation",
Content: string(stringResponse),
ToolCallId: toolCall.ToolCalls[0].Id,
})
return string(stringResponse), nil
}
func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel) (EventLocationAgent, error) {
client, err := CreateAgentClient(eventLocationPrompt)
if err != nil {
return EventLocationAgent{}, err
}
return EventLocationAgent{
agent := EventLocationAgent{
client: client,
locationModel: locationModel,
eventModel: eventModel,
}, nil
}
toolHandler := ToolsHandlers{
Handlers: make(map[string]ToolHandlerInterface),
}
toolHandler.Handlers["listLocations"] = ToolHandler[ListLocationArguments, []model.Locations]{
FunctionName: "listLocations",
Parse: func(stringArgs string) (ListLocationArguments, error) {
args := ListLocationArguments{}
err := json.Unmarshal([]byte(stringArgs), &args)
return args, err
},
Fn: func(info ToolHandlerInfo, _args ListLocationArguments, call ToolCall) ([]model.Locations, error) {
return agent.locationModel.List(context.Background(), info.userId)
},
}
toolHandler.Handlers["createLocation"] = ToolHandler[CreateLocationArguments, model.Locations]{
FunctionName: "createLocation",
Parse: func(stringArgs string) (CreateLocationArguments, error) {
args := CreateLocationArguments{}
err := json.Unmarshal([]byte(stringArgs), &args)
return args, err
},
Fn: func(info ToolHandlerInfo, args CreateLocationArguments, call ToolCall) (model.Locations, error) {
return agent.locationModel.Save(context.Background(), info.userId, model.Locations{
Name: args.Name,
Address: args.Address,
Coordinates: args.Coordinates,
})
},
}
toolHandler.Handlers["attachImageLocation"] = ToolHandler[AttachImageLocationArguments, model.ImageLocations]{
FunctionName: "attachImageLocation",
Parse: func(stringArgs string) (AttachImageLocationArguments, error) {
args := AttachImageLocationArguments{}
err := json.Unmarshal([]byte(stringArgs), &args)
return args, err
},
Fn: func(info ToolHandlerInfo, args AttachImageLocationArguments, call ToolCall) (model.ImageLocations, error) {
return agent.locationModel.SaveToImage(context.Background(), info.imageId, uuid.MustParse(args.LocationId))
},
}
agent.toolHandler = toolHandler
return agent, nil
}

View File

@ -0,0 +1,26 @@
package agents
import "github.com/google/uuid"
type ToolHandlerInfo struct {
userId uuid.UUID
imageId uuid.UUID
}
type ToolHandler[TArgs any, TResp any] struct {
FunctionName string
Parse func(args string) (TArgs, error)
Fn func(info ToolHandlerInfo, args TArgs, call ToolCall) (TResp, error)
}
type ToolHandlerInterface interface {
GetFunctionName() string
}
func (handler ToolHandler[TArgs, TResp]) GetFunctionName() string {
return handler.FunctionName
}
type ToolsHandlers struct {
Handlers map[string]ToolHandlerInterface
}

View File

@ -120,6 +120,13 @@ func main() {
return
}
userImage, err := imageModel.FinishProcessing(ctx, image.ID)
if err != nil {
log.Println("Failed to FinishProcessing")
log.Println(err)
return
}
log.Println("Calling locationAgent!")
err = locationAgent.GetLocations(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
@ -137,13 +144,6 @@ func main() {
log.Println("-----")
log.Println(imageInfo)
userImage, err := imageModel.FinishProcessing(ctx, image.ID)
if err != nil {
log.Println("Failed to FinishProcessing")
log.Println(err)
return
}
log.Println(userImage)
err = tagModel.SaveToImage(ctx, userImage.ImageID, imageInfo.Tags)

View File

@ -3,11 +3,12 @@ package models
import (
"context"
"database/sql"
. "github.com/go-jet/jet/v2/postgres"
"log"
"screenmark/screenmark/.gen/haystack/haystack/model"
. "screenmark/screenmark/.gen/haystack/haystack/table"
. "github.com/go-jet/jet/v2/postgres"
"github.com/google/uuid"
)
@ -15,60 +16,37 @@ type LocationModel struct {
dbPool *sql.DB
}
// This looks stupid
func getValues(location model.Locations) []any {
arr := make([]any, 0)
if location.Address != nil {
arr = append(arr, *location.Address)
} else {
arr = append(arr, nil)
}
if location.Coordinates != nil {
arr = append(arr, *location.Coordinates)
} else {
arr = append(arr, nil)
}
if location.Description != nil {
arr = append(arr, *location.Description)
} else {
arr = append(arr, nil)
}
return arr
}
func (m LocationModel) List(ctx context.Context, userId uuid.UUID) ([]model.Locations, error) {
listLocationsStmt := SELECT(Locations.AllColumns, ImageLocations.AllColumns, UserImages.AllColumns).
listLocationsStmt := SELECT(Locations.AllColumns).
FROM(
Locations.
INNER_JOIN(ImageLocations, ImageLocations.LocationID.EQ(Locations.ID)).
INNER_JOIN(UserImages, UserImages.ImageID.EQ(ImageLocations.ImageID)),
).WHERE(UserImages.UserID.EQ(UUID(userId)))
INNER_JOIN(UserLocations, UserLocations.LocationID.EQ(Locations.ID)),
).
WHERE(UserLocations.UserID.EQ(UUID(userId)))
locations := []model.Locations{}
err := listLocationsStmt.QueryContext(ctx, m.dbPool, &locations)
return locations, err
}
func (m LocationModel) Save(ctx context.Context, locations []model.Locations) ([]model.Locations, error) {
func (m LocationModel) Save(ctx context.Context, userId uuid.UUID, location model.Locations) (model.Locations, error) {
insertLocationStmt := Locations.
INSERT(Locations.Name, Locations.Address, Locations.Coordinates, Locations.Description)
INSERT(Locations.Name, Locations.Address, Locations.Coordinates, Locations.Description).
VALUES(location.Name, location.Address, location.Coordinates, location.Description).
RETURNING(Locations.AllColumns)
for _, location := range locations {
insertLocationStmt = insertLocationStmt.VALUES(location.Name, getValues(location)...)
insertedLocation := model.Locations{}
err := insertLocationStmt.QueryContext(ctx, m.dbPool, &insertedLocation)
if err != nil {
return model.Locations{}, err
}
insertLocationStmt = insertLocationStmt.RETURNING(Locations.AllColumns)
insertUserLocationStmt := UserLocations.
INSERT(UserLocations.UserID, UserLocations.LocationID).
VALUES(userId, insertedLocation.ID)
log.Println(insertLocationStmt.DebugSql())
insertedLocation := []model.Locations{}
err := insertLocationStmt.QueryContext(ctx, m.dbPool, &insertedLocation)
_, err = insertUserLocationStmt.ExecContext(ctx, m.dbPool)
return insertedLocation, err
}

View File

@ -64,6 +64,12 @@ CREATE TABLE haystack.image_locations (
image_id UUID NOT NULL REFERENCES haystack.image (id)
);
CREATE TABLE haystack.user_locations (
id uuid PRIMARY KEY DEFAULT gen_random_uuid(),
location_id UUID NOT NULL REFERENCES haystack.locations (id),
user_id UUID NOT NULL REFERENCES haystack.users (id)
);
CREATE TABLE haystack.events (
id uuid PRIMARY KEY DEFAULT gen_random_uuid(),