Haystack/backend/agents/event_location_agent.go

394 lines
11 KiB
Go

package agents
import (
"context"
"encoding/json"
"errors"
"log"
"reflect"
"screenmark/screenmark/.gen/haystack/haystack/model"
"screenmark/screenmark/models"
"time"
"github.com/google/uuid"
)
// This prompt is probably shit.
const eventLocationPrompt = `
You are an agent that extracts events, locations, and organizers from an image. Your primary tasks are to identify and create locations and organizers before creating events. Follow these steps:
Identify and Create Locations:
Check if the image contains a location.
If a location is found, check if it exists in the listLocations.
If the location does not exist, create it first.
Always reuse existing locations from listLocations to avoid duplicates.
Identify and Create Events:
Check if the image contains an event. An event should have a name and a date.
If an event is found, ensure you have a location (from step 1) and an organizer (from step 2) before creating the event.
Events must have an associated location and organizer. Do not create an event without these.
If possible, return a start time and an end time as ISO datetime strings.
Handling Images Without Events or Locations:
It is possible that the image does not contain an event or a location. In such cases, do not create an event.
Always prioritize the creation of locations and organizers before events. Ensure that all events have an associated location and organizer.
`
// TODO: this should be read directly from a file on load.
const TOOLS = `
[
{
"type": "function",
"function": {
"name": "createLocation",
"description": "Creates a location. No not use if you think an existing location is suitable!",
"parameters": {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"address": {
"type": "string"
}
},
"required": ["name"]
}
}
},
{
"type": "function",
"function": {
"name": "listLocations",
"description": "Lists the locations available",
"parameters": {
"type": "object",
"properties": {}
}
}
},
{
"type": "function",
"function": {
"name": "createEvent",
"description": "Creates a new event",
"parameters": {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"startDateTime": {
"type": "string",
"description": "The start time as an ISO string"
},
"endDateTime": {
"type": "string",
"description": "The end time as an ISO string"
},
"locationId": {
"type": "string",
"description": "The ID of the location, available by listLocations"
},
"organizerName": {
"type": "string",
"description": "The name of the organizer"
}
},
"required": ["name"]
}
}
},
{
"type": "function",
"function": {
"name": "finish",
"description": "Nothing else to do, call this function.",
"parameters": {
"type": "object",
"properties": {}
}
}
}
]
`
type EventLocationAgent struct {
client AgentClient
eventModel models.EventModel
locationModel models.LocationModel
contactModel models.ContactModel
toolHandler ToolsHandlers
}
type ListLocationArguments struct{}
type ListOrganizerArguments struct{}
type CreateLocationArguments struct {
Name string `json:"name"`
Address *string `json:"address,omitempty"`
Coordinates *string `json:"coordinates,omitempty"`
}
type CreateOrganizerArguments struct {
Name string `json:"name"`
PhoneNumber *string `json:"phoneNumber,omitempty"`
Email *string `json:"email,omitempty"`
}
type AttachImageLocationArguments struct {
LocationId string `json:"locationId"`
}
type CreateEventArguments struct {
Name string `json:"name"`
StartDateTime string `json:"startDateTime"`
EndDateTime string `json:"endDateTime"`
LocationId string `json:"locationId"`
OrganizerName string `json:"organizerName"`
}
func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
var tools any
err := json.Unmarshal([]byte(TOOLS), &tools)
toolChoice := "any"
request := AgentRequestBody{
Tools: &tools,
ToolChoice: &toolChoice,
Model: "pixtral-12b-2409",
Temperature: 0.3,
ResponseFormat: ResponseFormat{
Type: "text",
},
}
err = request.AddSystem(eventLocationPrompt)
if err != nil {
return err
}
request.AddImage(imageName, imageData)
_, err = agent.client.Request(&request)
if err != nil {
return err
}
toolHandlerInfo := ToolHandlerInfo{
imageId: imageId,
userId: userId,
}
_, err = agent.toolHandler.Handle(toolHandlerInfo, &request)
for err == nil {
log.Printf("Latest message: %+v\n", request.AgentMessages.Messages[len(request.AgentMessages.Messages)-1])
response, requestError := agent.client.Request(&request)
if requestError != nil {
return requestError
}
log.Println(response)
a, innerErr := agent.toolHandler.Handle(toolHandlerInfo, &request)
err = innerErr
log.Println(a)
log.Println("--------------------------")
}
return err
}
func (handler ToolsHandlers) Handle(info ToolHandlerInfo, request *AgentRequestBody) (string, error) {
agentMessage := request.Messages[len(request.Messages)-1]
toolCall, ok := agentMessage.(AgentAssistantToolCall)
if !ok {
return "", errors.New("Latest message was not a tool call.")
}
fnName := toolCall.ToolCalls[0].Function.Name
arguments := toolCall.ToolCalls[0].Function.Arguments
if fnName == "finish" {
return "", errors.New("This is the end! Maybe we just return a boolean.")
}
fn, exists := handler.Handlers[fnName]
if !exists {
return "", errors.New("Could not find tool with this name.")
}
// holy jesus what the fuck.
parseMethod := reflect.ValueOf(fn).Field(1)
if !parseMethod.IsValid() {
return "", errors.New("Parse method not found")
}
parsedArgs := parseMethod.Call([]reflect.Value{reflect.ValueOf(arguments)})
if !parsedArgs[1].IsNil() {
return "", parsedArgs[1].Interface().(error)
}
log.Printf("Calling: %s\n", fnName)
fnMethod := reflect.ValueOf(fn).Field(2)
if !fnMethod.IsValid() {
return "", errors.New("Fn method not found")
}
response := fnMethod.Call([]reflect.Value{reflect.ValueOf(info), parsedArgs[0], reflect.ValueOf(toolCall.ToolCalls[0])})
if !response[1].IsNil() {
return "", response[1].Interface().(error)
}
stringResponse, err := json.Marshal(response[0].Interface())
if err != nil {
return "", err
}
request.AddText(AgentTextMessage{
Role: "tool",
Name: "createLocation",
Content: string(stringResponse),
ToolCallId: toolCall.ToolCalls[0].Id,
})
return string(stringResponse), nil
}
func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel, contactModel models.ContactModel) (EventLocationAgent, error) {
client, err := CreateAgentClient(eventLocationPrompt)
if err != nil {
return EventLocationAgent{}, err
}
agent := EventLocationAgent{
client: client,
locationModel: locationModel,
eventModel: eventModel,
contactModel: contactModel,
}
toolHandler := ToolsHandlers{
Handlers: make(map[string]ToolHandlerInterface),
}
toolHandler.Handlers["listLocations"] = ToolHandler[ListLocationArguments, []model.Locations]{
FunctionName: "listLocations",
Parse: func(stringArgs string) (ListLocationArguments, error) {
args := ListLocationArguments{}
err := json.Unmarshal([]byte(stringArgs), &args)
return args, err
},
Fn: func(info ToolHandlerInfo, _args ListLocationArguments, call ToolCall) ([]model.Locations, error) {
return agent.locationModel.List(context.Background(), info.userId)
},
}
toolHandler.Handlers["createLocation"] = ToolHandler[CreateLocationArguments, model.Locations]{
FunctionName: "createLocation",
Parse: func(stringArgs string) (CreateLocationArguments, error) {
args := CreateLocationArguments{}
err := json.Unmarshal([]byte(stringArgs), &args)
return args, err
},
Fn: func(info ToolHandlerInfo, args CreateLocationArguments, call ToolCall) (model.Locations, error) {
ctx := context.Background()
location, err := agent.locationModel.Save(ctx, info.userId, model.Locations{
Name: args.Name,
Address: args.Address,
})
if err != nil {
return location, err
}
_, err = agent.locationModel.SaveToImage(ctx, info.imageId, location.ID)
return location, err
},
}
toolHandler.Handlers["createEvent"] = ToolHandler[CreateEventArguments, model.Events]{
FunctionName: "createEvent",
Parse: func(stringArgs string) (CreateEventArguments, error) {
args := CreateEventArguments{}
err := json.Unmarshal([]byte(stringArgs), &args)
return args, err
},
Fn: func(info ToolHandlerInfo, args CreateEventArguments, call ToolCall) (model.Events, error) {
ctx := context.Background()
layout := "2006-01-02T15:04:05Z"
startTime, err := time.Parse(layout, args.StartDateTime)
if err != nil {
return model.Events{}, err
}
endTime, err := time.Parse(layout, args.EndDateTime)
if err != nil {
return model.Events{}, err
}
event, err := agent.eventModel.Save(ctx, info.userId, model.Events{
Name: args.Name,
StartDateTime: &startTime,
EndDateTime: &endTime,
})
if err != nil {
return event, err
}
organizer, err := agent.contactModel.Save(ctx, info.userId, model.Contacts{
Name: args.Name,
})
if err != nil {
return event, err
}
_, err = agent.eventModel.SaveToImage(ctx, info.imageId, event.ID)
if err != nil {
return event, err
}
_, err = agent.contactModel.SaveToImage(ctx, info.imageId, organizer.ID)
if err != nil {
return event, err
}
locationId, err := uuid.Parse(args.LocationId)
if err != nil {
return event, err
}
event, err = agent.eventModel.UpdateLocation(ctx, event.ID, locationId)
if err != nil {
return event, err
}
return agent.eventModel.UpdateOrganizer(ctx, event.ID, organizer.ID)
},
}
agent.toolHandler = toolHandler
return agent, nil
}