diff --git a/backend/agents/agent.go b/backend/agents/agent.go index cbb5b2f..d1f869f 100644 --- a/backend/agents/agent.go +++ b/backend/agents/agent.go @@ -48,8 +48,10 @@ type AgentMessage interface { } type AgentTextMessage struct { - Role string `json:"role"` - Content string `json:"content"` + Role string `json:"role"` + Content string `json:"content"` + ToolCallId string `json:"tool_call_id,omitempty"` + Name string `json:"name,omitempty"` } func (textContent AgentTextMessage) MessageToJson() ([]byte, error) { @@ -57,6 +59,16 @@ func (textContent AgentTextMessage) MessageToJson() ([]byte, error) { return json.Marshal(textContent) } +type AgentAssistantToolCall struct { + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []ToolCall `json:"tool_calls"` +} + +func (toolCall AgentAssistantToolCall) MessageToJson() ([]byte, error) { + return json.Marshal(toolCall) +} + type AgentArrayMessage struct { Role string `json:"role"` Content []AgentContent `json:"content"` @@ -66,6 +78,14 @@ func (arrayContent AgentArrayMessage) MessageToJson() ([]byte, error) { return json.Marshal(arrayContent) } +func (content *AgentMessages) AddText(message AgentTextMessage) { + content.Messages = append(content.Messages, message) +} + +func (content *AgentMessages) AddToolCall(toolCall AgentAssistantToolCall) { + content.Messages = append(content.Messages, toolCall) +} + func (content *AgentMessages) AddImage(imageName string, image []byte) error { extension := filepath.Ext(imageName) if len(extension) == 0 { @@ -320,9 +340,21 @@ func getCompletionsForImage(model string, temperature float64, prompt string, im return request, nil } +type FunctionCall struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +type ToolCall struct { + Index int `json:"index"` + Id string `json:"id"` + Function FunctionCall `json:"function"` +} + type ResponseChoiceMessage struct { - Role string `json:"role"` - Content string `json:"content"` + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []ToolCall `json:"tool_calls"` } type ResponseChoice struct { diff --git a/backend/agents/event_location_agent.go b/backend/agents/event_location_agent.go index 2ec4d23..9b3ace1 100644 --- a/backend/agents/event_location_agent.go +++ b/backend/agents/event_location_agent.go @@ -1,9 +1,13 @@ package agents import ( + "context" "encoding/json" "io" "log" + "screenmark/screenmark/models" + + "github.com/google/uuid" ) const eventLocationPrompt = ` @@ -13,6 +17,8 @@ Your job is to check if an image has an event or a location and use the correct If you find an event, you should look for a location for this event on the image, it is possible an event doesn't have a location. It is possible that there is no location or event on an image. + +You should ask for a list of locations, as the user is likely to have this location saved. Reuse existing locations where possible. ` // TODO: this should be read directly from a file on load. @@ -77,9 +83,14 @@ const TOOLS = ` ] ` -type EventLocationAgent = AgentClient +type EventLocationAgent struct { + client AgentClient -func (agent EventLocationAgent) GetLocations(imageName string, imageData []byte) error { + eventModel models.EventModel + locationModel models.LocationModel +} + +func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageName string, imageData []byte) error { var tools any err := json.Unmarshal([]byte(TOOLS), &tools) @@ -109,12 +120,12 @@ func (agent EventLocationAgent) GetLocations(imageName string, imageData []byte) return err } - httpRequest, err := agent.getRequest(jsonAiRequest) + httpRequest, err := agent.client.getRequest(jsonAiRequest) if err != nil { return err } - resp, err := agent.Do(httpRequest) + resp, err := agent.client.Do(httpRequest) if err != nil { return err } @@ -126,14 +137,70 @@ func (agent EventLocationAgent) GetLocations(imageName string, imageData []byte) log.Println(string(response)) + agentResponse := AgentResponse{} + err = json.Unmarshal(response, &agentResponse) + + toolCalls := agentResponse.Choices[0].Message.ToolCalls[0] + + if toolCalls.Function.Name == "listLocations" { + locations, err := agent.locationModel.List(context.Background(), userId) + if err != nil { + return err + } + + jsonLocations, err := json.Marshal(locations) + if err != nil { + return err + } + + request.AddToolCall(AgentAssistantToolCall{ + Role: "assistant", + Content: "", + ToolCalls: []ToolCall{toolCalls}, + }) + + request.AddText(AgentTextMessage{ + Role: "tool", + Name: "listLocations", + Content: string(jsonLocations), + ToolCallId: toolCalls.Id, + }) + + jsonAiRequest, err := json.Marshal(request) + if err != nil { + return err + } + + httpRequest, err := agent.client.getRequest(jsonAiRequest) + if err != nil { + return err + } + + resp, err := agent.client.Do(httpRequest) + if err != nil { + return err + } + + response, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + + log.Println(string(response)) + } + return nil } -func NewLocationEventAgent() (EventLocationAgent, error) { - agent, err := CreateAgentClient(eventLocationPrompt) +func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel) (EventLocationAgent, error) { + client, err := CreateAgentClient(eventLocationPrompt) if err != nil { return EventLocationAgent{}, err } - return agent, nil + return EventLocationAgent{ + client: client, + locationModel: locationModel, + eventModel: eventModel, + }, nil } diff --git a/backend/main.go b/backend/main.go index e5e5b06..1b50b0a 100644 --- a/backend/main.go +++ b/backend/main.go @@ -108,7 +108,7 @@ func main() { panic(err) } - locationAgent, err := agents.NewLocationEventAgent() + locationAgent, err := agents.NewLocationEventAgent(locationModel, eventModel) if err != nil { panic(err) } @@ -121,7 +121,9 @@ func main() { } log.Println("Calling locationAgent!") - locationAgent.GetLocations(image.Image.ImageName, image.Image.Image) + locationAgent.GetLocations(image.UserID, image.Image.ImageName, image.Image.Image) + + return imageInfo, err := openAiClient.GetImageInfo(image.Image.ImageName, image.Image.Image) if err != nil { diff --git a/backend/models/locations.go b/backend/models/locations.go index 60d786a..96a8f31 100644 --- a/backend/models/locations.go +++ b/backend/models/locations.go @@ -3,6 +3,7 @@ package models import ( "context" "database/sql" + . "github.com/go-jet/jet/v2/postgres" "log" "screenmark/screenmark/.gen/haystack/haystack/model" . "screenmark/screenmark/.gen/haystack/haystack/table" @@ -39,6 +40,21 @@ func getValues(location model.Locations) []any { return arr } +func (m LocationModel) List(ctx context.Context, userId uuid.UUID) ([]model.Locations, error) { + listLocationsStmt := SELECT(Locations.AllColumns, ImageLocations.AllColumns, UserImages.AllColumns). + FROM( + Locations. + INNER_JOIN(ImageLocations, ImageLocations.LocationID.EQ(Locations.ID)). + INNER_JOIN(UserImages, UserImages.ImageID.EQ(ImageLocations.ImageID)), + ).WHERE(UserImages.UserID.EQ(UUID(userId))) + + locations := []model.Locations{} + + err := listLocationsStmt.QueryContext(ctx, m.dbPool, &locations) + + return locations, err +} + func (m LocationModel) Save(ctx context.Context, locations []model.Locations) (model.Locations, error) { insertLocationStmt := Locations. INSERT(Locations.Name, Locations.Address, Locations.Coordinates, Locations.Description)