diff --git a/backend/go.mod b/backend/go.mod index c5dca0b..fcdc38a 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -4,6 +4,7 @@ go 1.24.0 require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/go-chi/chi/v5 v5.2.1 // indirect github.com/go-jet/jet/v2 v2.12.0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/joho/godotenv v1.5.1 // indirect diff --git a/backend/go.sum b/backend/go.sum index 7ccdb24..ecd9db0 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -1,5 +1,7 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-chi/chi/v5 v5.2.1 h1:KOIHODQj58PmL80G2Eak4WdvUzjSJSm0vG72crDCqb8= +github.com/go-chi/chi/v5 v5.2.1/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops= github.com/go-jet/jet/v2 v2.12.0 h1:z2JfvBAZgsfxlQz6NXBYdZTXc7ep3jhbszTLtETv1JE= github.com/go-jet/jet/v2 v2.12.0/go.mod h1:ufQVRQeI1mbcO5R8uCEVcVf3Foej9kReBdwDx7YMWUM= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= diff --git a/backend/main.go b/backend/main.go index a6cd088..7674819 100644 --- a/backend/main.go +++ b/backend/main.go @@ -11,16 +11,38 @@ import ( "screenmark/screenmark/models" "time" + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" "github.com/joho/godotenv" "github.com/lib/pq" ) +type TestAiClient struct { + ImageInfo ImageInfo +} + +func (client TestAiClient) GetImageInfo(imageName string, imageData []byte) (ImageInfo, error) { + return client.ImageInfo, nil +} + +func GetAiClient() (AiClient, error) { + mode := os.Getenv("MODE") + if mode == "TESTING" { + return TestAiClient{}, nil + } + + return CreateOpenAiClient() +} + func main() { err := godotenv.Load() if err != nil { panic(err) } + mode := os.Getenv("MODE") + log.Printf("Mode: %s\n", mode) + err = models.InitDatabase() if err != nil { panic(err) @@ -39,46 +61,51 @@ func main() { panic(err) } - select { - case parameters := <-listener.Notify: - log.Println("received notification, new image available: " + parameters.Extra) + for { - go func() { - openAiClient, err := CreateOpenAiClient() - if err != nil { - panic(err) - } + select { + case parameters := <-listener.Notify: + log.Println("received notification, new image available: " + parameters.Extra) - image, err := models.GetImage(parameters.Extra) - if err != nil { - log.Println(err) - return - } + go func() { + openAiClient, err := GetAiClient() + if err != nil { + panic(err) + } - imageInfo, err := openAiClient.GetImageInfo(image.ImageName, image.Image) - if err != nil { - log.Println(err) - return - } + image, err := models.GetImage(parameters.Extra) + if err != nil { + log.Println(err) + return + } - log.Println("Finished processing image " + parameters.Extra) + imageInfo, err := openAiClient.GetImageInfo(image.ImageName, image.Image) + if err != nil { + log.Println(err) + return + } - models.SaveImageTags(parameters.Extra, imageInfo.Tags) - models.SaveImageLinks(parameters.Extra, imageInfo.Links) - models.SaveImageTexts(parameters.Extra, imageInfo.Tags) - }() + log.Println("Finished processing image " + parameters.Extra) + log.Printf("Image attributes: %+v\n", imageInfo) + + models.SaveImageTags(parameters.Extra, imageInfo.Tags) + models.SaveImageLinks(parameters.Extra, imageInfo.Links) + models.SaveImageTexts(parameters.Extra, imageInfo.Text) + }() + } } }() - mux := http.NewServeMux() + r := chi.NewRouter() + r.Use(middleware.Logger) - mux.HandleFunc("OPTIONS /image/{name}", func(w http.ResponseWriter, r *http.Request) { + r.Options("/*", func(w http.ResponseWriter, r *http.Request) { w.Header().Add("Access-Control-Allow-Origin", "*") w.Header().Add("Access-Control-Allow-Credentials", "*") w.Header().Add("Access-Control-Allow-Headers", "*") }) - mux.HandleFunc("GET /image", func(w http.ResponseWriter, r *http.Request) { + r.Get("/image", func(w http.ResponseWriter, r *http.Request) { userId := r.Header.Get("userId") images, err := models.GetUserImages(userId) @@ -100,7 +127,7 @@ func main() { w.Write(jsonImages) }) - mux.HandleFunc("GET /image/{id}", func(w http.ResponseWriter, r *http.Request) { + r.Get("/image/{id}", func(w http.ResponseWriter, r *http.Request) { imageId := r.PathValue("id") // TODO: really need authorization here! @@ -120,7 +147,7 @@ func main() { w.Write(image.Image) }) - mux.HandleFunc("POST /image/{name}", func(w http.ResponseWriter, r *http.Request) { + r.Post("/image/{name}", func(w http.ResponseWriter, r *http.Request) { imageName := r.PathValue("name") userId := r.Header.Get("userId") @@ -167,5 +194,5 @@ func main() { log.Println("Listening and serving on port 3040.") - http.ListenAndServe(":3040", mux) + http.ListenAndServe(":3040", r) } diff --git a/backend/models/image.go b/backend/models/image.go index 733a607..9fdca5a 100644 --- a/backend/models/image.go +++ b/backend/models/image.go @@ -46,8 +46,6 @@ func GetUserImages(userId string) ([]UserImagesWithInfo, error) { id := uuid.MustParse(userId) stmt := SELECT(UserImages.ID, UserImages.ImageName, ImageTags.AllColumns, ImageText.AllColumns, ImageLinks.AllColumns).FROM(UserImages.LEFT_JOIN(ImageTags, ImageTags.ImageID.EQ(UserImages.ID)).LEFT_JOIN(ImageText, ImageText.ImageID.EQ(UserImages.ID)).LEFT_JOIN(ImageLinks, ImageLinks.ImageID.EQ(UserImages.ID))).WHERE(UserImages.UserID.EQ(UUID(id))) - fmt.Println(stmt.DebugSql()) - images := []UserImagesWithInfo{} err := stmt.Query(db, &images) diff --git a/backend/openai.go b/backend/openai.go index 01ce639..f7a114e 100644 --- a/backend/openai.go +++ b/backend/openai.go @@ -113,6 +113,10 @@ func (imageMessage OpenAiImage) ToJson() ([]byte, error) { return json.Marshal(imageMessage) } +type AiClient interface { + GetImageInfo(imageName string, imageData []byte) (ImageInfo, error) +} + type OpenAiClient struct { url string apiKey string