Haystack/backend/agents/event_location_agent.go

282 lines
7.3 KiB
Go

package agents
import (
"context"
"encoding/json"
"errors"
"log"
"screenmark/screenmark/.gen/haystack/haystack/model"
"screenmark/screenmark/models"
"github.com/google/uuid"
)
const eventLocationPrompt = `
You are an agent that extracts events and locations from an image.
Your job is to check if an image has an event or a location and use the correct tools to extract this information.
If you find an event, you should look for a location for this event on the image, it is possible an event doesn't have a location.
Only create an event if you see an event on the image, not all locations have an associated event.
It is possible that there is no location or event on an image.
You should ask for a list of locations, as the user is likely to have this location saved. Reuse existing locations where possible.
`
// TODO: this should be read directly from a file on load.
const TOOLS = `
[
{
"type": "function",
"function": {
"name": "createLocation",
"description": "Creates a location",
"parameters": {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"coordinates": {
"type": "string"
},
"address": {
"type": "string"
}
},
"required": ["name"]
}
}
},
{
"type": "function",
"function": {
"name": "listLocations",
"description": "Lists the locations available",
"parameters": {
"type": "object",
"properties": {}
}
}
},
{
"type": "function",
"function": {
"name": "attachImageLocation",
"description": "Add a location to an image",
"parameters": {
"type": "object",
"properties": {
"locationId": {
"type": "string"
}
},
"required": ["locationId"]
}
}
},
{
"type": "function",
"function": {
"name": "createEvent",
"description": "Creates a new event",
"parameters": {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"datetime": {
"type": "string"
},
"locationId": {
"type": "string",
"description": "The ID of the location, available by listLocations"
}
},
"required": ["name"]
}
}
},
{
"type": "function",
"function": {
"name": "finish",
"description": "Nothing else to do, call this function.",
"parameters": {
"type": "object",
"properties": {}
}
}
}
]
`
type EventLocationAgent struct {
client AgentClient
eventModel models.EventModel
locationModel models.LocationModel
}
type CreateLocationArguments struct {
Name string `json:"name"`
Address *string `json:"address,omitempty"`
Coordinates *string `json:"coordinates,omitempty"`
}
type AttachImageLocationArguments struct {
LocationId string `json:"locationId"`
}
func (agent EventLocationAgent) HandleToolCall(userId uuid.UUID, imageId uuid.UUID, request *AgentRequestBody) error {
latestMessage := request.AgentMessages.Messages[len(request.AgentMessages.Messages)-1]
toolCall, ok := latestMessage.(AgentAssistantToolCall)
if !ok {
return errors.New("Latest message is not a tool_call")
}
for _, call := range toolCall.ToolCalls {
if call.Function.Name == "listLocations" {
log.Println("Function call: listLocations")
locations, err := agent.locationModel.List(context.Background(), userId)
if err != nil {
return err
}
jsonLocations, err := json.Marshal(locations)
if err != nil {
return err
}
request.AddText(AgentTextMessage{
Role: "tool",
Name: "listLocations",
Content: string(jsonLocations),
ToolCallId: call.Id,
})
} else if call.Function.Name == "createLocation" {
log.Println("Function call: createLocation")
locationArguments := CreateLocationArguments{}
err := json.Unmarshal([]byte(call.Function.Arguments), &locationArguments)
if err != nil {
return err
}
locations, err := agent.locationModel.Save(context.Background(), []model.Locations{{
Name: locationArguments.Name,
Address: locationArguments.Address,
Coordinates: locationArguments.Coordinates,
}})
if err != nil {
return err
}
createdLocation := locations[0]
jsonCreatedLocation, err := json.Marshal(createdLocation)
if err != nil {
return err
}
request.AddText(AgentTextMessage{
Role: "tool",
Name: "createLocation",
Content: string(jsonCreatedLocation),
ToolCallId: call.Id,
})
} else if call.Function.Name == "attachImageLocation" {
log.Println("Function call: attachImageLocation")
attachLocationArguments := AttachImageLocationArguments{}
err := json.Unmarshal([]byte(call.Function.Arguments), &attachLocationArguments)
if err != nil {
return err
}
_, err = agent.locationModel.SaveToImage(context.Background(), imageId, uuid.MustParse(attachLocationArguments.LocationId))
if err != nil {
return err
}
request.AddText(AgentTextMessage{
Role: "tool",
Name: "attachImageLocation",
Content: "OK",
ToolCallId: call.Id,
})
} else if call.Function.Name == "finish" {
log.Println("Finished!")
return errors.New("hmmm, this isnt actually an error but hey")
} else {
return errors.New("Unknown tool_call")
}
}
return nil
}
func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
var tools any
err := json.Unmarshal([]byte(TOOLS), &tools)
toolChoice := "any"
request := AgentRequestBody{
Tools: &tools,
ToolChoice: &toolChoice,
Model: "pixtral-12b-2409",
Temperature: 0.3,
ResponseFormat: ResponseFormat{
Type: "text",
},
}
err = request.AddSystem(eventLocationPrompt)
if err != nil {
return err
}
log.Println(request)
request.AddImage(imageName, imageData)
_, err = agent.client.Request(&request)
if err != nil {
return err
}
err = agent.HandleToolCall(userId, imageId, &request)
for err == nil {
log.Printf("Latest message: %+v\n", request.AgentMessages.Messages[len(request.AgentMessages.Messages)-1])
response, requestError := agent.client.Request(&request)
if requestError != nil {
return requestError
}
log.Println(response)
err = agent.HandleToolCall(userId, imageId, &request)
}
return err
}
func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel) (EventLocationAgent, error) {
client, err := CreateAgentClient(eventLocationPrompt)
if err != nil {
return EventLocationAgent{}, err
}
return EventLocationAgent{
client: client,
locationModel: locationModel,
eventModel: eventModel,
}, nil
}