From 13e5ed9f9e95fd7815c9fc407783056e514c1860 Mon Sep 17 00:00:00 2001 From: John Costa Date: Sat, 22 Mar 2025 17:47:02 +0000 Subject: [PATCH] feat(locations): allowing AI to attach it to the image --- backend/agents/event_location_agent.go | 49 +++++++++++++++++++++++--- backend/main.go | 9 +---- backend/models/locations.go | 21 +++-------- backend/tools.json | 16 +++++++++ 4 files changed, 66 insertions(+), 29 deletions(-) diff --git a/backend/agents/event_location_agent.go b/backend/agents/event_location_agent.go index 8c0c5d0..8ea94df 100644 --- a/backend/agents/event_location_agent.go +++ b/backend/agents/event_location_agent.go @@ -17,7 +17,6 @@ Your job is to check if an image has an event or a location and use the correct If you find an event, you should look for a location for this event on the image, it is possible an event doesn't have a location. Only create an event if you see an event on the image, not all locations have an associated event. -DO NOT CREATE EVENTS It is possible that there is no location or event on an image. @@ -60,6 +59,22 @@ const TOOLS = ` } } }, + { + "type": "function", + "function": { + "name": "attachImageLocation", + "description": "Add a location to an image", + "parameters": { + "type": "object", + "properties": { + "locationId": { + "type": "string" + } + }, + "required": ["locationId"] + } + } + }, { "type": "function", "function": { @@ -110,7 +125,11 @@ type CreateLocationArguments struct { Coordinates *string `json:"coordinates,omitempty"` } -func (agent EventLocationAgent) HandleToolCall(userId uuid.UUID, request *AgentRequestBody) error { +type AttachImageLocationArguments struct { + LocationId string `json:"locationId"` +} + +func (agent EventLocationAgent) HandleToolCall(userId uuid.UUID, imageId uuid.UUID, request *AgentRequestBody) error { latestMessage := request.AgentMessages.Messages[len(request.AgentMessages.Messages)-1] toolCall, ok := latestMessage.(AgentAssistantToolCall) @@ -170,6 +189,26 @@ func (agent EventLocationAgent) HandleToolCall(userId uuid.UUID, request *AgentR Content: string(jsonCreatedLocation), ToolCallId: call.Id, }) + } else if call.Function.Name == "attachImageLocation" { + log.Println("Function call: attachImageLocation") + + attachLocationArguments := AttachImageLocationArguments{} + err := json.Unmarshal([]byte(call.Function.Arguments), &attachLocationArguments) + if err != nil { + return err + } + + _, err = agent.locationModel.SaveToImage(context.Background(), imageId, uuid.MustParse(attachLocationArguments.LocationId)) + if err != nil { + return err + } + + request.AddText(AgentTextMessage{ + Role: "tool", + Name: "attachImageLocation", + Content: "OK", + ToolCallId: call.Id, + }) } else if call.Function.Name == "finish" { log.Println("Finished!") return errors.New("hmmm, this isnt actually an error but hey") @@ -181,7 +220,7 @@ func (agent EventLocationAgent) HandleToolCall(userId uuid.UUID, request *AgentR return nil } -func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageName string, imageData []byte) error { +func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error { var tools any err := json.Unmarshal([]byte(TOOLS), &tools) @@ -211,7 +250,7 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageName string, return err } - err = agent.HandleToolCall(userId, &request) + err = agent.HandleToolCall(userId, imageId, &request) for err == nil { log.Printf("Latest message: %+v\n", request.AgentMessages.Messages[len(request.AgentMessages.Messages)-1]) @@ -222,7 +261,7 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageName string, log.Println(response) - err = agent.HandleToolCall(userId, &request) + err = agent.HandleToolCall(userId, imageId, &request) } return err diff --git a/backend/main.go b/backend/main.go index 434658d..edc572e 100644 --- a/backend/main.go +++ b/backend/main.go @@ -121,7 +121,7 @@ func main() { } log.Println("Calling locationAgent!") - err = locationAgent.GetLocations(image.UserID, image.Image.ImageName, image.Image.Image) + err = locationAgent.GetLocations(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image) log.Println(err) @@ -167,13 +167,6 @@ func main() { return } - err = locationModel.SaveToImage(ctx, userImage.ImageID, imageInfo.Locations) - if err != nil { - log.Println("Failed to save location") - log.Println(err) - return - } - err = eventModel.SaveToImage(ctx, userImage.ImageID, imageInfo.Events) if err != nil { log.Println("Failed to save events") diff --git a/backend/models/locations.go b/backend/models/locations.go index b18acda..7e05d54 100644 --- a/backend/models/locations.go +++ b/backend/models/locations.go @@ -73,26 +73,15 @@ func (m LocationModel) Save(ctx context.Context, locations []model.Locations) ([ return insertedLocation, err } -func (m LocationModel) SaveToImage(ctx context.Context, imageId uuid.UUID, locations []model.Locations) error { - if len(locations) == 0 { - return nil - } - - location, err := m.Save(ctx, locations) - - if err != nil { - return err - } - - // TODO: doesnt work if array is more than 1. BIG TODO - +func (m LocationModel) SaveToImage(ctx context.Context, imageId uuid.UUID, locationId uuid.UUID) (model.ImageLocations, error) { insertImageLocationStmt := ImageLocations. INSERT(ImageLocations.ImageID, ImageLocations.LocationID). - VALUES(imageId, location[0].ID) + VALUES(imageId, locationId) - _, err = insertImageLocationStmt.ExecContext(ctx, m.dbPool) + imageLocation := model.ImageLocations{} + _, err := insertImageLocationStmt.ExecContext(ctx, m.dbPool) - return err + return imageLocation, err } func NewLocationModel(db *sql.DB) LocationModel { diff --git a/backend/tools.json b/backend/tools.json index d957b21..cdd0ec5 100644 --- a/backend/tools.json +++ b/backend/tools.json @@ -55,6 +55,22 @@ } } }, + { + "type": "function", + "function": { + "name": "attachImageLocation", + "description": "Add a location to an image", + "parameters": { + "type": "object", + "properties": { + "locationId": { + "type": "string" + } + }, + "required": ["locationId"] + } + } + }, { "type": "function", "function": {