feat(tool-calling) Big refactor on how tool calling is handled
these commits are too big
This commit is contained in:
18
backend/.gen/haystack/haystack/model/user_locations.go
Normal file
18
backend/.gen/haystack/haystack/model/user_locations.go
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
//
|
||||||
|
// Code generated by go-jet DO NOT EDIT.
|
||||||
|
//
|
||||||
|
// WARNING: Changes to this file may cause incorrect behavior
|
||||||
|
// and will be lost if the code is regenerated
|
||||||
|
//
|
||||||
|
|
||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
type UserLocations struct {
|
||||||
|
ID uuid.UUID `sql:"primary_key"`
|
||||||
|
LocationID uuid.UUID
|
||||||
|
UserID uuid.UUID
|
||||||
|
}
|
@ -20,6 +20,7 @@ func UseSchema(schema string) {
|
|||||||
Locations = Locations.FromSchema(schema)
|
Locations = Locations.FromSchema(schema)
|
||||||
UserImages = UserImages.FromSchema(schema)
|
UserImages = UserImages.FromSchema(schema)
|
||||||
UserImagesToProcess = UserImagesToProcess.FromSchema(schema)
|
UserImagesToProcess = UserImagesToProcess.FromSchema(schema)
|
||||||
|
UserLocations = UserLocations.FromSchema(schema)
|
||||||
UserTags = UserTags.FromSchema(schema)
|
UserTags = UserTags.FromSchema(schema)
|
||||||
Users = Users.FromSchema(schema)
|
Users = Users.FromSchema(schema)
|
||||||
}
|
}
|
||||||
|
81
backend/.gen/haystack/haystack/table/user_locations.go
Normal file
81
backend/.gen/haystack/haystack/table/user_locations.go
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
//
|
||||||
|
// Code generated by go-jet DO NOT EDIT.
|
||||||
|
//
|
||||||
|
// WARNING: Changes to this file may cause incorrect behavior
|
||||||
|
// and will be lost if the code is regenerated
|
||||||
|
//
|
||||||
|
|
||||||
|
package table
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/go-jet/jet/v2/postgres"
|
||||||
|
)
|
||||||
|
|
||||||
|
var UserLocations = newUserLocationsTable("haystack", "user_locations", "")
|
||||||
|
|
||||||
|
type userLocationsTable struct {
|
||||||
|
postgres.Table
|
||||||
|
|
||||||
|
// Columns
|
||||||
|
ID postgres.ColumnString
|
||||||
|
LocationID postgres.ColumnString
|
||||||
|
UserID postgres.ColumnString
|
||||||
|
|
||||||
|
AllColumns postgres.ColumnList
|
||||||
|
MutableColumns postgres.ColumnList
|
||||||
|
}
|
||||||
|
|
||||||
|
type UserLocationsTable struct {
|
||||||
|
userLocationsTable
|
||||||
|
|
||||||
|
EXCLUDED userLocationsTable
|
||||||
|
}
|
||||||
|
|
||||||
|
// AS creates new UserLocationsTable with assigned alias
|
||||||
|
func (a UserLocationsTable) AS(alias string) *UserLocationsTable {
|
||||||
|
return newUserLocationsTable(a.SchemaName(), a.TableName(), alias)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Schema creates new UserLocationsTable with assigned schema name
|
||||||
|
func (a UserLocationsTable) FromSchema(schemaName string) *UserLocationsTable {
|
||||||
|
return newUserLocationsTable(schemaName, a.TableName(), a.Alias())
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithPrefix creates new UserLocationsTable with assigned table prefix
|
||||||
|
func (a UserLocationsTable) WithPrefix(prefix string) *UserLocationsTable {
|
||||||
|
return newUserLocationsTable(a.SchemaName(), prefix+a.TableName(), a.TableName())
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithSuffix creates new UserLocationsTable with assigned table suffix
|
||||||
|
func (a UserLocationsTable) WithSuffix(suffix string) *UserLocationsTable {
|
||||||
|
return newUserLocationsTable(a.SchemaName(), a.TableName()+suffix, a.TableName())
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUserLocationsTable(schemaName, tableName, alias string) *UserLocationsTable {
|
||||||
|
return &UserLocationsTable{
|
||||||
|
userLocationsTable: newUserLocationsTableImpl(schemaName, tableName, alias),
|
||||||
|
EXCLUDED: newUserLocationsTableImpl("", "excluded", ""),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUserLocationsTableImpl(schemaName, tableName, alias string) userLocationsTable {
|
||||||
|
var (
|
||||||
|
IDColumn = postgres.StringColumn("id")
|
||||||
|
LocationIDColumn = postgres.StringColumn("location_id")
|
||||||
|
UserIDColumn = postgres.StringColumn("user_id")
|
||||||
|
allColumns = postgres.ColumnList{IDColumn, LocationIDColumn, UserIDColumn}
|
||||||
|
mutableColumns = postgres.ColumnList{LocationIDColumn, UserIDColumn}
|
||||||
|
)
|
||||||
|
|
||||||
|
return userLocationsTable{
|
||||||
|
Table: postgres.NewTable(schemaName, tableName, alias, allColumns...),
|
||||||
|
|
||||||
|
//Columns
|
||||||
|
ID: IDColumn,
|
||||||
|
LocationID: LocationIDColumn,
|
||||||
|
UserID: UserIDColumn,
|
||||||
|
|
||||||
|
AllColumns: allColumns,
|
||||||
|
MutableColumns: mutableColumns,
|
||||||
|
}
|
||||||
|
}
|
@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"log"
|
"log"
|
||||||
|
"reflect"
|
||||||
"screenmark/screenmark/.gen/haystack/haystack/model"
|
"screenmark/screenmark/.gen/haystack/haystack/model"
|
||||||
"screenmark/screenmark/models"
|
"screenmark/screenmark/models"
|
||||||
|
|
||||||
@ -21,6 +22,10 @@ Only create an event if you see an event on the image, not all locations have an
|
|||||||
It is possible that there is no location or event on an image.
|
It is possible that there is no location or event on an image.
|
||||||
|
|
||||||
You should ask for a list of locations, as the user is likely to have this location saved. Reuse existing locations where possible.
|
You should ask for a list of locations, as the user is likely to have this location saved. Reuse existing locations where possible.
|
||||||
|
|
||||||
|
Do not create an event if you don't see any dates, or a name indicating an event.
|
||||||
|
|
||||||
|
Always reuse existing locations from listLocations . Do not create duplicates.
|
||||||
`
|
`
|
||||||
|
|
||||||
// TODO: this should be read directly from a file on load.
|
// TODO: this should be read directly from a file on load.
|
||||||
@ -30,7 +35,7 @@ const TOOLS = `
|
|||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": "createLocation",
|
"name": "createLocation",
|
||||||
"description": "Creates a location",
|
"description": "Creates a location. No not use if you think an existing location is suitable!",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
@ -63,12 +68,13 @@ const TOOLS = `
|
|||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": "attachImageLocation",
|
"name": "attachImageLocation",
|
||||||
"description": "Add a location to an image",
|
"description": "Add a location to an image. You must use UUID.",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"locationId": {
|
"locationId": {
|
||||||
"type": "string"
|
"type": "string",
|
||||||
|
"description": "UUID of an existing location, you can use listLocations to get values, or use the return value of createLocation"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"required": ["locationId"]
|
"required": ["locationId"]
|
||||||
@ -117,8 +123,12 @@ type EventLocationAgent struct {
|
|||||||
|
|
||||||
eventModel models.EventModel
|
eventModel models.EventModel
|
||||||
locationModel models.LocationModel
|
locationModel models.LocationModel
|
||||||
|
|
||||||
|
toolHandler ToolsHandlers
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ListLocationArguments struct{}
|
||||||
|
|
||||||
type CreateLocationArguments struct {
|
type CreateLocationArguments struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Address *string `json:"address,omitempty"`
|
Address *string `json:"address,omitempty"`
|
||||||
@ -129,97 +139,6 @@ type AttachImageLocationArguments struct {
|
|||||||
LocationId string `json:"locationId"`
|
LocationId string `json:"locationId"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (agent EventLocationAgent) HandleToolCall(userId uuid.UUID, imageId uuid.UUID, request *AgentRequestBody) error {
|
|
||||||
latestMessage := request.AgentMessages.Messages[len(request.AgentMessages.Messages)-1]
|
|
||||||
|
|
||||||
toolCall, ok := latestMessage.(AgentAssistantToolCall)
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
return errors.New("Latest message is not a tool_call")
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, call := range toolCall.ToolCalls {
|
|
||||||
if call.Function.Name == "listLocations" {
|
|
||||||
log.Println("Function call: listLocations")
|
|
||||||
|
|
||||||
locations, err := agent.locationModel.List(context.Background(), userId)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonLocations, err := json.Marshal(locations)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
request.AddText(AgentTextMessage{
|
|
||||||
Role: "tool",
|
|
||||||
Name: "listLocations",
|
|
||||||
Content: string(jsonLocations),
|
|
||||||
ToolCallId: call.Id,
|
|
||||||
})
|
|
||||||
} else if call.Function.Name == "createLocation" {
|
|
||||||
log.Println("Function call: createLocation")
|
|
||||||
|
|
||||||
locationArguments := CreateLocationArguments{}
|
|
||||||
err := json.Unmarshal([]byte(call.Function.Arguments), &locationArguments)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
locations, err := agent.locationModel.Save(context.Background(), []model.Locations{{
|
|
||||||
Name: locationArguments.Name,
|
|
||||||
Address: locationArguments.Address,
|
|
||||||
Coordinates: locationArguments.Coordinates,
|
|
||||||
}})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
createdLocation := locations[0]
|
|
||||||
jsonCreatedLocation, err := json.Marshal(createdLocation)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
request.AddText(AgentTextMessage{
|
|
||||||
Role: "tool",
|
|
||||||
Name: "createLocation",
|
|
||||||
Content: string(jsonCreatedLocation),
|
|
||||||
ToolCallId: call.Id,
|
|
||||||
})
|
|
||||||
} else if call.Function.Name == "attachImageLocation" {
|
|
||||||
log.Println("Function call: attachImageLocation")
|
|
||||||
|
|
||||||
attachLocationArguments := AttachImageLocationArguments{}
|
|
||||||
err := json.Unmarshal([]byte(call.Function.Arguments), &attachLocationArguments)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = agent.locationModel.SaveToImage(context.Background(), imageId, uuid.MustParse(attachLocationArguments.LocationId))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
request.AddText(AgentTextMessage{
|
|
||||||
Role: "tool",
|
|
||||||
Name: "attachImageLocation",
|
|
||||||
Content: "OK",
|
|
||||||
ToolCallId: call.Id,
|
|
||||||
})
|
|
||||||
} else if call.Function.Name == "finish" {
|
|
||||||
log.Println("Finished!")
|
|
||||||
return errors.New("hmmm, this isnt actually an error but hey")
|
|
||||||
} else {
|
|
||||||
return errors.New("Unknown tool_call")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
|
func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
|
||||||
var tools any
|
var tools any
|
||||||
err := json.Unmarshal([]byte(TOOLS), &tools)
|
err := json.Unmarshal([]byte(TOOLS), &tools)
|
||||||
@ -241,8 +160,6 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Println(request)
|
|
||||||
|
|
||||||
request.AddImage(imageName, imageData)
|
request.AddImage(imageName, imageData)
|
||||||
|
|
||||||
_, err = agent.client.Request(&request)
|
_, err = agent.client.Request(&request)
|
||||||
@ -250,7 +167,13 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = agent.HandleToolCall(userId, imageId, &request)
|
toolHandlerInfo := ToolHandlerInfo{
|
||||||
|
imageId: imageId,
|
||||||
|
userId: userId,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = agent.toolHandler.Handle(toolHandlerInfo, &request)
|
||||||
|
|
||||||
for err == nil {
|
for err == nil {
|
||||||
log.Printf("Latest message: %+v\n", request.AgentMessages.Messages[len(request.AgentMessages.Messages)-1])
|
log.Printf("Latest message: %+v\n", request.AgentMessages.Messages[len(request.AgentMessages.Messages)-1])
|
||||||
|
|
||||||
@ -261,21 +184,135 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID
|
|||||||
|
|
||||||
log.Println(response)
|
log.Println(response)
|
||||||
|
|
||||||
err = agent.HandleToolCall(userId, imageId, &request)
|
a, innerErr := agent.toolHandler.Handle(toolHandlerInfo, &request)
|
||||||
|
|
||||||
|
err = innerErr
|
||||||
|
|
||||||
|
log.Println(a)
|
||||||
|
log.Println("--------------------------")
|
||||||
}
|
}
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (handler ToolsHandlers) Handle(info ToolHandlerInfo, request *AgentRequestBody) (string, error) {
|
||||||
|
agentMessage := request.Messages[len(request.Messages)-1]
|
||||||
|
|
||||||
|
toolCall, ok := agentMessage.(AgentAssistantToolCall)
|
||||||
|
if !ok {
|
||||||
|
return "", errors.New("Latest message was not a tool call.")
|
||||||
|
}
|
||||||
|
|
||||||
|
fnName := toolCall.ToolCalls[0].Function.Name
|
||||||
|
arguments := toolCall.ToolCalls[0].Function.Arguments
|
||||||
|
|
||||||
|
if fnName == "finish" {
|
||||||
|
return "", errors.New("This is the end! Maybe we just return a boolean.")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn, exists := handler.Handlers[fnName]
|
||||||
|
if !exists {
|
||||||
|
return "", errors.New("Could not find tool with this name.")
|
||||||
|
}
|
||||||
|
|
||||||
|
// holy jesus what the fuck.
|
||||||
|
parseMethod := reflect.ValueOf(fn).Field(1)
|
||||||
|
if !parseMethod.IsValid() {
|
||||||
|
return "", errors.New("Parse method not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
parsedArgs := parseMethod.Call([]reflect.Value{reflect.ValueOf(arguments)})
|
||||||
|
if !parsedArgs[1].IsNil() {
|
||||||
|
return "", parsedArgs[1].Interface().(error)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("Calling: %s\n", fnName)
|
||||||
|
|
||||||
|
fnMethod := reflect.ValueOf(fn).Field(2)
|
||||||
|
if !fnMethod.IsValid() {
|
||||||
|
return "", errors.New("Fn method not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
response := fnMethod.Call([]reflect.Value{reflect.ValueOf(info), parsedArgs[0], reflect.ValueOf(toolCall.ToolCalls[0])})
|
||||||
|
if !response[1].IsNil() {
|
||||||
|
return "", response[1].Interface().(error)
|
||||||
|
}
|
||||||
|
|
||||||
|
stringResponse, err := json.Marshal(response[0].Interface())
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
request.AddText(AgentTextMessage{
|
||||||
|
Role: "tool",
|
||||||
|
Name: "createLocation",
|
||||||
|
Content: string(stringResponse),
|
||||||
|
ToolCallId: toolCall.ToolCalls[0].Id,
|
||||||
|
})
|
||||||
|
|
||||||
|
return string(stringResponse), nil
|
||||||
|
}
|
||||||
|
|
||||||
func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel) (EventLocationAgent, error) {
|
func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel) (EventLocationAgent, error) {
|
||||||
client, err := CreateAgentClient(eventLocationPrompt)
|
client, err := CreateAgentClient(eventLocationPrompt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return EventLocationAgent{}, err
|
return EventLocationAgent{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return EventLocationAgent{
|
agent := EventLocationAgent{
|
||||||
client: client,
|
client: client,
|
||||||
locationModel: locationModel,
|
locationModel: locationModel,
|
||||||
eventModel: eventModel,
|
eventModel: eventModel,
|
||||||
}, nil
|
}
|
||||||
|
|
||||||
|
toolHandler := ToolsHandlers{
|
||||||
|
Handlers: make(map[string]ToolHandlerInterface),
|
||||||
|
}
|
||||||
|
|
||||||
|
toolHandler.Handlers["listLocations"] = ToolHandler[ListLocationArguments, []model.Locations]{
|
||||||
|
FunctionName: "listLocations",
|
||||||
|
Parse: func(stringArgs string) (ListLocationArguments, error) {
|
||||||
|
args := ListLocationArguments{}
|
||||||
|
err := json.Unmarshal([]byte(stringArgs), &args)
|
||||||
|
|
||||||
|
return args, err
|
||||||
|
},
|
||||||
|
Fn: func(info ToolHandlerInfo, _args ListLocationArguments, call ToolCall) ([]model.Locations, error) {
|
||||||
|
return agent.locationModel.List(context.Background(), info.userId)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
toolHandler.Handlers["createLocation"] = ToolHandler[CreateLocationArguments, model.Locations]{
|
||||||
|
FunctionName: "createLocation",
|
||||||
|
Parse: func(stringArgs string) (CreateLocationArguments, error) {
|
||||||
|
args := CreateLocationArguments{}
|
||||||
|
err := json.Unmarshal([]byte(stringArgs), &args)
|
||||||
|
|
||||||
|
return args, err
|
||||||
|
},
|
||||||
|
Fn: func(info ToolHandlerInfo, args CreateLocationArguments, call ToolCall) (model.Locations, error) {
|
||||||
|
return agent.locationModel.Save(context.Background(), info.userId, model.Locations{
|
||||||
|
Name: args.Name,
|
||||||
|
Address: args.Address,
|
||||||
|
Coordinates: args.Coordinates,
|
||||||
|
})
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
toolHandler.Handlers["attachImageLocation"] = ToolHandler[AttachImageLocationArguments, model.ImageLocations]{
|
||||||
|
FunctionName: "attachImageLocation",
|
||||||
|
Parse: func(stringArgs string) (AttachImageLocationArguments, error) {
|
||||||
|
args := AttachImageLocationArguments{}
|
||||||
|
err := json.Unmarshal([]byte(stringArgs), &args)
|
||||||
|
|
||||||
|
return args, err
|
||||||
|
},
|
||||||
|
Fn: func(info ToolHandlerInfo, args AttachImageLocationArguments, call ToolCall) (model.ImageLocations, error) {
|
||||||
|
return agent.locationModel.SaveToImage(context.Background(), info.imageId, uuid.MustParse(args.LocationId))
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
agent.toolHandler = toolHandler
|
||||||
|
|
||||||
|
return agent, nil
|
||||||
}
|
}
|
||||||
|
26
backend/agents/tools_handler.go
Normal file
26
backend/agents/tools_handler.go
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
package agents
|
||||||
|
|
||||||
|
import "github.com/google/uuid"
|
||||||
|
|
||||||
|
type ToolHandlerInfo struct {
|
||||||
|
userId uuid.UUID
|
||||||
|
imageId uuid.UUID
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolHandler[TArgs any, TResp any] struct {
|
||||||
|
FunctionName string
|
||||||
|
Parse func(args string) (TArgs, error)
|
||||||
|
Fn func(info ToolHandlerInfo, args TArgs, call ToolCall) (TResp, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolHandlerInterface interface {
|
||||||
|
GetFunctionName() string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (handler ToolHandler[TArgs, TResp]) GetFunctionName() string {
|
||||||
|
return handler.FunctionName
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolsHandlers struct {
|
||||||
|
Handlers map[string]ToolHandlerInterface
|
||||||
|
}
|
@ -120,6 +120,13 @@ func main() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
userImage, err := imageModel.FinishProcessing(ctx, image.ID)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Failed to FinishProcessing")
|
||||||
|
log.Println(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
log.Println("Calling locationAgent!")
|
log.Println("Calling locationAgent!")
|
||||||
err = locationAgent.GetLocations(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
|
err = locationAgent.GetLocations(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
|
||||||
|
|
||||||
@ -137,13 +144,6 @@ func main() {
|
|||||||
log.Println("-----")
|
log.Println("-----")
|
||||||
log.Println(imageInfo)
|
log.Println(imageInfo)
|
||||||
|
|
||||||
userImage, err := imageModel.FinishProcessing(ctx, image.ID)
|
|
||||||
if err != nil {
|
|
||||||
log.Println("Failed to FinishProcessing")
|
|
||||||
log.Println(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Println(userImage)
|
log.Println(userImage)
|
||||||
|
|
||||||
err = tagModel.SaveToImage(ctx, userImage.ImageID, imageInfo.Tags)
|
err = tagModel.SaveToImage(ctx, userImage.ImageID, imageInfo.Tags)
|
||||||
|
@ -3,11 +3,12 @@ package models
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
. "github.com/go-jet/jet/v2/postgres"
|
|
||||||
"log"
|
"log"
|
||||||
"screenmark/screenmark/.gen/haystack/haystack/model"
|
"screenmark/screenmark/.gen/haystack/haystack/model"
|
||||||
. "screenmark/screenmark/.gen/haystack/haystack/table"
|
. "screenmark/screenmark/.gen/haystack/haystack/table"
|
||||||
|
|
||||||
|
. "github.com/go-jet/jet/v2/postgres"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -15,60 +16,37 @@ type LocationModel struct {
|
|||||||
dbPool *sql.DB
|
dbPool *sql.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
// This looks stupid
|
|
||||||
func getValues(location model.Locations) []any {
|
|
||||||
arr := make([]any, 0)
|
|
||||||
|
|
||||||
if location.Address != nil {
|
|
||||||
arr = append(arr, *location.Address)
|
|
||||||
} else {
|
|
||||||
arr = append(arr, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
if location.Coordinates != nil {
|
|
||||||
arr = append(arr, *location.Coordinates)
|
|
||||||
} else {
|
|
||||||
arr = append(arr, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
if location.Description != nil {
|
|
||||||
arr = append(arr, *location.Description)
|
|
||||||
} else {
|
|
||||||
arr = append(arr, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
return arr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m LocationModel) List(ctx context.Context, userId uuid.UUID) ([]model.Locations, error) {
|
func (m LocationModel) List(ctx context.Context, userId uuid.UUID) ([]model.Locations, error) {
|
||||||
listLocationsStmt := SELECT(Locations.AllColumns, ImageLocations.AllColumns, UserImages.AllColumns).
|
listLocationsStmt := SELECT(Locations.AllColumns).
|
||||||
FROM(
|
FROM(
|
||||||
Locations.
|
Locations.
|
||||||
INNER_JOIN(ImageLocations, ImageLocations.LocationID.EQ(Locations.ID)).
|
INNER_JOIN(UserLocations, UserLocations.LocationID.EQ(Locations.ID)),
|
||||||
INNER_JOIN(UserImages, UserImages.ImageID.EQ(ImageLocations.ImageID)),
|
).
|
||||||
).WHERE(UserImages.UserID.EQ(UUID(userId)))
|
WHERE(UserLocations.UserID.EQ(UUID(userId)))
|
||||||
|
|
||||||
locations := []model.Locations{}
|
locations := []model.Locations{}
|
||||||
|
|
||||||
err := listLocationsStmt.QueryContext(ctx, m.dbPool, &locations)
|
err := listLocationsStmt.QueryContext(ctx, m.dbPool, &locations)
|
||||||
|
|
||||||
return locations, err
|
return locations, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m LocationModel) Save(ctx context.Context, locations []model.Locations) ([]model.Locations, error) {
|
func (m LocationModel) Save(ctx context.Context, userId uuid.UUID, location model.Locations) (model.Locations, error) {
|
||||||
insertLocationStmt := Locations.
|
insertLocationStmt := Locations.
|
||||||
INSERT(Locations.Name, Locations.Address, Locations.Coordinates, Locations.Description)
|
INSERT(Locations.Name, Locations.Address, Locations.Coordinates, Locations.Description).
|
||||||
|
VALUES(location.Name, location.Address, location.Coordinates, location.Description).
|
||||||
|
RETURNING(Locations.AllColumns)
|
||||||
|
|
||||||
for _, location := range locations {
|
insertedLocation := model.Locations{}
|
||||||
insertLocationStmt = insertLocationStmt.VALUES(location.Name, getValues(location)...)
|
err := insertLocationStmt.QueryContext(ctx, m.dbPool, &insertedLocation)
|
||||||
|
if err != nil {
|
||||||
|
return model.Locations{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
insertLocationStmt = insertLocationStmt.RETURNING(Locations.AllColumns)
|
insertUserLocationStmt := UserLocations.
|
||||||
|
INSERT(UserLocations.UserID, UserLocations.LocationID).
|
||||||
|
VALUES(userId, insertedLocation.ID)
|
||||||
|
|
||||||
log.Println(insertLocationStmt.DebugSql())
|
_, err = insertUserLocationStmt.ExecContext(ctx, m.dbPool)
|
||||||
|
|
||||||
insertedLocation := []model.Locations{}
|
|
||||||
err := insertLocationStmt.QueryContext(ctx, m.dbPool, &insertedLocation)
|
|
||||||
|
|
||||||
return insertedLocation, err
|
return insertedLocation, err
|
||||||
}
|
}
|
||||||
|
@ -64,6 +64,12 @@ CREATE TABLE haystack.image_locations (
|
|||||||
image_id UUID NOT NULL REFERENCES haystack.image (id)
|
image_id UUID NOT NULL REFERENCES haystack.image (id)
|
||||||
);
|
);
|
||||||
|
|
||||||
|
CREATE TABLE haystack.user_locations (
|
||||||
|
id uuid PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
location_id UUID NOT NULL REFERENCES haystack.locations (id),
|
||||||
|
user_id UUID NOT NULL REFERENCES haystack.users (id)
|
||||||
|
);
|
||||||
|
|
||||||
CREATE TABLE haystack.events (
|
CREATE TABLE haystack.events (
|
||||||
id uuid PRIMARY KEY DEFAULT gen_random_uuid(),
|
id uuid PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user