243 lines
6.2 KiB
Go
243 lines
6.2 KiB
Go
package agents
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"log"
|
|
"screenmark/screenmark/.gen/haystack/haystack/model"
|
|
"screenmark/screenmark/models"
|
|
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
const eventLocationPrompt = `
|
|
You are an agent that extracts events and locations from an image.
|
|
Your job is to check if an image has an event or a location and use the correct tools to extract this information.
|
|
|
|
If you find an event, you should look for a location for this event on the image, it is possible an event doesn't have a location.
|
|
Only create an event if you see an event on the image, not all locations have an associated event.
|
|
DO NOT CREATE EVENTS
|
|
|
|
It is possible that there is no location or event on an image.
|
|
|
|
You should ask for a list of locations, as the user is likely to have this location saved. Reuse existing locations where possible.
|
|
`
|
|
|
|
// TODO: this should be read directly from a file on load.
|
|
const TOOLS = `
|
|
[
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "createLocation",
|
|
"description": "Creates a location",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"name": {
|
|
"type": "string"
|
|
},
|
|
"coordinates": {
|
|
"type": "string"
|
|
},
|
|
"address": {
|
|
"type": "string"
|
|
}
|
|
},
|
|
"required": ["name"]
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "listLocations",
|
|
"description": "Lists the locations available",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {}
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "createEvent",
|
|
"description": "Creates a new event",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"name": {
|
|
"type": "string"
|
|
},
|
|
"datetime": {
|
|
"type": "string"
|
|
},
|
|
"locationId": {
|
|
"type": "string",
|
|
"description": "The ID of the location, available by listLocations"
|
|
}
|
|
},
|
|
"required": ["name"]
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "finish",
|
|
"description": "Nothing else to do, call this function.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {}
|
|
}
|
|
}
|
|
}
|
|
]
|
|
`
|
|
|
|
type EventLocationAgent struct {
|
|
client AgentClient
|
|
|
|
eventModel models.EventModel
|
|
locationModel models.LocationModel
|
|
}
|
|
|
|
type CreateLocationArguments struct {
|
|
Name string `json:"name"`
|
|
Address *string `json:"address,omitempty"`
|
|
Coordinates *string `json:"coordinates,omitempty"`
|
|
}
|
|
|
|
func (agent EventLocationAgent) HandleToolCall(userId uuid.UUID, request *AgentRequestBody) error {
|
|
latestMessage := request.AgentMessages.Messages[len(request.AgentMessages.Messages)-1]
|
|
|
|
toolCall, ok := latestMessage.(AgentAssistantToolCall)
|
|
|
|
if !ok {
|
|
return errors.New("Latest message is not a tool_call")
|
|
}
|
|
|
|
for _, call := range toolCall.ToolCalls {
|
|
if call.Function.Name == "listLocations" {
|
|
log.Println("Function call: listLocations")
|
|
|
|
locations, err := agent.locationModel.List(context.Background(), userId)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
jsonLocations, err := json.Marshal(locations)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
request.AddText(AgentTextMessage{
|
|
Role: "tool",
|
|
Name: "listLocations",
|
|
Content: string(jsonLocations),
|
|
ToolCallId: call.Id,
|
|
})
|
|
} else if call.Function.Name == "createLocation" {
|
|
log.Println("Function call: createLocation")
|
|
|
|
locationArguments := CreateLocationArguments{}
|
|
err := json.Unmarshal([]byte(call.Function.Arguments), &locationArguments)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
locations, err := agent.locationModel.Save(context.Background(), []model.Locations{{
|
|
Name: locationArguments.Name,
|
|
Address: locationArguments.Address,
|
|
Coordinates: locationArguments.Coordinates,
|
|
}})
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
createdLocation := locations[0]
|
|
jsonCreatedLocation, err := json.Marshal(createdLocation)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
request.AddText(AgentTextMessage{
|
|
Role: "tool",
|
|
Name: "createLocation",
|
|
Content: string(jsonCreatedLocation),
|
|
ToolCallId: call.Id,
|
|
})
|
|
} else if call.Function.Name == "finish" {
|
|
log.Println("Finished!")
|
|
return errors.New("hmmm, this isnt actually an error but hey")
|
|
} else {
|
|
return errors.New("Unknown tool_call")
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageName string, imageData []byte) error {
|
|
var tools any
|
|
err := json.Unmarshal([]byte(TOOLS), &tools)
|
|
|
|
toolChoice := "any"
|
|
|
|
request := AgentRequestBody{
|
|
Tools: &tools,
|
|
ToolChoice: &toolChoice,
|
|
Model: "pixtral-12b-2409",
|
|
Temperature: 0.3,
|
|
ResponseFormat: ResponseFormat{
|
|
Type: "text",
|
|
},
|
|
}
|
|
|
|
err = request.AddSystem(eventLocationPrompt)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
log.Println(request)
|
|
|
|
request.AddImage(imageName, imageData)
|
|
|
|
_, err = agent.client.Request(&request)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = agent.HandleToolCall(userId, &request)
|
|
for err == nil {
|
|
log.Printf("Latest message: %+v\n", request.AgentMessages.Messages[len(request.AgentMessages.Messages)-1])
|
|
|
|
response, requestError := agent.client.Request(&request)
|
|
if requestError != nil {
|
|
return requestError
|
|
}
|
|
|
|
log.Println(response)
|
|
|
|
err = agent.HandleToolCall(userId, &request)
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel) (EventLocationAgent, error) {
|
|
client, err := CreateAgentClient(eventLocationPrompt)
|
|
if err != nil {
|
|
return EventLocationAgent{}, err
|
|
}
|
|
|
|
return EventLocationAgent{
|
|
client: client,
|
|
locationModel: locationModel,
|
|
eventModel: eventModel,
|
|
}, nil
|
|
}
|