diff --git a/backend/.gen/haystack/haystack/model/user_locations.go b/backend/.gen/haystack/haystack/model/user_locations.go new file mode 100644 index 0000000..95c3a7e --- /dev/null +++ b/backend/.gen/haystack/haystack/model/user_locations.go @@ -0,0 +1,18 @@ +// +// Code generated by go-jet DO NOT EDIT. +// +// WARNING: Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated +// + +package model + +import ( + "github.com/google/uuid" +) + +type UserLocations struct { + ID uuid.UUID `sql:"primary_key"` + LocationID uuid.UUID + UserID uuid.UUID +} diff --git a/backend/.gen/haystack/haystack/table/table_use_schema.go b/backend/.gen/haystack/haystack/table/table_use_schema.go index 8f5d510..6d0922f 100644 --- a/backend/.gen/haystack/haystack/table/table_use_schema.go +++ b/backend/.gen/haystack/haystack/table/table_use_schema.go @@ -20,6 +20,7 @@ func UseSchema(schema string) { Locations = Locations.FromSchema(schema) UserImages = UserImages.FromSchema(schema) UserImagesToProcess = UserImagesToProcess.FromSchema(schema) + UserLocations = UserLocations.FromSchema(schema) UserTags = UserTags.FromSchema(schema) Users = Users.FromSchema(schema) } diff --git a/backend/.gen/haystack/haystack/table/user_locations.go b/backend/.gen/haystack/haystack/table/user_locations.go new file mode 100644 index 0000000..3208f25 --- /dev/null +++ b/backend/.gen/haystack/haystack/table/user_locations.go @@ -0,0 +1,81 @@ +// +// Code generated by go-jet DO NOT EDIT. +// +// WARNING: Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated +// + +package table + +import ( + "github.com/go-jet/jet/v2/postgres" +) + +var UserLocations = newUserLocationsTable("haystack", "user_locations", "") + +type userLocationsTable struct { + postgres.Table + + // Columns + ID postgres.ColumnString + LocationID postgres.ColumnString + UserID postgres.ColumnString + + AllColumns postgres.ColumnList + MutableColumns postgres.ColumnList +} + +type UserLocationsTable struct { + userLocationsTable + + EXCLUDED userLocationsTable +} + +// AS creates new UserLocationsTable with assigned alias +func (a UserLocationsTable) AS(alias string) *UserLocationsTable { + return newUserLocationsTable(a.SchemaName(), a.TableName(), alias) +} + +// Schema creates new UserLocationsTable with assigned schema name +func (a UserLocationsTable) FromSchema(schemaName string) *UserLocationsTable { + return newUserLocationsTable(schemaName, a.TableName(), a.Alias()) +} + +// WithPrefix creates new UserLocationsTable with assigned table prefix +func (a UserLocationsTable) WithPrefix(prefix string) *UserLocationsTable { + return newUserLocationsTable(a.SchemaName(), prefix+a.TableName(), a.TableName()) +} + +// WithSuffix creates new UserLocationsTable with assigned table suffix +func (a UserLocationsTable) WithSuffix(suffix string) *UserLocationsTable { + return newUserLocationsTable(a.SchemaName(), a.TableName()+suffix, a.TableName()) +} + +func newUserLocationsTable(schemaName, tableName, alias string) *UserLocationsTable { + return &UserLocationsTable{ + userLocationsTable: newUserLocationsTableImpl(schemaName, tableName, alias), + EXCLUDED: newUserLocationsTableImpl("", "excluded", ""), + } +} + +func newUserLocationsTableImpl(schemaName, tableName, alias string) userLocationsTable { + var ( + IDColumn = postgres.StringColumn("id") + LocationIDColumn = postgres.StringColumn("location_id") + UserIDColumn = postgres.StringColumn("user_id") + allColumns = postgres.ColumnList{IDColumn, LocationIDColumn, UserIDColumn} + mutableColumns = postgres.ColumnList{LocationIDColumn, UserIDColumn} + ) + + return userLocationsTable{ + Table: postgres.NewTable(schemaName, tableName, alias, allColumns...), + + //Columns + ID: IDColumn, + LocationID: LocationIDColumn, + UserID: UserIDColumn, + + AllColumns: allColumns, + MutableColumns: mutableColumns, + } +} diff --git a/backend/agents/event_location_agent.go b/backend/agents/event_location_agent.go index 8ea94df..3212f67 100644 --- a/backend/agents/event_location_agent.go +++ b/backend/agents/event_location_agent.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "log" + "reflect" "screenmark/screenmark/.gen/haystack/haystack/model" "screenmark/screenmark/models" @@ -21,6 +22,10 @@ Only create an event if you see an event on the image, not all locations have an It is possible that there is no location or event on an image. You should ask for a list of locations, as the user is likely to have this location saved. Reuse existing locations where possible. + +Do not create an event if you don't see any dates, or a name indicating an event. + +Always reuse existing locations from listLocations . Do not create duplicates. ` // TODO: this should be read directly from a file on load. @@ -30,7 +35,7 @@ const TOOLS = ` "type": "function", "function": { "name": "createLocation", - "description": "Creates a location", + "description": "Creates a location. No not use if you think an existing location is suitable!", "parameters": { "type": "object", "properties": { @@ -63,12 +68,13 @@ const TOOLS = ` "type": "function", "function": { "name": "attachImageLocation", - "description": "Add a location to an image", + "description": "Add a location to an image. You must use UUID.", "parameters": { "type": "object", "properties": { "locationId": { - "type": "string" + "type": "string", + "description": "UUID of an existing location, you can use listLocations to get values, or use the return value of createLocation" } }, "required": ["locationId"] @@ -117,8 +123,12 @@ type EventLocationAgent struct { eventModel models.EventModel locationModel models.LocationModel + + toolHandler ToolsHandlers } +type ListLocationArguments struct{} + type CreateLocationArguments struct { Name string `json:"name"` Address *string `json:"address,omitempty"` @@ -129,97 +139,6 @@ type AttachImageLocationArguments struct { LocationId string `json:"locationId"` } -func (agent EventLocationAgent) HandleToolCall(userId uuid.UUID, imageId uuid.UUID, request *AgentRequestBody) error { - latestMessage := request.AgentMessages.Messages[len(request.AgentMessages.Messages)-1] - - toolCall, ok := latestMessage.(AgentAssistantToolCall) - - if !ok { - return errors.New("Latest message is not a tool_call") - } - - for _, call := range toolCall.ToolCalls { - if call.Function.Name == "listLocations" { - log.Println("Function call: listLocations") - - locations, err := agent.locationModel.List(context.Background(), userId) - if err != nil { - return err - } - - jsonLocations, err := json.Marshal(locations) - if err != nil { - return err - } - - request.AddText(AgentTextMessage{ - Role: "tool", - Name: "listLocations", - Content: string(jsonLocations), - ToolCallId: call.Id, - }) - } else if call.Function.Name == "createLocation" { - log.Println("Function call: createLocation") - - locationArguments := CreateLocationArguments{} - err := json.Unmarshal([]byte(call.Function.Arguments), &locationArguments) - if err != nil { - return err - } - - locations, err := agent.locationModel.Save(context.Background(), []model.Locations{{ - Name: locationArguments.Name, - Address: locationArguments.Address, - Coordinates: locationArguments.Coordinates, - }}) - - if err != nil { - return err - } - - createdLocation := locations[0] - jsonCreatedLocation, err := json.Marshal(createdLocation) - if err != nil { - return err - } - - request.AddText(AgentTextMessage{ - Role: "tool", - Name: "createLocation", - Content: string(jsonCreatedLocation), - ToolCallId: call.Id, - }) - } else if call.Function.Name == "attachImageLocation" { - log.Println("Function call: attachImageLocation") - - attachLocationArguments := AttachImageLocationArguments{} - err := json.Unmarshal([]byte(call.Function.Arguments), &attachLocationArguments) - if err != nil { - return err - } - - _, err = agent.locationModel.SaveToImage(context.Background(), imageId, uuid.MustParse(attachLocationArguments.LocationId)) - if err != nil { - return err - } - - request.AddText(AgentTextMessage{ - Role: "tool", - Name: "attachImageLocation", - Content: "OK", - ToolCallId: call.Id, - }) - } else if call.Function.Name == "finish" { - log.Println("Finished!") - return errors.New("hmmm, this isnt actually an error but hey") - } else { - return errors.New("Unknown tool_call") - } - } - - return nil -} - func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error { var tools any err := json.Unmarshal([]byte(TOOLS), &tools) @@ -241,8 +160,6 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID return err } - log.Println(request) - request.AddImage(imageName, imageData) _, err = agent.client.Request(&request) @@ -250,7 +167,13 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID return err } - err = agent.HandleToolCall(userId, imageId, &request) + toolHandlerInfo := ToolHandlerInfo{ + imageId: imageId, + userId: userId, + } + + _, err = agent.toolHandler.Handle(toolHandlerInfo, &request) + for err == nil { log.Printf("Latest message: %+v\n", request.AgentMessages.Messages[len(request.AgentMessages.Messages)-1]) @@ -261,21 +184,135 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID log.Println(response) - err = agent.HandleToolCall(userId, imageId, &request) + a, innerErr := agent.toolHandler.Handle(toolHandlerInfo, &request) + + err = innerErr + + log.Println(a) + log.Println("--------------------------") } return err } +func (handler ToolsHandlers) Handle(info ToolHandlerInfo, request *AgentRequestBody) (string, error) { + agentMessage := request.Messages[len(request.Messages)-1] + + toolCall, ok := agentMessage.(AgentAssistantToolCall) + if !ok { + return "", errors.New("Latest message was not a tool call.") + } + + fnName := toolCall.ToolCalls[0].Function.Name + arguments := toolCall.ToolCalls[0].Function.Arguments + + if fnName == "finish" { + return "", errors.New("This is the end! Maybe we just return a boolean.") + } + + fn, exists := handler.Handlers[fnName] + if !exists { + return "", errors.New("Could not find tool with this name.") + } + + // holy jesus what the fuck. + parseMethod := reflect.ValueOf(fn).Field(1) + if !parseMethod.IsValid() { + return "", errors.New("Parse method not found") + } + + parsedArgs := parseMethod.Call([]reflect.Value{reflect.ValueOf(arguments)}) + if !parsedArgs[1].IsNil() { + return "", parsedArgs[1].Interface().(error) + } + + log.Printf("Calling: %s\n", fnName) + + fnMethod := reflect.ValueOf(fn).Field(2) + if !fnMethod.IsValid() { + return "", errors.New("Fn method not found") + } + + response := fnMethod.Call([]reflect.Value{reflect.ValueOf(info), parsedArgs[0], reflect.ValueOf(toolCall.ToolCalls[0])}) + if !response[1].IsNil() { + return "", response[1].Interface().(error) + } + + stringResponse, err := json.Marshal(response[0].Interface()) + if err != nil { + return "", err + } + + request.AddText(AgentTextMessage{ + Role: "tool", + Name: "createLocation", + Content: string(stringResponse), + ToolCallId: toolCall.ToolCalls[0].Id, + }) + + return string(stringResponse), nil +} + func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel) (EventLocationAgent, error) { client, err := CreateAgentClient(eventLocationPrompt) if err != nil { return EventLocationAgent{}, err } - return EventLocationAgent{ + agent := EventLocationAgent{ client: client, locationModel: locationModel, eventModel: eventModel, - }, nil + } + + toolHandler := ToolsHandlers{ + Handlers: make(map[string]ToolHandlerInterface), + } + + toolHandler.Handlers["listLocations"] = ToolHandler[ListLocationArguments, []model.Locations]{ + FunctionName: "listLocations", + Parse: func(stringArgs string) (ListLocationArguments, error) { + args := ListLocationArguments{} + err := json.Unmarshal([]byte(stringArgs), &args) + + return args, err + }, + Fn: func(info ToolHandlerInfo, _args ListLocationArguments, call ToolCall) ([]model.Locations, error) { + return agent.locationModel.List(context.Background(), info.userId) + }, + } + + toolHandler.Handlers["createLocation"] = ToolHandler[CreateLocationArguments, model.Locations]{ + FunctionName: "createLocation", + Parse: func(stringArgs string) (CreateLocationArguments, error) { + args := CreateLocationArguments{} + err := json.Unmarshal([]byte(stringArgs), &args) + + return args, err + }, + Fn: func(info ToolHandlerInfo, args CreateLocationArguments, call ToolCall) (model.Locations, error) { + return agent.locationModel.Save(context.Background(), info.userId, model.Locations{ + Name: args.Name, + Address: args.Address, + Coordinates: args.Coordinates, + }) + }, + } + + toolHandler.Handlers["attachImageLocation"] = ToolHandler[AttachImageLocationArguments, model.ImageLocations]{ + FunctionName: "attachImageLocation", + Parse: func(stringArgs string) (AttachImageLocationArguments, error) { + args := AttachImageLocationArguments{} + err := json.Unmarshal([]byte(stringArgs), &args) + + return args, err + }, + Fn: func(info ToolHandlerInfo, args AttachImageLocationArguments, call ToolCall) (model.ImageLocations, error) { + return agent.locationModel.SaveToImage(context.Background(), info.imageId, uuid.MustParse(args.LocationId)) + }, + } + + agent.toolHandler = toolHandler + + return agent, nil } diff --git a/backend/agents/tools_handler.go b/backend/agents/tools_handler.go new file mode 100644 index 0000000..5981d50 --- /dev/null +++ b/backend/agents/tools_handler.go @@ -0,0 +1,26 @@ +package agents + +import "github.com/google/uuid" + +type ToolHandlerInfo struct { + userId uuid.UUID + imageId uuid.UUID +} + +type ToolHandler[TArgs any, TResp any] struct { + FunctionName string + Parse func(args string) (TArgs, error) + Fn func(info ToolHandlerInfo, args TArgs, call ToolCall) (TResp, error) +} + +type ToolHandlerInterface interface { + GetFunctionName() string +} + +func (handler ToolHandler[TArgs, TResp]) GetFunctionName() string { + return handler.FunctionName +} + +type ToolsHandlers struct { + Handlers map[string]ToolHandlerInterface +} diff --git a/backend/main.go b/backend/main.go index edc572e..9b803d8 100644 --- a/backend/main.go +++ b/backend/main.go @@ -120,6 +120,13 @@ func main() { return } + userImage, err := imageModel.FinishProcessing(ctx, image.ID) + if err != nil { + log.Println("Failed to FinishProcessing") + log.Println(err) + return + } + log.Println("Calling locationAgent!") err = locationAgent.GetLocations(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image) @@ -137,13 +144,6 @@ func main() { log.Println("-----") log.Println(imageInfo) - userImage, err := imageModel.FinishProcessing(ctx, image.ID) - if err != nil { - log.Println("Failed to FinishProcessing") - log.Println(err) - return - } - log.Println(userImage) err = tagModel.SaveToImage(ctx, userImage.ImageID, imageInfo.Tags) diff --git a/backend/models/locations.go b/backend/models/locations.go index 7e05d54..49509bc 100644 --- a/backend/models/locations.go +++ b/backend/models/locations.go @@ -3,11 +3,12 @@ package models import ( "context" "database/sql" - . "github.com/go-jet/jet/v2/postgres" "log" "screenmark/screenmark/.gen/haystack/haystack/model" . "screenmark/screenmark/.gen/haystack/haystack/table" + . "github.com/go-jet/jet/v2/postgres" + "github.com/google/uuid" ) @@ -15,60 +16,37 @@ type LocationModel struct { dbPool *sql.DB } -// This looks stupid -func getValues(location model.Locations) []any { - arr := make([]any, 0) - - if location.Address != nil { - arr = append(arr, *location.Address) - } else { - arr = append(arr, nil) - } - - if location.Coordinates != nil { - arr = append(arr, *location.Coordinates) - } else { - arr = append(arr, nil) - } - - if location.Description != nil { - arr = append(arr, *location.Description) - } else { - arr = append(arr, nil) - } - - return arr -} - func (m LocationModel) List(ctx context.Context, userId uuid.UUID) ([]model.Locations, error) { - listLocationsStmt := SELECT(Locations.AllColumns, ImageLocations.AllColumns, UserImages.AllColumns). + listLocationsStmt := SELECT(Locations.AllColumns). FROM( Locations. - INNER_JOIN(ImageLocations, ImageLocations.LocationID.EQ(Locations.ID)). - INNER_JOIN(UserImages, UserImages.ImageID.EQ(ImageLocations.ImageID)), - ).WHERE(UserImages.UserID.EQ(UUID(userId))) + INNER_JOIN(UserLocations, UserLocations.LocationID.EQ(Locations.ID)), + ). + WHERE(UserLocations.UserID.EQ(UUID(userId))) locations := []model.Locations{} err := listLocationsStmt.QueryContext(ctx, m.dbPool, &locations) - return locations, err } -func (m LocationModel) Save(ctx context.Context, locations []model.Locations) ([]model.Locations, error) { +func (m LocationModel) Save(ctx context.Context, userId uuid.UUID, location model.Locations) (model.Locations, error) { insertLocationStmt := Locations. - INSERT(Locations.Name, Locations.Address, Locations.Coordinates, Locations.Description) + INSERT(Locations.Name, Locations.Address, Locations.Coordinates, Locations.Description). + VALUES(location.Name, location.Address, location.Coordinates, location.Description). + RETURNING(Locations.AllColumns) - for _, location := range locations { - insertLocationStmt = insertLocationStmt.VALUES(location.Name, getValues(location)...) + insertedLocation := model.Locations{} + err := insertLocationStmt.QueryContext(ctx, m.dbPool, &insertedLocation) + if err != nil { + return model.Locations{}, err } - insertLocationStmt = insertLocationStmt.RETURNING(Locations.AllColumns) + insertUserLocationStmt := UserLocations. + INSERT(UserLocations.UserID, UserLocations.LocationID). + VALUES(userId, insertedLocation.ID) - log.Println(insertLocationStmt.DebugSql()) - - insertedLocation := []model.Locations{} - err := insertLocationStmt.QueryContext(ctx, m.dbPool, &insertedLocation) + _, err = insertUserLocationStmt.ExecContext(ctx, m.dbPool) return insertedLocation, err } diff --git a/backend/schema.sql b/backend/schema.sql index a2d254e..e2e9ee1 100644 --- a/backend/schema.sql +++ b/backend/schema.sql @@ -64,6 +64,12 @@ CREATE TABLE haystack.image_locations ( image_id UUID NOT NULL REFERENCES haystack.image (id) ); +CREATE TABLE haystack.user_locations ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + location_id UUID NOT NULL REFERENCES haystack.locations (id), + user_id UUID NOT NULL REFERENCES haystack.users (id) +); + CREATE TABLE haystack.events ( id uuid PRIMARY KEY DEFAULT gen_random_uuid(),