From 8acf25a2a7b67ba9ecd888b57378c3c1cfd6f879 Mon Sep 17 00:00:00 2001 From: John Costa Date: Sun, 16 Mar 2025 18:13:30 +0000 Subject: [PATCH] refactor(models): using more organised structure --- backend/main.go | 51 +++--- backend/models/database.go | 12 +- backend/models/image.go | 316 +++++++++++++++---------------------- backend/models/links.go | 29 ++++ backend/models/tags.go | 85 ++++++---- backend/models/text.go | 31 ++++ backend/models/user.go | 32 ++++ backend/openai_test.go | 4 +- 8 files changed, 310 insertions(+), 250 deletions(-) create mode 100644 backend/models/links.go create mode 100644 backend/models/text.go create mode 100644 backend/models/user.go diff --git a/backend/main.go b/backend/main.go index 5f28569..108b29a 100644 --- a/backend/main.go +++ b/backend/main.go @@ -2,6 +2,7 @@ package main import ( "bytes" + "context" "encoding/base64" "encoding/json" "fmt" @@ -10,11 +11,13 @@ import ( "net/http" "os" "path/filepath" + "screenmark/screenmark/.gen/haystack/haystack/model" "screenmark/screenmark/models" "time" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" + "github.com/google/uuid" "github.com/joho/godotenv" "github.com/lib/pq" ) @@ -51,11 +54,16 @@ func main() { mode := os.Getenv("MODE") log.Printf("Mode: %s\n", mode) - err = models.InitDatabase() + db, err := models.InitDatabase() if err != nil { panic(err) } + imageModel := models.NewImageModel(db) + linkModel := models.NewLinkModel(db) + tagModel := models.NewTagModel(db) + textModel := models.NewTextModel(db) + listener := pq.NewListener(os.Getenv("DB_CONNECTION"), time.Second, time.Second, func(event pq.ListenerEventType, err error) { if err != nil { panic(err) @@ -73,9 +81,9 @@ func main() { select { case parameters := <-listener.Notify: - imageId := parameters.Extra + imageId := uuid.MustParse(parameters.Extra) - log.Println("received notification, new image available: " + imageId) + ctx := context.Background() go func() { openAiClient, err := GetAiClient() @@ -83,46 +91,34 @@ func main() { panic(err) } - image, err := models.GetImageToProcessWithData(imageId) + image, err := imageModel.GetToProcessWithData(ctx, imageId) if err != nil { - log.Println("1") - log.Println(err) return } imageInfo, err := openAiClient.GetImageInfo(image.Image.ImageName, image.Image.Image) if err != nil { - log.Println("2") - log.Println(err) return } - savedImage, err := models.SaveImage(image.ID) + savedImage, err := imageModel.FinishProcessing(ctx, image.ID) if err != nil { - log.Println("3") - log.Println(err) return } - log.Println("Finished processing image " + imageId) - log.Printf("Image attributes: %+v\n", imageInfo) - - _, err = models.SaveImageTags(savedImage.ID.String(), imageInfo.Tags) + err = tagModel.SaveToImage(ctx, savedImage.ID, imageInfo.Tags) if err != nil { - log.Println("1") - log.Println(err) + return } - _, err = models.SaveImageLinks(savedImage.ID.String(), imageInfo.Links) + err = linkModel.Save(ctx, savedImage.ID, imageInfo.Links) if err != nil { - log.Println("2") - log.Println(err) + return } - _, err = models.SaveImageTexts(savedImage.ID.String(), imageInfo.Text) + err = textModel.Save(ctx, savedImage.ID, imageInfo.Text) if err != nil { - log.Println("3") - log.Println(err) + return } }() } @@ -153,7 +149,7 @@ func main() { w.Header().Add("Access-Control-Allow-Credentials", "*") w.Header().Add("Access-Control-Allow-Headers", "*") - images, err := models.GetUserImages(userId) + images, err := imageModel.ListWithProperties(r.Context(), uuid.MustParse(userId)) if err != nil { log.Println(err) w.WriteHeader(http.StatusNotFound) @@ -180,7 +176,7 @@ func main() { w.Header().Add("Access-Control-Allow-Headers", "*") // TODO: really need authorization here! - image, err := models.GetImage(imageId) + image, err := imageModel.Get(r.Context(), uuid.MustParse(imageId)) if err != nil { log.Println(err) w.WriteHeader(http.StatusNotFound) @@ -259,7 +255,10 @@ func main() { return } - userImage, err := models.SaveImageToProcess(userId, imageName, image) + userImage, err := imageModel.Process(r.Context(), uuid.MustParse(userId), model.Image{ + Image: image, + ImageName: imageName, + }) if err != nil { log.Println("Second case") log.Println(err) diff --git a/backend/models/database.go b/backend/models/database.go index fb6f08b..2f5a0ca 100644 --- a/backend/models/database.go +++ b/backend/models/database.go @@ -8,18 +8,12 @@ import ( _ "github.com/lib/pq" ) -var db *sql.DB - -func InitDatabase() error { +func InitDatabase() (*sql.DB, error) { connection := os.Getenv("DB_CONNECTION") if len(connection) == 0 { - return errors.New("DB_CONNECTION env was not found.") + return nil, errors.New("DB_CONNECTION env was not found.") } - database, err := sql.Open("postgres", connection) - - db = database - - return err + return sql.Open("postgres", connection) } diff --git a/backend/models/image.go b/backend/models/image.go index d0d5f61..5fc098a 100644 --- a/backend/models/image.go +++ b/backend/models/image.go @@ -1,6 +1,8 @@ package models import ( + "context" + "database/sql" "errors" "fmt" "screenmark/screenmark/.gen/haystack/haystack/model" @@ -11,76 +13,8 @@ import ( "github.com/google/uuid" ) -func SaveImageToProcess(userId string, imageName string, imageData []byte) (model.UserImagesToProcess, error) { - insertImageStmt := Image.INSERT(Image.ImageName, Image.Image).VALUES(imageName, imageData).RETURNING(Image.ID) - - // TODO: should be a transaction - - image := model.Image{} - err := insertImageStmt.Query(db, &image) - if err != nil { - return model.UserImagesToProcess{}, err - } - - stmt := UserImagesToProcess.INSERT(UserImagesToProcess.UserID, UserImagesToProcess.ImageID).VALUES(userId, image.ID).RETURNING(UserImagesToProcess.AllColumns) - - fmt.Println(stmt.DebugSql()) - - userImage := model.UserImagesToProcess{} - err = stmt.Query(db, &userImage) - - return userImage, err -} - -func removeImageToProcess(imageId string) error { - id := uuid.MustParse(imageId) - - stmt := UserImagesToProcess.DELETE().WHERE(UserImagesToProcess.ID.EQ(UUID(id))) - - _, err := stmt.Exec(db) - - return err -} - -func getUserId(imageId uuid.UUID) (uuid.UUID, error) { - stmt := UserImages.SELECT(UserImages.UserID).WHERE(UserImages.ID.EQ(UUID(imageId))) - - fmt.Println(stmt.DebugSql()) - - userIds := make([]string, 0) - - err := stmt.Query(db, &userIds) - if err != nil { - return uuid.Nil, err - } - - if len(userIds) != 1 { - return uuid.Nil, errors.New("expect only one user id per image id") - } - - return uuid.Parse(userIds[0]) -} - -func SaveImage(imageId uuid.UUID) (model.UserImages, error) { - imageToProcess, err := GetImageToProcess(imageId.String()) - if err != nil { - return model.UserImages{}, err - } - - stmt := UserImages.INSERT(UserImages.UserID, UserImages.ImageID).VALUES(imageToProcess.UserID, imageToProcess.ImageID).RETURNING(UserImages.ID, UserImages.UserID, UserImages.ImageID) - - userImage := model.UserImages{} - err = stmt.Query(db, &userImage) - if err != nil { - return model.UserImages{}, err - } - - err = removeImageToProcess(imageId.String()) - if err != nil { - return model.UserImages{}, err - } - - return userImage, err +type ImageModel struct { + dbPool *sql.DB } type ImageData struct { @@ -89,49 +23,62 @@ type ImageData struct { Image model.Image } -func GetImage(imageId string) (ImageData, error) { - id := uuid.MustParse(imageId) - stmt := SELECT(UserImages.AllColumns, Image.AllColumns).FROM(UserImages.INNER_JOIN(Image, Image.ID.EQ(UserImages.ImageID))).WHERE(UserImages.ID.EQ(UUID(id))) - - images := []ImageData{} - err := stmt.Query(db, &images) - - if len(images) != 1 { - return ImageData{}, errors.New(fmt.Sprintf("Expected 1, got %d\n", len(images))) - } - - return images[0], err -} - -type ImageToProcessData struct { +type ProcessingImageData struct { model.UserImagesToProcess Image model.Image } -func GetImageToProcessWithData(imageId string) (ImageToProcessData, error) { - id := uuid.MustParse(imageId) - // stmt := UserImagesToProcess.SELECT(UserImages.AllColumns).WHERE(UserImages.ID.EQ(UUID(id))) +type ImageWithProperties struct { + ID uuid.UUID - // TODO: Image should be `Images` - stmt := SELECT(UserImagesToProcess.AllColumns, Image.AllColumns).FROM(UserImagesToProcess.INNER_JOIN(Image, Image.ID.EQ(UserImagesToProcess.ImageID))).WHERE(UserImagesToProcess.ID.EQ(UUID(id))) + Image model.UserImages - images := []ImageToProcessData{} - err := stmt.Query(db, &images) - - if len(images) != 1 { - return ImageToProcessData{}, errors.New(fmt.Sprintf("Expected 1, got %d\n", len(images))) - } - - return images[0], err + Tags []model.ImageTags + Links []model.ImageLinks + Text []model.ImageText } -func GetImageToProcess(imageId string) (model.UserImagesToProcess, error) { - id := uuid.MustParse(imageId) - stmt := UserImagesToProcess.SELECT(UserImagesToProcess.AllColumns).WHERE(UserImagesToProcess.ID.EQ(UUID(id))) +func (m ImageModel) Process(ctx context.Context, userId uuid.UUID, image model.Image) (model.UserImagesToProcess, error) { + tx, err := m.dbPool.BeginTx(ctx, nil) + if err != nil { + return model.UserImagesToProcess{}, err + } + + insertImageStmt := Image. + INSERT(Image.ImageName, Image.Image). + VALUES(image.ImageName, image.Image). + RETURNING(Image.ID) + + insertedImage := model.Image{} + err = insertImageStmt.QueryContext(ctx, tx, &insertedImage) + if err != nil { + return model.UserImagesToProcess{}, err + } + + stmt := UserImagesToProcess. + INSERT(UserImagesToProcess.UserID, UserImagesToProcess.ImageID). + VALUES(userId, insertedImage.ID). + RETURNING(UserImagesToProcess.AllColumns) + + userImage := model.UserImagesToProcess{} + err = stmt.QueryContext(ctx, tx, &userImage) + if err != nil { + return model.UserImagesToProcess{}, err + } + + err = tx.Commit() + + return userImage, err +} + +func (m ImageModel) GetToProcess(ctx context.Context, imageId uuid.UUID) (model.UserImagesToProcess, error) { + getToProcessStmt := UserImagesToProcess. + SELECT(UserImagesToProcess.AllColumns). + WHERE(UserImagesToProcess.ID.EQ(UUID(imageId))) images := []model.UserImagesToProcess{} - err := stmt.Query(db, &images) + err := getToProcessStmt.QueryContext(ctx, m.dbPool, &images) if len(images) != 1 { return model.UserImagesToProcess{}, errors.New(fmt.Sprintf("Expected 1, got %d\n", len(images))) @@ -140,97 +87,94 @@ func GetImageToProcess(imageId string) (model.UserImagesToProcess, error) { return images[0], err } -type UserImagesWithInfo struct { - ID uuid.UUID +func (m ImageModel) GetToProcessWithData(ctx context.Context, imageId uuid.UUID) (ProcessingImageData, error) { + stmt := SELECT(UserImagesToProcess.AllColumns, Image.AllColumns). + FROM( + UserImagesToProcess.INNER_JOIN( + Image, Image.ID.EQ(UserImagesToProcess.ImageID), + ), + ).WHERE(UserImagesToProcess.ID.EQ(UUID(imageId))) - // TODO: this shit - Image model.Image + images := []ProcessingImageData{} + err := stmt.QueryContext(ctx, m.dbPool, &images) - Tags []model.ImageTags - Links []model.ImageLinks - Text []model.ImageText + if len(images) != 1 { + return ProcessingImageData{}, errors.New(fmt.Sprintf("Expected 1, got %d\n", len(images))) + } + + return images[0], err } -func GetUserImages(userId string) ([]UserImagesWithInfo, error) { - id := uuid.MustParse(userId) - stmt := SELECT(UserImages.ID.AS("UserImagesWithInfo.ID"), Image.ID, Image.ImageName, ImageTags.AllColumns, ImageText.AllColumns, ImageLinks.AllColumns).FROM(UserImages.INNER_JOIN(Image, Image.ID.EQ(UserImages.ImageID)).LEFT_JOIN(ImageTags, ImageTags.ImageID.EQ(UserImages.ID)).LEFT_JOIN(ImageText, ImageText.ImageID.EQ(UserImages.ID)).LEFT_JOIN(ImageLinks, ImageLinks.ImageID.EQ(UserImages.ID))).WHERE(UserImages.UserID.EQ(UUID(id))) +func (m ImageModel) FinishProcessing(ctx context.Context, imageId uuid.UUID) (model.UserImages, error) { + imageToProcess, err := m.GetToProcess(ctx, imageId) + if err != nil { + return model.UserImages{}, err + } - images := []UserImagesWithInfo{} - err := stmt.Query(db, &images) + tx, err := m.dbPool.Begin() + if err != nil { + return model.UserImages{}, err + } + + insertImageStmt := UserImages. + INSERT(UserImages.UserID, UserImages.ImageID). + VALUES(imageToProcess.UserID, imageToProcess.ImageID). + RETURNING(UserImages.ID, UserImages.UserID, UserImages.ImageID) + + userImage := model.UserImages{} + err = insertImageStmt.QueryContext(ctx, tx, &userImage) + if err != nil { + return model.UserImages{}, err + } + + removeProcessingStmt := UserImagesToProcess. + DELETE(). + WHERE(UserImagesToProcess.ID.EQ(UUID(imageToProcess.ID))) + + _, err = removeProcessingStmt.ExecContext(ctx, tx) + + return userImage, err +} + +func (m ImageModel) Get(ctx context.Context, imageId uuid.UUID) (ImageData, error) { + getImageStmt := SELECT(UserImages.AllColumns, Image.AllColumns). + FROM( + UserImages.INNER_JOIN(Image, Image.ID.EQ(UserImages.ImageID)), + ). + WHERE(UserImages.ID.EQ(UUID(imageId))) + + images := []ImageData{} + err := getImageStmt.QueryContext(ctx, m.dbPool, &images) + + if len(images) != 1 { + return ImageData{}, errors.New(fmt.Sprintf("Expected 1, got %d\n", len(images))) + } + + return images[0], err +} + +// TODO: move this to `user.go` model file +func (m ImageModel) ListWithProperties(ctx context.Context, userId uuid.UUID) ([]ImageWithProperties, error) { + listWithPropertiesStmt := SELECT( + UserImages.ID.AS("UserImagesWithInfo.ID"), + Image.ID, + Image.ImageName, + ImageTags.AllColumns, + ImageText.AllColumns, + ImageLinks.AllColumns). + FROM( + UserImages.INNER_JOIN(Image, Image.ID.EQ(UserImages.ImageID)). + LEFT_JOIN(ImageTags, ImageTags.ImageID.EQ(UserImages.ID)). + LEFT_JOIN(ImageText, ImageText.ImageID.EQ(UserImages.ID)). + LEFT_JOIN(ImageLinks, ImageLinks.ImageID.EQ(UserImages.ID))). + WHERE(UserImages.UserID.EQ(UUID(userId))) + + images := []ImageWithProperties{} + err := listWithPropertiesStmt.QueryContext(ctx, m.dbPool, &images) return images, err } -func SaveImageTags(imageId string, tags []string) ([]model.ImageTags, error) { - id := uuid.MustParse(imageId) - - userId, err := getUserId(id) - if err != nil { - return []model.ImageTags{}, err - } - - err = CreateTags(userId, tags) - if err != nil { - return []model.ImageTags{}, err - } - - userTagsExpression := make([]Expression, 0) - for _, tag := range tags { - userTagsExpression = append(userTagsExpression, String(tag)) - } - - userTags := make([]model.UserTags, 0) - - getTagsStmt := UserTags.SELECT(UserTags.ID, UserTags.Tag).WHERE(UserTags.Tag.IN(userTagsExpression...)) - err = getTagsStmt.Query(db, &userTags) - if err != nil { - return []model.ImageTags{}, err - } - - stmt := ImageTags.INSERT(ImageTags.ImageID, ImageTags.TagID) - - for _, t := range userTags { - stmt = stmt.VALUES(id, t.ID) - } - - stmt.RETURNING(ImageTags.AllColumns) - - imageTags := make([]model.ImageTags, 0) - err = stmt.Query(db, &imageTags) - - return imageTags, err -} - -func SaveImageLinks(imageId string, links []string) ([]model.ImageLinks, error) { - id := uuid.MustParse(imageId) - - stmt := ImageLinks.INSERT(ImageLinks.ImageID, ImageLinks.Link) - - for _, t := range links { - stmt = stmt.VALUES(id, t) - } - - stmt.RETURNING(ImageLinks.AllColumns) - - imageLinks := []model.ImageLinks{} - err := stmt.Query(db, &imageLinks) - - return imageLinks, err -} - -func SaveImageTexts(imageId string, texts []string) ([]model.ImageText, error) { - id := uuid.MustParse(imageId) - - stmt := ImageText.INSERT(ImageText.ImageID, ImageText.ImageText) - - for _, t := range texts { - stmt = stmt.VALUES(id, t) - } - - stmt.RETURNING(ImageText.AllColumns) - - imageTags := []model.ImageText{} - err := stmt.Query(db, &imageTags) - - return imageTags, err +func NewImageModel(db *sql.DB) ImageModel { + return ImageModel{dbPool: db} } diff --git a/backend/models/links.go b/backend/models/links.go new file mode 100644 index 0000000..e598b10 --- /dev/null +++ b/backend/models/links.go @@ -0,0 +1,29 @@ +package models + +import ( + "context" + "database/sql" + . "screenmark/screenmark/.gen/haystack/haystack/table" + + "github.com/google/uuid" +) + +type LinkModel struct { + dbPool *sql.DB +} + +func (m LinkModel) Save(ctx context.Context, imageId uuid.UUID, links []string) error { + stmt := ImageLinks.INSERT(ImageLinks.ImageID, ImageLinks.Link) + + for _, link := range links { + stmt = stmt.VALUES(imageId, link) + } + + _, err := stmt.ExecContext(ctx, m.dbPool) + + return err +} + +func NewLinkModel(db *sql.DB) LinkModel { + return LinkModel{dbPool: db} +} diff --git a/backend/models/tags.go b/backend/models/tags.go index e61b977..a6d3a93 100644 --- a/backend/models/tags.go +++ b/backend/models/tags.go @@ -1,6 +1,8 @@ package models import ( + "context" + "database/sql" "fmt" "screenmark/screenmark/.gen/haystack/haystack/model" . "screenmark/screenmark/.gen/haystack/haystack/table" @@ -9,6 +11,10 @@ import ( "github.com/google/uuid" ) +type TagModel struct { + dbPool *sql.DB +} + // Raw dogging SQL is kinda based though? // // | nO, usE OrM!! @@ -20,7 +26,7 @@ import ( // | -- -- // | -- -- // | ---- IQ ---- -func getNonExistantTags(userId uuid.UUID, tags []string) ([]string, error) { +func (m TagModel) getNonExistantTags(ctx context.Context, userId uuid.UUID, tags []string) ([]string, error) { values := "" counter := 1 // big big SQL injection problem here? @@ -29,20 +35,7 @@ func getNonExistantTags(userId uuid.UUID, tags []string) ([]string, error) { } values = values[0 : len(values)-1] - /* - WITH given_tags - AS (SELECT given_tags.tag FROM (VALUES ('c')) AS given_tags (tag)), - this_user_tags as ( - SELECT id, tag - FROM haystack.user_tags - where user_tags.user_id = 'fcc22dbb-7792-4595-be8e-d0439e13990a' - ) - select given_tags.tag from given_tags - LEFT OUTER JOIN this_user_tags ON this_user_tags.tag = given_tags.tag - where this_user_tags.tag is null; - */ - - withStuff := fmt.Sprintf(`WITH given_tags + getNonExistingTags := fmt.Sprintf(`WITH given_tags AS (SELECT given_tags.tag FROM (VALUES `+values+`) AS given_tags (tag)), this_user_tags AS (SELECT id, tag FROM haystack.user_tags WHERE user_tags.user_id = $%d) @@ -51,23 +44,20 @@ func getNonExistantTags(userId uuid.UUID, tags []string) ([]string, error) { LEFT OUTER JOIN haystack.user_tags ON haystack.user_tags.tag = given_tags.tag where user_tags.tag is null`, counter) - stmt, err := db.Prepare(withStuff) - fmt.Println(withStuff) + getNonExistingTagsStmt, err := m.dbPool.PrepareContext(ctx, getNonExistingTags) + defer getNonExistingTagsStmt.Close() if err != nil { - fmt.Println("failing to prepare stmt") return []string{}, err } - defer stmt.Close() - args := make([]any, counter) for i, v := range tags { args[i] = v } args[counter-1] = userId.String() - rows, err := stmt.Query(args...) + rows, err := getNonExistingTagsStmt.QueryContext(ctx, args...) if err != nil { return []string{}, err } @@ -84,8 +74,8 @@ func getNonExistantTags(userId uuid.UUID, tags []string) ([]string, error) { return nonExistantTags, nil } -func CreateTags(userId uuid.UUID, tags []string) error { - tagsToInsert, err := getNonExistantTags(userId, tags) +func (m TagModel) Save(ctx context.Context, userId uuid.UUID, tags []string) error { + tagsToInsert, err := m.getNonExistantTags(ctx, userId, tags) if err != nil { return err } @@ -100,17 +90,58 @@ func CreateTags(userId uuid.UUID, tags []string) error { stmt = stmt.VALUES(UUID(userId), tag) } - _, err = stmt.Exec(db) + _, err = stmt.ExecContext(ctx, m.dbPool) return err } -func GetTags(userId uuid.UUID) ([]model.UserTags, error) { - stmt := UserTags.SELECT(UserTags.AllColumns).WHERE(UserTags.UserID.EQ(UUID(userId))) +func (m TagModel) List(ctx context.Context, userId uuid.UUID) ([]model.UserTags, error) { + listTagsStmt := UserTags.SELECT(UserTags.AllColumns).WHERE(UserTags.UserID.EQ(UUID(userId))) userTags := []model.UserTags{} - err := stmt.Query(db, &userTags) + err := listTagsStmt.QueryContext(ctx, m.dbPool, &userTags) return userTags, err } + +func (m TagModel) SaveToImage(ctx context.Context, imageId uuid.UUID, tags []string) error { + userId, err := getUserIdFromImage(ctx, m.dbPool, imageId) + if err != nil { + return err + } + + err = m.Save(ctx, userId, tags) + if err != nil { + return err + } + + userTagsExpression := make([]Expression, 0) + for _, tag := range tags { + userTagsExpression = append(userTagsExpression, String(tag)) + } + + userTags := make([]model.UserTags, 0) + + getTagsStmt := UserTags.SELECT( + UserTags.ID, UserTags.Tag, + ).WHERE(UserTags.Tag.IN(userTagsExpression...)) + err = getTagsStmt.Query(m.dbPool, &userTags) + if err != nil { + return err + } + + stmt := ImageTags.INSERT(ImageTags.ImageID, ImageTags.TagID) + + for _, t := range userTags { + stmt = stmt.VALUES(imageId, t.ID) + } + + _, err = stmt.ExecContext(ctx, m.dbPool) + + return err +} + +func NewTagModel(db *sql.DB) TagModel { + return TagModel{dbPool: db} +} diff --git a/backend/models/text.go b/backend/models/text.go new file mode 100644 index 0000000..031d01b --- /dev/null +++ b/backend/models/text.go @@ -0,0 +1,31 @@ +package models + +import ( + "context" + "database/sql" + . "screenmark/screenmark/.gen/haystack/haystack/table" + + "github.com/google/uuid" +) + +type TextModel struct { + dbPool *sql.DB +} + +func (m TextModel) Save(ctx context.Context, imageId uuid.UUID, texts []string) error { + saveImageTextStmt := ImageText.INSERT(ImageText.ImageID, ImageText.ImageText) + + for _, t := range texts { + saveImageTextStmt = saveImageTextStmt.VALUES(imageId, t) + } + + saveImageTextStmt.RETURNING(ImageText.AllColumns) + + _, err := saveImageTextStmt.ExecContext(ctx, m.dbPool) + + return err +} + +func NewTextModel(db *sql.DB) TextModel { + return TextModel{dbPool: db} +} diff --git a/backend/models/user.go b/backend/models/user.go new file mode 100644 index 0000000..2aacc9e --- /dev/null +++ b/backend/models/user.go @@ -0,0 +1,32 @@ +package models + +import ( + "context" + "database/sql" + "errors" + "screenmark/screenmark/.gen/haystack/haystack/model" + . "screenmark/screenmark/.gen/haystack/haystack/table" + + . "github.com/go-jet/jet/v2/postgres" + "github.com/google/uuid" +) + +type UserModek struct { + dbPool *sql.DB +} + +func getUserIdFromImage(ctx context.Context, dbPool *sql.DB, imageId uuid.UUID) (uuid.UUID, error) { + getUserIdStmt := UserImages.SELECT(UserImages.UserID).WHERE(UserImages.ID.EQ(UUID(imageId))) + + user := []model.Users{} + err := getUserIdStmt.QueryContext(ctx, dbPool, &user) + if err != nil { + return uuid.Nil, err + } + + if len(user) != 1 { + return uuid.Nil, errors.New("Expected exactly one choice.") + } + + return user[0].ID, nil +} diff --git a/backend/openai_test.go b/backend/openai_test.go index 945f507..387e82f 100644 --- a/backend/openai_test.go +++ b/backend/openai_test.go @@ -82,8 +82,8 @@ func TestMessageBuilderImage(t *testing.T) { base64data := base64.StdEncoding.EncodeToString(data) url := fmt.Sprintf("data:image/%s;base64,%s", "png", base64data) - if imageContent.ImageUrl.Url != url { - t.Logf("Expected %s, but got %s.\n", url, imageContent.ImageUrl.Url) + if imageContent.ImageUrl != url { + t.Logf("Expected %s, but got %s.\n", url, imageContent.ImageUrl) t.FailNow() } }