Haystack/backend/main.go

276 lines
6.3 KiB
Go

package main
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
"screenmark/screenmark/models"
"time"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/joho/godotenv"
"github.com/lib/pq"
)
type TestAiClient struct {
ImageInfo ImageInfo
}
func (client TestAiClient) GetImageInfo(imageName string, imageData []byte) (ImageInfo, error) {
return client.ImageInfo, nil
}
func GetAiClient() (AiClient, error) {
mode := os.Getenv("MODE")
if mode == "TESTING" {
return TestAiClient{
ImageInfo: ImageInfo{
Tags: []string{"tag"},
Links: []string{"links"},
Text: []string{"text"},
},
}, nil
}
return CreateOpenAiClient()
}
func main() {
err := godotenv.Load()
if err != nil {
panic(err)
}
mode := os.Getenv("MODE")
log.Printf("Mode: %s\n", mode)
err = models.InitDatabase()
if err != nil {
panic(err)
}
listener := pq.NewListener(os.Getenv("DB_CONNECTION"), time.Second, time.Second, func(event pq.ListenerEventType, err error) {
if err != nil {
panic(err)
}
})
defer listener.Close()
go func() {
err := listener.Listen("new_image")
if err != nil {
panic(err)
}
for {
select {
case parameters := <-listener.Notify:
imageId := parameters.Extra
log.Println("received notification, new image available: " + imageId)
go func() {
openAiClient, err := GetAiClient()
if err != nil {
panic(err)
}
image, err := models.GetImageToProcessWithData(imageId)
if err != nil {
log.Println("1")
log.Println(err)
return
}
imageInfo, err := openAiClient.GetImageInfo(image.Image.ImageName, image.Image.Image)
if err != nil {
log.Println("2")
log.Println(err)
return
}
savedImage, err := models.SaveImage(image.ID)
if err != nil {
log.Println("3")
log.Println(err)
return
}
log.Println("Finished processing image " + imageId)
log.Printf("Image attributes: %+v\n", imageInfo)
models.SaveImageTags(savedImage.ID.String(), imageInfo.Tags)
models.SaveImageLinks(savedImage.ID.String(), imageInfo.Links)
models.SaveImageTexts(savedImage.ID.String(), imageInfo.Text)
}()
}
}
}()
r := chi.NewRouter()
r.Use(middleware.Logger)
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Content-Type", "application/json")
next.ServeHTTP(w, r)
})
})
r.Options("/*", func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Access-Control-Allow-Origin", "*")
w.Header().Add("Access-Control-Allow-Credentials", "*")
w.Header().Add("Access-Control-Allow-Headers", "*")
})
r.Get("/image", func(w http.ResponseWriter, r *http.Request) {
userId := r.Header.Get("userId")
w.Header().Add("Access-Control-Allow-Origin", "*")
w.Header().Add("Access-Control-Allow-Credentials", "*")
w.Header().Add("Access-Control-Allow-Headers", "*")
images, err := models.GetUserImages(userId)
if err != nil {
log.Println(err)
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, "Something went wrong")
return
}
jsonImages, err := json.Marshal(images)
if err != nil {
log.Println(err)
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "Could not create JSON response for this image")
return
}
w.Write(jsonImages)
})
r.Get("/image/{id}", func(w http.ResponseWriter, r *http.Request) {
imageId := r.PathValue("id")
w.Header().Add("Access-Control-Allow-Origin", "*")
w.Header().Add("Access-Control-Allow-Credentials", "*")
w.Header().Add("Access-Control-Allow-Headers", "*")
// TODO: really need authorization here!
image, err := models.GetImage(imageId)
if err != nil {
log.Println(err)
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, "Could not get image")
return
}
// TODO: this could be part of the db table
extension := filepath.Ext(image.Image.ImageName)
extension = extension[1:]
w.Header().Add("Content-Type", "image/"+extension)
w.Write(image.Image.Image)
})
r.Post("/image/{name}", func(w http.ResponseWriter, r *http.Request) {
imageName := r.PathValue("name")
userId := r.Header.Get("userId")
w.Header().Add("Access-Control-Allow-Origin", "*")
w.Header().Add("Access-Control-Allow-Credentials", "*")
w.Header().Add("Access-Control-Allow-Headers", "*")
if len(imageName) == 0 {
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "You need to provide a name in the path")
return
}
contentType := r.Header.Get("Content-Type")
log.Println(contentType)
// TODO: length checks on body
// TODO: extract this shit out
image := make([]byte, 0)
if contentType == "application/base64" {
decoder := base64.NewDecoder(base64.StdEncoding, r.Body)
buf := &bytes.Buffer{}
decodedIamge, err := io.Copy(buf, decoder)
if err != nil {
log.Println(err)
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "bruh, base64 aint decoding")
return
}
fmt.Println(string(image))
fmt.Println(decodedIamge)
image = buf.Bytes()
} else if contentType == "application/oclet-stream" {
bodyData, err := io.ReadAll(r.Body)
if err != nil {
log.Println(err)
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "bruh, binary aint binaring")
return
}
// TODO: check headers
image = bodyData
} else {
log.Println("bad stuff?")
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "Bruh, you need oclet stream or base64")
return
}
if err != nil {
log.Println("First case")
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "Couldnt read the image from the request body")
return
}
userImage, err := models.SaveImageToProcess(userId, imageName, image)
if err != nil {
log.Println("Second case")
log.Println(err)
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "Could not save image to DB")
return
}
jsonUserImage, err := json.Marshal(userImage)
if err != nil {
log.Println("Third case")
log.Println(err)
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "Could not create JSON response for this image")
return
}
w.WriteHeader(http.StatusCreated)
fmt.Fprint(w, string(jsonUserImage))
w.Header().Add("Content-Type", "application/json")
})
log.Println("Listening and serving on port 3040.")
http.ListenAndServe(":3040", r)
}