package agents import ( "context" "encoding/json" "errors" "log" "screenmark/screenmark/.gen/haystack/haystack/model" "screenmark/screenmark/models" "github.com/google/uuid" ) const eventLocationPrompt = ` You are an agent that extracts events and locations from an image. Your job is to check if an image has an event or a location and use the correct tools to extract this information. If you find an event, you should look for a location for this event on the image, it is possible an event doesn't have a location. Only create an event if you see an event on the image, not all locations have an associated event. DO NOT CREATE EVENTS It is possible that there is no location or event on an image. You should ask for a list of locations, as the user is likely to have this location saved. Reuse existing locations where possible. ` // TODO: this should be read directly from a file on load. const TOOLS = ` [ { "type": "function", "function": { "name": "createLocation", "description": "Creates a location", "parameters": { "type": "object", "properties": { "name": { "type": "string" }, "coordinates": { "type": "string" }, "address": { "type": "string" } }, "required": ["name"] } } }, { "type": "function", "function": { "name": "listLocations", "description": "Lists the locations available", "parameters": { "type": "object", "properties": {} } } }, { "type": "function", "function": { "name": "createEvent", "description": "Creates a new event", "parameters": { "type": "object", "properties": { "name": { "type": "string" }, "datetime": { "type": "string" }, "locationId": { "type": "string", "description": "The ID of the location, available by listLocations" } }, "required": ["name"] } } }, { "type": "function", "function": { "name": "finish", "description": "Nothing else to do, call this function.", "parameters": { "type": "object", "properties": {} } } } ] ` type EventLocationAgent struct { client AgentClient eventModel models.EventModel locationModel models.LocationModel } type CreateLocationArguments struct { Name string `json:"name"` Address *string `json:"address,omitempty"` Coordinates *string `json:"coordinates,omitempty"` } func (agent EventLocationAgent) HandleToolCall(userId uuid.UUID, request *AgentRequestBody) error { latestMessage := request.AgentMessages.Messages[len(request.AgentMessages.Messages)-1] toolCall, ok := latestMessage.(AgentAssistantToolCall) if !ok { return errors.New("Latest message is not a tool_call") } for _, call := range toolCall.ToolCalls { if call.Function.Name == "listLocations" { log.Println("Function call: listLocations") locations, err := agent.locationModel.List(context.Background(), userId) if err != nil { return err } jsonLocations, err := json.Marshal(locations) if err != nil { return err } request.AddText(AgentTextMessage{ Role: "tool", Name: "listLocations", Content: string(jsonLocations), ToolCallId: call.Id, }) } else if call.Function.Name == "createLocation" { log.Println("Function call: createLocation") locationArguments := CreateLocationArguments{} err := json.Unmarshal([]byte(call.Function.Arguments), &locationArguments) if err != nil { return err } locations, err := agent.locationModel.Save(context.Background(), []model.Locations{{ Name: locationArguments.Name, Address: locationArguments.Address, Coordinates: locationArguments.Coordinates, }}) if err != nil { return err } createdLocation := locations[0] jsonCreatedLocation, err := json.Marshal(createdLocation) if err != nil { return err } request.AddText(AgentTextMessage{ Role: "tool", Name: "createLocation", Content: string(jsonCreatedLocation), ToolCallId: call.Id, }) } else if call.Function.Name == "finish" { log.Println("Finished!") return errors.New("hmmm, this isnt actually an error but hey") } else { return errors.New("Unknown tool_call") } } return nil } func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageName string, imageData []byte) error { var tools any err := json.Unmarshal([]byte(TOOLS), &tools) toolChoice := "any" request := AgentRequestBody{ Tools: &tools, ToolChoice: &toolChoice, Model: "pixtral-12b-2409", Temperature: 0.3, ResponseFormat: ResponseFormat{ Type: "text", }, } err = request.AddSystem(eventLocationPrompt) if err != nil { return err } log.Println(request) request.AddImage(imageName, imageData) _, err = agent.client.Request(&request) if err != nil { return err } err = agent.HandleToolCall(userId, &request) for err == nil { log.Printf("Latest message: %+v\n", request.AgentMessages.Messages[len(request.AgentMessages.Messages)-1]) response, requestError := agent.client.Request(&request) if requestError != nil { return requestError } log.Println(response) err = agent.HandleToolCall(userId, &request) } return err } func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel) (EventLocationAgent, error) { client, err := CreateAgentClient(eventLocationPrompt) if err != nil { return EventLocationAgent{}, err } return EventLocationAgent{ client: client, locationModel: locationModel, eventModel: eventModel, }, nil }