From b57968b9383d22c1df5d9a85c0985afd09ac442f Mon Sep 17 00:00:00 2001 From: John Costa Date: Sun, 13 Apr 2025 15:02:32 +0100 Subject: [PATCH] feat(location): agent to create locations --- backend/agents/location_agent.go | 179 +++++++++++++++++++++++++++++++ backend/agents/orchestrator.go | 53 +++++---- backend/events.go | 13 ++- backend/models/locations.go | 24 ++++- 4 files changed, 231 insertions(+), 38 deletions(-) create mode 100644 backend/agents/location_agent.go diff --git a/backend/agents/location_agent.go b/backend/agents/location_agent.go new file mode 100644 index 0000000..25da2e0 --- /dev/null +++ b/backend/agents/location_agent.go @@ -0,0 +1,179 @@ +package agents + +import ( + "context" + "encoding/json" + "os" + "screenmark/screenmark/.gen/haystack/haystack/model" + "screenmark/screenmark/agents/client" + "screenmark/screenmark/models" + "time" + + "github.com/charmbracelet/log" + "github.com/google/uuid" +) + +const locationPrompt = ` +You are an agent. + +The user will send you images and you have to identify if they have any location or a place. This could a picture of a real place, an address, or it's name. + +There are various tools you can use to perform this task. + +listLocations +Lists the users already existing locations, you should do this before using createLocation to avoid creating duplicates. + +createLocation +Use this to create a new location. Avoid making duplicates and only create a new location if listLocations doesnt contain the location on the image. + +linkLocation +Links an image to a location. + +finish +Call when there is nothing else to do. +` + +const locationTools = ` +[ + { + "type": "function", + "function": { + "name": "listLocations", + "description": "List the locations the user already has.", + "parameters": { + "type": "object", + "properties": {}, + "required": [] + } + } + }, + { + "type": "function", + "function": { + "name": "createLocation", + "description": "Use to create a new location", + "parameters": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "address": { + "type": "string" + } + }, + "required": ["name"] + } + } + }, + { + "type": "function", + "function": { + "name": "linkLocation", + "description": "Use to link an already existing location to the image you were sent", + "parameters": { + "type": "object", + "properties": { + "locationId": { + "type": "string" + } + }, + "required": ["locationId"] + } + } + }, + { + "type": "function", + "function": { + "name": "finish", + "description": "Call this when there is nothing left to do.", + "parameters": { + "type": "object", + "properties": {}, + "required": [] + } + } + } +]` + +type LocationAgent struct { + client client.AgentClient + + locationModel models.LocationModel +} + +type listLocationArguments struct{} +type createLocationArguments struct { + Name string `json:"name"` + Address *string `json:"address"` +} +type linkLocationArguments struct { + LocationID string `json:"locationId"` +} + +func NewLocationAgent(locationModel models.LocationModel) (LocationAgent, error) { + agentClient, err := client.CreateAgentClient(log.NewWithOptions(os.Stdout, log.Options{ + ReportTimestamp: true, + TimeFormat: time.Kitchen, + Prefix: "Locations 📍", + })) + + if err != nil { + return LocationAgent{}, err + } + + agent := LocationAgent{ + client: agentClient, + locationModel: locationModel, + } + + agentClient.ToolHandler.AddTool("listLocations", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { + return agent.locationModel.List(context.Background(), info.UserId) + }) + + agentClient.ToolHandler.AddTool("createLocation", func(info client.ToolHandlerInfo, _args string, call client.ToolCall) (any, error) { + args := createLocationArguments{} + err := json.Unmarshal([]byte(_args), &args) + if err != nil { + return model.Locations{}, err + } + + ctx := context.Background() + + location, err := agent.locationModel.Save(ctx, info.UserId, model.Locations{ + Name: args.Name, + Address: args.Address, + }) + + if err != nil { + return model.Locations{}, err + } + + _, err = agent.locationModel.SaveToImage(ctx, info.ImageId, location.ID) + if err != nil { + return model.Locations{}, err + } + + return location, nil + }) + + agentClient.ToolHandler.AddTool("linkLocation", func(info client.ToolHandlerInfo, _args string, call client.ToolCall) (any, error) { + args := linkLocationArguments{} + err := json.Unmarshal([]byte(_args), &args) + if err != nil { + return "", err + } + + ctx := context.Background() + + contactUuid, err := uuid.Parse(args.LocationID) + if err != nil { + return "", err + } + + agent.locationModel.SaveToImage(ctx, info.ImageId, contactUuid) + return "Saved", nil + }) + + return agent, nil +} diff --git a/backend/agents/orchestrator.go b/backend/agents/orchestrator.go index 9868bef..19a68d7 100644 --- a/backend/agents/orchestrator.go +++ b/backend/agents/orchestrator.go @@ -20,17 +20,15 @@ The agents are available as tool calls. Agents available: -eventLocationAgent -Use it when you think the image contains an event or a location of any sort. This can be an event page, a map, an address or a date. -This could also be a conversation describing an event. - noteAgent Use when there is ANY text on the image. contactAgent - Use it when the image contains information relating a person. +locationAgent +Use it when the image contains some address or a place. + noAction When you think there is no more information to extract from the image. @@ -41,18 +39,6 @@ Do not call the agent if you do not think it is relevant for the image. const OrchestratorTools = ` [ - { - "type": "function", - "function": { - "name": "eventLocationAgent", - "description": "Use when there is an event or location on the image. This could be in writing form", - "parameters": { - "type": "object", - "properties": {}, - "required": [] - } - } - }, { "type": "function", "function": { @@ -76,6 +62,18 @@ const OrchestratorTools = ` "required": [] } } + }, + { + "type": "function", + "function": { + "name": "locationAgent", + "description": "Use when then image contains some place, location or address", + "parameters": { + "type": "object", + "properties": {}, + "required": [] + } + } }, { "type": "function", @@ -101,7 +99,7 @@ type Status struct { Ok bool `json:"ok"` } -func NewOrchestratorAgent(eventLocationAgent EventLocationAgent, noteAgent NoteAgent, contactAgent ContactAgent, imageName string, imageData []byte) (OrchestratorAgent, error) { +func NewOrchestratorAgent(noteAgent NoteAgent, contactAgent ContactAgent, locationAgent LocationAgent, imageName string, imageData []byte) (OrchestratorAgent, error) { agent, err := client.CreateAgentClient(log.NewWithOptions(os.Stdout, log.Options{ ReportTimestamp: true, TimeFormat: time.Kitchen, @@ -112,17 +110,6 @@ func NewOrchestratorAgent(eventLocationAgent EventLocationAgent, noteAgent NoteA return OrchestratorAgent{}, err } - agent.ToolHandler.AddTool("eventLocationAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { - // We need a way to keep track of this async? - // Probably just a DB, because we don't want to wait. The orchistrator shouldnt wait for this stuff to finish. - - go eventLocationAgent.client.RunAgent(eventLocationPrompt, eventLocationTools, "finish", info.UserId, info.ImageId, imageName, imageData) - - return Status{ - Ok: true, - }, nil - }) - agent.ToolHandler.AddTool("noteAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { go noteAgent.GetNotes(info.UserId, info.ImageId, imageName, imageData) @@ -139,6 +126,14 @@ func NewOrchestratorAgent(eventLocationAgent EventLocationAgent, noteAgent NoteA }, nil }) + agent.ToolHandler.AddTool("locationAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { + go locationAgent.client.RunAgent(locationPrompt, locationTools, "finish", info.UserId, info.ImageId, imageName, imageData) + + return Status{ + Ok: true, + }, nil + }) + agent.ToolHandler.AddTool("noAction", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { // To nothing diff --git a/backend/events.go b/backend/events.go index 39edff8..23f496a 100644 --- a/backend/events.go +++ b/backend/events.go @@ -23,7 +23,6 @@ func ListenNewImageEvents(db *sql.DB, eventManager *EventManager) { defer listener.Close() locationModel := models.NewLocationModel(db) - eventModel := models.NewEventModel(db) noteModel := models.NewNoteModel(db) imageModel := models.NewImageModel(db) contactModel := models.NewContactModel(db) @@ -42,11 +41,6 @@ func ListenNewImageEvents(db *sql.DB, eventManager *EventManager) { ctx := context.Background() go func() { - locationAgent, err := agents.NewLocationEventAgent(locationModel, eventModel, contactModel) - if err != nil { - panic(err) - } - noteAgent, err := agents.NewNoteAgent(noteModel) if err != nil { panic(err) @@ -57,6 +51,11 @@ func ListenNewImageEvents(db *sql.DB, eventManager *EventManager) { panic(err) } + locationAgent, err := agents.NewLocationAgent(locationModel) + if err != nil { + panic(err) + } + image, err := imageModel.GetToProcessWithData(ctx, imageId) if err != nil { log.Println("Failed to GetToProcessWithData") @@ -70,7 +69,7 @@ func ListenNewImageEvents(db *sql.DB, eventManager *EventManager) { return } - orchestrator, err := agents.NewOrchestratorAgent(locationAgent, noteAgent, contactAgent, image.Image.ImageName, image.Image.Image) + orchestrator, err := agents.NewOrchestratorAgent(noteAgent, contactAgent, locationAgent, image.Image.ImageName, image.Image.Image) if err != nil { panic(err) } diff --git a/backend/models/locations.go b/backend/models/locations.go index 7f874c8..b3bfd6c 100644 --- a/backend/models/locations.go +++ b/backend/models/locations.go @@ -7,6 +7,7 @@ import ( . "screenmark/screenmark/.gen/haystack/haystack/table" . "github.com/go-jet/jet/v2/postgres" + "github.com/go-jet/jet/v2/qrm" "github.com/google/uuid" ) @@ -51,13 +52,32 @@ func (m LocationModel) Save(ctx context.Context, userId uuid.UUID, location mode } func (m LocationModel) SaveToImage(ctx context.Context, imageId uuid.UUID, locationId uuid.UUID) (model.ImageLocations, error) { + imageLocation := model.ImageLocations{} + + checkExistingStmt := ImageLocations. + SELECT(ImageLocations.AllColumns). + WHERE( + ImageLocations.ImageID.EQ(UUID(imageId)). + AND(ImageLocations.LocationID.EQ(UUID(locationId))), + ) + + err := checkExistingStmt.QueryContext(ctx, m.dbPool, &imageLocation) + if err != nil && err != qrm.ErrNoRows { + // A real error + return model.ImageLocations{}, err + } + + if err == nil { + // Already exists. + return imageLocation, nil + } + insertImageLocationStmt := ImageLocations. INSERT(ImageLocations.ImageID, ImageLocations.LocationID). VALUES(imageId, locationId). RETURNING(ImageLocations.AllColumns) - imageLocation := model.ImageLocations{} - err := insertImageLocationStmt.QueryContext(ctx, m.dbPool, &imageLocation) + err = insertImageLocationStmt.QueryContext(ctx, m.dbPool, &imageLocation) return imageLocation, err }