diff --git a/backend/agents/agent.go b/backend/agents/agent.go index d1f869f..231037b 100644 --- a/backend/agents/agent.go +++ b/backend/agents/agent.go @@ -393,6 +393,49 @@ func parseAgentResponse(jsonResponse []byte) (ImageInfo, error) { return imageInfo, nil } +func (client AgentClient) Request(request *AgentRequestBody) (AgentResponse, error) { + jsonAiRequest, err := json.Marshal(request) + if err != nil { + return AgentResponse{}, err + } + + httpRequest, err := client.getRequest(jsonAiRequest) + if err != nil { + return AgentResponse{}, err + } + + resp, err := client.Do(httpRequest) + if err != nil { + return AgentResponse{}, err + } + + response, err := io.ReadAll(resp.Body) + if err != nil { + return AgentResponse{}, err + } + + agentResponse := AgentResponse{} + err = json.Unmarshal(response, &agentResponse) + + if err != nil { + return AgentResponse{}, err + } + + log.Println(string(response)) + + toolCalls := agentResponse.Choices[0].Message.ToolCalls + if len(toolCalls) > 0 { + // Should for sure be more flexible. + request.AddToolCall(AgentAssistantToolCall{ + Role: "assistant", + Content: "", + ToolCalls: toolCalls, + }) + } + + return agentResponse, nil +} + func (client AgentClient) GetImageInfo(imageName string, imageData []byte) (ImageInfo, error) { aiRequest, err := getCompletionsForImage("pixtral-12b-2409", 1.0, client.systemPrompt, imageName, RESPONSE_FORMAT, imageData) if err != nil { diff --git a/backend/agents/event_location_agent.go b/backend/agents/event_location_agent.go index 9b3ace1..8c0c5d0 100644 --- a/backend/agents/event_location_agent.go +++ b/backend/agents/event_location_agent.go @@ -3,8 +3,9 @@ package agents import ( "context" "encoding/json" - "io" + "errors" "log" + "screenmark/screenmark/.gen/haystack/haystack/model" "screenmark/screenmark/models" "github.com/google/uuid" @@ -15,6 +16,8 @@ You are an agent that extracts events and locations from an image. Your job is to check if an image has an event or a location and use the correct tools to extract this information. If you find an event, you should look for a location for this event on the image, it is possible an event doesn't have a location. +Only create an event if you see an event on the image, not all locations have an associated event. +DO NOT CREATE EVENTS It is possible that there is no location or event on an image. @@ -79,6 +82,17 @@ const TOOLS = ` "required": ["name"] } } + }, + { + "type": "function", + "function": { + "name": "finish", + "description": "Nothing else to do, call this function.", + "parameters": { + "type": "object", + "properties": {} + } + } } ] ` @@ -90,15 +104,92 @@ type EventLocationAgent struct { locationModel models.LocationModel } +type CreateLocationArguments struct { + Name string `json:"name"` + Address *string `json:"address,omitempty"` + Coordinates *string `json:"coordinates,omitempty"` +} + +func (agent EventLocationAgent) HandleToolCall(userId uuid.UUID, request *AgentRequestBody) error { + latestMessage := request.AgentMessages.Messages[len(request.AgentMessages.Messages)-1] + + toolCall, ok := latestMessage.(AgentAssistantToolCall) + + if !ok { + return errors.New("Latest message is not a tool_call") + } + + for _, call := range toolCall.ToolCalls { + if call.Function.Name == "listLocations" { + log.Println("Function call: listLocations") + + locations, err := agent.locationModel.List(context.Background(), userId) + if err != nil { + return err + } + + jsonLocations, err := json.Marshal(locations) + if err != nil { + return err + } + + request.AddText(AgentTextMessage{ + Role: "tool", + Name: "listLocations", + Content: string(jsonLocations), + ToolCallId: call.Id, + }) + } else if call.Function.Name == "createLocation" { + log.Println("Function call: createLocation") + + locationArguments := CreateLocationArguments{} + err := json.Unmarshal([]byte(call.Function.Arguments), &locationArguments) + if err != nil { + return err + } + + locations, err := agent.locationModel.Save(context.Background(), []model.Locations{{ + Name: locationArguments.Name, + Address: locationArguments.Address, + Coordinates: locationArguments.Coordinates, + }}) + + if err != nil { + return err + } + + createdLocation := locations[0] + jsonCreatedLocation, err := json.Marshal(createdLocation) + if err != nil { + return err + } + + request.AddText(AgentTextMessage{ + Role: "tool", + Name: "createLocation", + Content: string(jsonCreatedLocation), + ToolCallId: call.Id, + }) + } else if call.Function.Name == "finish" { + log.Println("Finished!") + return errors.New("hmmm, this isnt actually an error but hey") + } else { + return errors.New("Unknown tool_call") + } + } + + return nil +} + func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageName string, imageData []byte) error { var tools any err := json.Unmarshal([]byte(TOOLS), &tools) - auto := "auto" + toolChoice := "any" request := AgentRequestBody{ Tools: &tools, - ToolChoice: &auto, + ToolChoice: &toolChoice, Model: "pixtral-12b-2409", Temperature: 0.3, ResponseFormat: ResponseFormat{ @@ -115,81 +206,26 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageName string, request.AddImage(imageName, imageData) - jsonAiRequest, err := json.Marshal(request) + _, err = agent.client.Request(&request) if err != nil { return err } - httpRequest, err := agent.client.getRequest(jsonAiRequest) - if err != nil { - return err + err = agent.HandleToolCall(userId, &request) + for err == nil { + log.Printf("Latest message: %+v\n", request.AgentMessages.Messages[len(request.AgentMessages.Messages)-1]) + + response, requestError := agent.client.Request(&request) + if requestError != nil { + return requestError + } + + log.Println(response) + + err = agent.HandleToolCall(userId, &request) } - resp, err := agent.client.Do(httpRequest) - if err != nil { - return err - } - - response, err := io.ReadAll(resp.Body) - if err != nil { - return err - } - - log.Println(string(response)) - - agentResponse := AgentResponse{} - err = json.Unmarshal(response, &agentResponse) - - toolCalls := agentResponse.Choices[0].Message.ToolCalls[0] - - if toolCalls.Function.Name == "listLocations" { - locations, err := agent.locationModel.List(context.Background(), userId) - if err != nil { - return err - } - - jsonLocations, err := json.Marshal(locations) - if err != nil { - return err - } - - request.AddToolCall(AgentAssistantToolCall{ - Role: "assistant", - Content: "", - ToolCalls: []ToolCall{toolCalls}, - }) - - request.AddText(AgentTextMessage{ - Role: "tool", - Name: "listLocations", - Content: string(jsonLocations), - ToolCallId: toolCalls.Id, - }) - - jsonAiRequest, err := json.Marshal(request) - if err != nil { - return err - } - - httpRequest, err := agent.client.getRequest(jsonAiRequest) - if err != nil { - return err - } - - resp, err := agent.client.Do(httpRequest) - if err != nil { - return err - } - - response, err := io.ReadAll(resp.Body) - if err != nil { - return err - } - - log.Println(string(response)) - } - - return nil + return err } func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel) (EventLocationAgent, error) { diff --git a/backend/main.go b/backend/main.go index 1b50b0a..434658d 100644 --- a/backend/main.go +++ b/backend/main.go @@ -121,7 +121,9 @@ func main() { } log.Println("Calling locationAgent!") - locationAgent.GetLocations(image.UserID, image.Image.ImageName, image.Image.Image) + err = locationAgent.GetLocations(image.UserID, image.Image.ImageName, image.Image.Image) + + log.Println(err) return diff --git a/backend/models/locations.go b/backend/models/locations.go index 96a8f31..b18acda 100644 --- a/backend/models/locations.go +++ b/backend/models/locations.go @@ -55,7 +55,7 @@ func (m LocationModel) List(ctx context.Context, userId uuid.UUID) ([]model.Loca return locations, err } -func (m LocationModel) Save(ctx context.Context, locations []model.Locations) (model.Locations, error) { +func (m LocationModel) Save(ctx context.Context, locations []model.Locations) ([]model.Locations, error) { insertLocationStmt := Locations. INSERT(Locations.Name, Locations.Address, Locations.Coordinates, Locations.Description) @@ -67,7 +67,7 @@ func (m LocationModel) Save(ctx context.Context, locations []model.Locations) (m log.Println(insertLocationStmt.DebugSql()) - insertedLocation := model.Locations{} + insertedLocation := []model.Locations{} err := insertLocationStmt.QueryContext(ctx, m.dbPool, &insertedLocation) return insertedLocation, err @@ -84,9 +84,11 @@ func (m LocationModel) SaveToImage(ctx context.Context, imageId uuid.UUID, locat return err } + // TODO: doesnt work if array is more than 1. BIG TODO + insertImageLocationStmt := ImageLocations. INSERT(ImageLocations.ImageID, ImageLocations.LocationID). - VALUES(imageId, location.ID) + VALUES(imageId, location[0].ID) _, err = insertImageLocationStmt.ExecContext(ctx, m.dbPool) diff --git a/backend/tools.json b/backend/tools.json index 13652dc..d957b21 100644 --- a/backend/tools.json +++ b/backend/tools.json @@ -54,5 +54,16 @@ "required": ["name"] } } + }, + { + "type": "function", + "function": { + "name": "finish", + "description": "Nothing else to do, call this function.", + "parameters": { + "type": "object", + "properties": {} + } + } } ]