341 lines
9.4 KiB
Go
341 lines
9.4 KiB
Go
package agents
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"log"
|
|
"reflect"
|
|
"screenmark/screenmark/.gen/haystack/haystack/model"
|
|
"screenmark/screenmark/models"
|
|
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
const eventLocationPrompt = `
|
|
You are an agent that extracts events and locations from an image.
|
|
Your job is to check if an image has an event or a location and use the correct tools to extract this information.
|
|
|
|
If you find an event, you should look for a location for this event on the image, it is possible an event doesn't have a location.
|
|
Only create an event if you see an event on the image, not all locations have an associated event.
|
|
|
|
It is possible that there is no location or event on an image.
|
|
|
|
You should ask for a list of locations, as the user is likely to have this location saved. Reuse existing locations where possible.
|
|
|
|
Always reuse existing locations from listLocations. Do not create duplicates.
|
|
|
|
Do not create an event if you don't see any dates, or a name indicating an event.
|
|
|
|
Events can have an associated location, if you think there is a location, then you must either use a location from listLocations or you must create it first.
|
|
Wherever possible, find the location in the image.
|
|
`
|
|
|
|
// TODO: this should be read directly from a file on load.
|
|
const TOOLS = `
|
|
[
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "createLocation",
|
|
"description": "Creates a location. No not use if you think an existing location is suitable!",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"name": {
|
|
"type": "string"
|
|
},
|
|
"coordinates": {
|
|
"type": "string"
|
|
},
|
|
"address": {
|
|
"type": "string"
|
|
}
|
|
},
|
|
"required": ["name"]
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "listLocations",
|
|
"description": "Lists the locations available",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {}
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "createEvent",
|
|
"description": "Creates a new event",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"name": {
|
|
"type": "string"
|
|
},
|
|
"datetime": {
|
|
"type": "string"
|
|
},
|
|
"locationId": {
|
|
"type": "string",
|
|
"description": "The ID of the location, available by listLocations"
|
|
}
|
|
},
|
|
"required": ["name"]
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "finish",
|
|
"description": "Nothing else to do, call this function.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {}
|
|
}
|
|
}
|
|
}
|
|
]
|
|
`
|
|
|
|
type EventLocationAgent struct {
|
|
client AgentClient
|
|
|
|
eventModel models.EventModel
|
|
locationModel models.LocationModel
|
|
|
|
toolHandler ToolsHandlers
|
|
}
|
|
|
|
type ListLocationArguments struct{}
|
|
|
|
type CreateLocationArguments struct {
|
|
Name string `json:"name"`
|
|
Address *string `json:"address,omitempty"`
|
|
Coordinates *string `json:"coordinates,omitempty"`
|
|
}
|
|
|
|
type AttachImageLocationArguments struct {
|
|
LocationId string `json:"locationId"`
|
|
}
|
|
|
|
type CreateEventArguments struct {
|
|
Name string `json:"name"`
|
|
Datetime string `json:"datetime"`
|
|
LocationId string `json:"locationId"`
|
|
}
|
|
|
|
func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
|
|
var tools any
|
|
err := json.Unmarshal([]byte(TOOLS), &tools)
|
|
|
|
toolChoice := "any"
|
|
|
|
request := AgentRequestBody{
|
|
Tools: &tools,
|
|
ToolChoice: &toolChoice,
|
|
Model: "pixtral-12b-2409",
|
|
Temperature: 0.3,
|
|
ResponseFormat: ResponseFormat{
|
|
Type: "text",
|
|
},
|
|
}
|
|
|
|
err = request.AddSystem(eventLocationPrompt)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
request.AddImage(imageName, imageData)
|
|
|
|
_, err = agent.client.Request(&request)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
toolHandlerInfo := ToolHandlerInfo{
|
|
imageId: imageId,
|
|
userId: userId,
|
|
}
|
|
|
|
_, err = agent.toolHandler.Handle(toolHandlerInfo, &request)
|
|
|
|
for err == nil {
|
|
log.Printf("Latest message: %+v\n", request.AgentMessages.Messages[len(request.AgentMessages.Messages)-1])
|
|
|
|
response, requestError := agent.client.Request(&request)
|
|
if requestError != nil {
|
|
return requestError
|
|
}
|
|
|
|
log.Println(response)
|
|
|
|
a, innerErr := agent.toolHandler.Handle(toolHandlerInfo, &request)
|
|
|
|
err = innerErr
|
|
|
|
log.Println(a)
|
|
log.Println("--------------------------")
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
func (handler ToolsHandlers) Handle(info ToolHandlerInfo, request *AgentRequestBody) (string, error) {
|
|
agentMessage := request.Messages[len(request.Messages)-1]
|
|
|
|
toolCall, ok := agentMessage.(AgentAssistantToolCall)
|
|
if !ok {
|
|
return "", errors.New("Latest message was not a tool call.")
|
|
}
|
|
|
|
fnName := toolCall.ToolCalls[0].Function.Name
|
|
arguments := toolCall.ToolCalls[0].Function.Arguments
|
|
|
|
if fnName == "finish" {
|
|
return "", errors.New("This is the end! Maybe we just return a boolean.")
|
|
}
|
|
|
|
fn, exists := handler.Handlers[fnName]
|
|
if !exists {
|
|
return "", errors.New("Could not find tool with this name.")
|
|
}
|
|
|
|
// holy jesus what the fuck.
|
|
parseMethod := reflect.ValueOf(fn).Field(1)
|
|
if !parseMethod.IsValid() {
|
|
return "", errors.New("Parse method not found")
|
|
}
|
|
|
|
parsedArgs := parseMethod.Call([]reflect.Value{reflect.ValueOf(arguments)})
|
|
if !parsedArgs[1].IsNil() {
|
|
return "", parsedArgs[1].Interface().(error)
|
|
}
|
|
|
|
log.Printf("Calling: %s\n", fnName)
|
|
|
|
fnMethod := reflect.ValueOf(fn).Field(2)
|
|
if !fnMethod.IsValid() {
|
|
return "", errors.New("Fn method not found")
|
|
}
|
|
|
|
response := fnMethod.Call([]reflect.Value{reflect.ValueOf(info), parsedArgs[0], reflect.ValueOf(toolCall.ToolCalls[0])})
|
|
if !response[1].IsNil() {
|
|
return "", response[1].Interface().(error)
|
|
}
|
|
|
|
stringResponse, err := json.Marshal(response[0].Interface())
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
request.AddText(AgentTextMessage{
|
|
Role: "tool",
|
|
Name: "createLocation",
|
|
Content: string(stringResponse),
|
|
ToolCallId: toolCall.ToolCalls[0].Id,
|
|
})
|
|
|
|
return string(stringResponse), nil
|
|
}
|
|
|
|
func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel) (EventLocationAgent, error) {
|
|
client, err := CreateAgentClient(eventLocationPrompt)
|
|
if err != nil {
|
|
return EventLocationAgent{}, err
|
|
}
|
|
|
|
agent := EventLocationAgent{
|
|
client: client,
|
|
locationModel: locationModel,
|
|
eventModel: eventModel,
|
|
}
|
|
|
|
toolHandler := ToolsHandlers{
|
|
Handlers: make(map[string]ToolHandlerInterface),
|
|
}
|
|
|
|
toolHandler.Handlers["listLocations"] = ToolHandler[ListLocationArguments, []model.Locations]{
|
|
FunctionName: "listLocations",
|
|
Parse: func(stringArgs string) (ListLocationArguments, error) {
|
|
args := ListLocationArguments{}
|
|
err := json.Unmarshal([]byte(stringArgs), &args)
|
|
|
|
return args, err
|
|
},
|
|
Fn: func(info ToolHandlerInfo, _args ListLocationArguments, call ToolCall) ([]model.Locations, error) {
|
|
return agent.locationModel.List(context.Background(), info.userId)
|
|
},
|
|
}
|
|
|
|
toolHandler.Handlers["createLocation"] = ToolHandler[CreateLocationArguments, model.Locations]{
|
|
FunctionName: "createLocation",
|
|
Parse: func(stringArgs string) (CreateLocationArguments, error) {
|
|
args := CreateLocationArguments{}
|
|
err := json.Unmarshal([]byte(stringArgs), &args)
|
|
|
|
return args, err
|
|
},
|
|
Fn: func(info ToolHandlerInfo, args CreateLocationArguments, call ToolCall) (model.Locations, error) {
|
|
return agent.locationModel.Save(context.Background(), info.userId, model.Locations{
|
|
Name: args.Name,
|
|
Address: args.Address,
|
|
Coordinates: args.Coordinates,
|
|
})
|
|
},
|
|
}
|
|
|
|
// I'm not sure this one actually makes sense either.
|
|
// I think the earlier tool can do more.
|
|
toolHandler.Handlers["attachImageLocation"] = ToolHandler[AttachImageLocationArguments, model.ImageLocations]{
|
|
FunctionName: "attachImageLocation",
|
|
Parse: func(stringArgs string) (AttachImageLocationArguments, error) {
|
|
args := AttachImageLocationArguments{}
|
|
err := json.Unmarshal([]byte(stringArgs), &args)
|
|
|
|
return args, err
|
|
},
|
|
Fn: func(info ToolHandlerInfo, args AttachImageLocationArguments, call ToolCall) (model.ImageLocations, error) {
|
|
return agent.locationModel.SaveToImage(context.Background(), info.imageId, uuid.MustParse(args.LocationId))
|
|
},
|
|
}
|
|
|
|
toolHandler.Handlers["createEvent"] = ToolHandler[CreateEventArguments, model.Events]{
|
|
FunctionName: "createEvent",
|
|
Parse: func(stringArgs string) (CreateEventArguments, error) {
|
|
args := CreateEventArguments{}
|
|
err := json.Unmarshal([]byte(stringArgs), &args)
|
|
|
|
return args, err
|
|
},
|
|
Fn: func(info ToolHandlerInfo, args CreateEventArguments, call ToolCall) (model.Events, error) {
|
|
ctx := context.Background()
|
|
|
|
event, err := agent.eventModel.Save(ctx, info.userId, model.Events{
|
|
Name: args.Name,
|
|
})
|
|
|
|
if err != nil {
|
|
return event, err
|
|
}
|
|
|
|
locationId, err := uuid.Parse(args.LocationId)
|
|
if err != nil {
|
|
return event, err
|
|
}
|
|
|
|
return agent.eventModel.UpdateLocation(ctx, event.ID, locationId)
|
|
},
|
|
}
|
|
|
|
agent.toolHandler = toolHandler
|
|
|
|
return agent, nil
|
|
}
|