package models import ( "context" "database/sql" "errors" "fmt" "screenmark/screenmark/.gen/haystack/haystack/model" . "screenmark/screenmark/.gen/haystack/haystack/table" "strconv" "strings" . "github.com/go-jet/jet/v2/postgres" "github.com/google/uuid" ) type ImageModel struct { dbPool *sql.DB } type ImageData struct { model.UserImages Image model.Image } type ProcessingImageData struct { model.UserImagesToProcess Image model.Image } func (m ImageModel) Process(ctx context.Context, userId uuid.UUID, image model.Image) (model.UserImagesToProcess, error) { tx, err := m.dbPool.BeginTx(ctx, nil) if err != nil { return model.UserImagesToProcess{}, err } insertImageStmt := Image. INSERT(Image.ImageName, Image.Image). VALUES(image.ImageName, image.Image). RETURNING(Image.ID) insertedImage := model.Image{} err = insertImageStmt.QueryContext(ctx, tx, &insertedImage) if err != nil { return model.UserImagesToProcess{}, err } stmt := UserImagesToProcess. INSERT(UserImagesToProcess.UserID, UserImagesToProcess.ImageID). VALUES(userId, insertedImage.ID). RETURNING(UserImagesToProcess.AllColumns) userImage := model.UserImagesToProcess{} err = stmt.QueryContext(ctx, tx, &userImage) if err != nil { return model.UserImagesToProcess{}, err } err = tx.Commit() return userImage, err } func (m ImageModel) GetToProcess(ctx context.Context, imageId uuid.UUID) (model.UserImagesToProcess, error) { getToProcessStmt := UserImagesToProcess. SELECT(UserImagesToProcess.AllColumns). WHERE(UserImagesToProcess.ID.EQ(UUID(imageId))) images := []model.UserImagesToProcess{} err := getToProcessStmt.QueryContext(ctx, m.dbPool, &images) if len(images) != 1 { return model.UserImagesToProcess{}, errors.New(fmt.Sprintf("Expected 1, got %d\n", len(images))) } return images[0], err } func (m ImageModel) GetToProcessWithData(ctx context.Context, imageId uuid.UUID) (ProcessingImageData, error) { stmt := SELECT(UserImagesToProcess.AllColumns, Image.AllColumns). FROM( UserImagesToProcess.INNER_JOIN( Image, Image.ID.EQ(UserImagesToProcess.ImageID), ), ).WHERE(UserImagesToProcess.ID.EQ(UUID(imageId))) images := []ProcessingImageData{} err := stmt.QueryContext(ctx, m.dbPool, &images) if len(images) != 1 { return ProcessingImageData{}, errors.New(fmt.Sprintf("Expected 1, got %d\n", len(images))) } return images[0], err } func (m ImageModel) FinishProcessing(ctx context.Context, imageId uuid.UUID) (model.UserImages, error) { imageToProcess, err := m.GetToProcess(ctx, imageId) if err != nil { return model.UserImages{}, err } tx, err := m.dbPool.Begin() if err != nil { return model.UserImages{}, err } insertImageStmt := UserImages. INSERT(UserImages.UserID, UserImages.ImageID). VALUES(imageToProcess.UserID, imageToProcess.ImageID). RETURNING(UserImages.ID, UserImages.UserID, UserImages.ImageID) userImage := model.UserImages{} err = insertImageStmt.QueryContext(ctx, tx, &userImage) if err != nil { return model.UserImages{}, err } // Hacky. Update the status before removing so we can get our regular triggers // to work. updateStatusStmt := UserImagesToProcess. UPDATE(UserImagesToProcess.Status). SET(model.Progress_Complete). WHERE(UserImagesToProcess.ID.EQ(UUID(imageToProcess.ID))) _, err = updateStatusStmt.ExecContext(ctx, tx) if err != nil { return model.UserImages{}, err } // TODO: // We cannot delete the image to process because our events rely on it. // This indicates our DB structure with the two tables might need some adjusting. // Or re-doing all together perhaps. // (switching to a one table (user_images) could work) // But for now, we can just not delete the images to process and set them to complete // removeProcessingStmt := UserImagesToProcess. // DELETE(). // WHERE(UserImagesToProcess.ID.EQ(UUID(imageToProcess.ID))) // // _, err = removeProcessingStmt.ExecContext(ctx, tx) // if err != nil { // return model.UserImages{}, err // } err = tx.Commit() return userImage, err } func (m ImageModel) StartProcessing(ctx context.Context, processingImageId uuid.UUID) error { startProcessingStmt := UserImagesToProcess. UPDATE(UserImagesToProcess.Status). SET(model.Progress_InProgress). WHERE(UserImagesToProcess.ID.EQ(UUID(processingImageId))) _, err := startProcessingStmt.ExecContext(ctx, m.dbPool) return err } func (m ImageModel) Get(ctx context.Context, imageId uuid.UUID) (model.Image, error) { getImageStmt := Image.SELECT(Image.AllColumns). WHERE(Image.ID.EQ(UUID(imageId))) image := model.Image{} err := getImageStmt.QueryContext(ctx, m.dbPool, &image) return image, err } func (m ImageModel) IsUserAuthorized(ctx context.Context, imageId uuid.UUID, userId uuid.UUID) bool { getImageUserId := UserImages.SELECT(UserImages.UserID).WHERE(UserImages.ImageID.EQ(UUID(imageId))) userImage := model.UserImages{} err := getImageUserId.QueryContext(ctx, m.dbPool, &userImage) return err != nil && userImage.UserID.String() == userId.String() } func (m ImageModel) UpdateEmbedding(ctx context.Context, imageId uuid.UUID, embedding []float64) error { stringSlice := make([]string, len(embedding)) for i, f := range embedding { stringSlice[i] = strconv.FormatFloat(f, 'f', -1, 64) } embeddingString := "[" + strings.Join(stringSlice, ",") + "]" _, err := m.dbPool.ExecContext(ctx, "UPDATE haystack.image SET embedding = $1::vector WHERE id = $2", embeddingString, imageId) return err } const SIMILARITY_THRESHOLD = 0.5 func (m ImageModel) GetSimilar(ctx context.Context, imageId uuid.UUID) ([]model.Image, error) { stmt, err := m.dbPool.PrepareContext(ctx, ` WITH similarities AS ( SELECT 1 - (embedding <=> ( SELECT embedding FROM haystack.image WHERE ID = $1 )) AS cosine_similarity, id, image_name FROM haystack.image WHERE embedding IS NOT NULL ) SELECT id, image_name FROM similarities WHERE cosine_similarity > $2 `) if err != nil { return []model.Image{}, err } images := []model.Image{} rows, err := stmt.QueryContext(ctx, imageId, SIMILARITY_THRESHOLD) for rows.Next() { image := model.Image{} err := rows.Scan(&image.ID, &image.ImageName) if err != nil { return []model.Image{}, err } images = append(images, image) } return images, nil } func NewImageModel(db *sql.DB) ImageModel { return ImageModel{dbPool: db} }