364 lines
8.6 KiB
Go
364 lines
8.6 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"screenmark/screenmark/.gen/haystack/haystack/model"
|
|
"screenmark/screenmark/agents"
|
|
"screenmark/screenmark/models"
|
|
"time"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/go-chi/chi/v5/middleware"
|
|
"github.com/google/uuid"
|
|
"github.com/joho/godotenv"
|
|
"github.com/lib/pq"
|
|
)
|
|
|
|
type TestAiClient struct {
|
|
ImageInfo agents.ImageInfo
|
|
}
|
|
|
|
func (client TestAiClient) GetImageInfo(imageName string, imageData []byte) (agents.ImageInfo, error) {
|
|
return client.ImageInfo, nil
|
|
}
|
|
|
|
func GetAiClient() (agents.AiClient, error) {
|
|
mode := os.Getenv("MODE")
|
|
if mode == "TESTING" {
|
|
address := "10 Downing Street"
|
|
description := "Cheese and Crackers"
|
|
|
|
return TestAiClient{
|
|
ImageInfo: agents.ImageInfo{
|
|
Tags: []string{"tag"},
|
|
Links: []string{"links"},
|
|
Text: []string{"text"},
|
|
Locations: []model.Locations{{
|
|
ID: uuid.Nil,
|
|
Name: "London",
|
|
Address: &address,
|
|
}},
|
|
Events: []model.Events{{
|
|
ID: uuid.Nil,
|
|
Name: "Party",
|
|
Description: &description,
|
|
}},
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
return agents.CreateAgentClient(agents.PROMPT)
|
|
}
|
|
|
|
func main() {
|
|
err := godotenv.Load()
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
mode := os.Getenv("MODE")
|
|
log.Printf("Mode: %s\n", mode)
|
|
|
|
db, err := models.InitDatabase()
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
imageModel := models.NewImageModel(db)
|
|
linkModel := models.NewLinkModel(db)
|
|
tagModel := models.NewTagModel(db)
|
|
textModel := models.NewTextModel(db)
|
|
locationModel := models.NewLocationModel(db)
|
|
eventModel := models.NewEventModel(db)
|
|
userModel := models.NewUserModel(db)
|
|
contactModel := models.NewContactModel(db)
|
|
noteModel := models.NewNoteModel(db)
|
|
|
|
listener := pq.NewListener(os.Getenv("DB_CONNECTION"), time.Second, time.Second, func(event pq.ListenerEventType, err error) {
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
})
|
|
defer listener.Close()
|
|
|
|
go func() {
|
|
err := listener.Listen("new_image")
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
for {
|
|
|
|
select {
|
|
case parameters := <-listener.Notify:
|
|
imageId := uuid.MustParse(parameters.Extra)
|
|
|
|
ctx := context.Background()
|
|
|
|
go func() {
|
|
openAiClient, err := GetAiClient()
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
_, err = agents.NewLocationEventAgent(locationModel, eventModel, contactModel)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
noteAgent, err := agents.NewNoteAgent(noteModel)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
image, err := imageModel.GetToProcessWithData(ctx, imageId)
|
|
if err != nil {
|
|
log.Println("Failed to GetToProcessWithData")
|
|
log.Println(err)
|
|
return
|
|
}
|
|
|
|
userImage, err := imageModel.FinishProcessing(ctx, image.ID)
|
|
if err != nil {
|
|
log.Println("Failed to FinishProcessing")
|
|
log.Println(err)
|
|
return
|
|
}
|
|
|
|
// log.Println("Calling locationAgent!")
|
|
// err = locationAgent.GetLocations(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
|
|
// log.Println(err)
|
|
|
|
log.Println("Calling noteAgent!")
|
|
err = noteAgent.GetNotes(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
|
|
log.Println(err)
|
|
|
|
return
|
|
|
|
imageInfo, err := openAiClient.GetImageInfo(image.Image.ImageName, image.Image.Image)
|
|
if err != nil {
|
|
log.Println("Failed to GetImageInfo")
|
|
log.Println(err)
|
|
return
|
|
}
|
|
|
|
err = tagModel.SaveToImage(ctx, userImage.ImageID, imageInfo.Tags)
|
|
if err != nil {
|
|
log.Println("Failed to save tags")
|
|
log.Println(err)
|
|
return
|
|
}
|
|
|
|
err = linkModel.Save(ctx, userImage.ImageID, imageInfo.Links)
|
|
if err != nil {
|
|
log.Println("Failed to save links")
|
|
log.Println(err)
|
|
return
|
|
}
|
|
|
|
err = textModel.Save(ctx, userImage.ImageID, imageInfo.Text)
|
|
if err != nil {
|
|
log.Println("Failed to save text")
|
|
log.Println(err)
|
|
return
|
|
}
|
|
}()
|
|
}
|
|
}
|
|
}()
|
|
|
|
r := chi.NewRouter()
|
|
|
|
r.Use(middleware.Logger)
|
|
r.Use(func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Add("Content-Type", "application/json")
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
})
|
|
|
|
r.Options("/*", func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Add("Access-Control-Allow-Origin", "*")
|
|
w.Header().Add("Access-Control-Allow-Credentials", "*")
|
|
w.Header().Add("Access-Control-Allow-Headers", "*")
|
|
})
|
|
|
|
r.Get("/image", func(w http.ResponseWriter, r *http.Request) {
|
|
userId := r.Header.Get("userId")
|
|
|
|
w.Header().Add("Access-Control-Allow-Origin", "*")
|
|
w.Header().Add("Access-Control-Allow-Credentials", "*")
|
|
w.Header().Add("Access-Control-Allow-Headers", "*")
|
|
|
|
images, err := userModel.ListWithProperties(r.Context(), uuid.MustParse(userId))
|
|
if err != nil {
|
|
log.Println(err)
|
|
w.WriteHeader(http.StatusNotFound)
|
|
fmt.Fprintf(w, "Something went wrong")
|
|
return
|
|
}
|
|
|
|
log.Println(images)
|
|
|
|
type DataType struct {
|
|
Type string `json:"type"`
|
|
Data any `json:"data"`
|
|
}
|
|
|
|
dataTypes := make([]DataType, 0)
|
|
for _, image := range images {
|
|
for _, location := range image.Locations {
|
|
dataTypes = append(dataTypes, DataType{
|
|
Type: "location",
|
|
Data: location,
|
|
})
|
|
}
|
|
|
|
for _, event := range image.Events {
|
|
dataTypes = append(dataTypes, DataType{
|
|
Type: "event",
|
|
Data: event,
|
|
})
|
|
}
|
|
}
|
|
|
|
jsonImages, err := json.Marshal(dataTypes)
|
|
if err != nil {
|
|
log.Println(err)
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
fmt.Fprintf(w, "Could not create JSON response for this image")
|
|
return
|
|
}
|
|
|
|
w.Write(jsonImages)
|
|
})
|
|
|
|
r.Get("/image/{id}", func(w http.ResponseWriter, r *http.Request) {
|
|
imageId := r.PathValue("id")
|
|
|
|
w.Header().Add("Access-Control-Allow-Origin", "*")
|
|
w.Header().Add("Access-Control-Allow-Credentials", "*")
|
|
w.Header().Add("Access-Control-Allow-Headers", "*")
|
|
|
|
// TODO: really need authorization here!
|
|
image, err := imageModel.Get(r.Context(), uuid.MustParse(imageId))
|
|
if err != nil {
|
|
log.Println(err)
|
|
w.WriteHeader(http.StatusNotFound)
|
|
fmt.Fprintf(w, "Could not get image")
|
|
return
|
|
}
|
|
|
|
// TODO: this could be part of the db table
|
|
extension := filepath.Ext(image.Image.ImageName)
|
|
extension = extension[1:]
|
|
|
|
w.Header().Add("Content-Type", "image/"+extension)
|
|
w.Write(image.Image.Image)
|
|
})
|
|
|
|
r.Post("/image/{name}", func(w http.ResponseWriter, r *http.Request) {
|
|
imageName := r.PathValue("name")
|
|
|
|
userId := r.Header.Get("userId")
|
|
|
|
w.Header().Add("Access-Control-Allow-Origin", "*")
|
|
w.Header().Add("Access-Control-Allow-Credentials", "*")
|
|
w.Header().Add("Access-Control-Allow-Headers", "*")
|
|
|
|
if len(imageName) == 0 {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
fmt.Fprintf(w, "You need to provide a name in the path")
|
|
return
|
|
}
|
|
|
|
contentType := r.Header.Get("Content-Type")
|
|
|
|
log.Println(contentType)
|
|
|
|
// TODO: length checks on body
|
|
// TODO: extract this shit out
|
|
image := make([]byte, 0)
|
|
if contentType == "application/base64" {
|
|
decoder := base64.NewDecoder(base64.StdEncoding, r.Body)
|
|
buf := &bytes.Buffer{}
|
|
|
|
decodedIamge, err := io.Copy(buf, decoder)
|
|
if err != nil {
|
|
log.Println(err)
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
fmt.Fprintf(w, "bruh, base64 aint decoding")
|
|
return
|
|
}
|
|
|
|
fmt.Println(string(image))
|
|
fmt.Println(decodedIamge)
|
|
|
|
image = buf.Bytes()
|
|
} else if contentType == "application/oclet-stream" {
|
|
bodyData, err := io.ReadAll(r.Body)
|
|
if err != nil {
|
|
log.Println(err)
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
fmt.Fprintf(w, "bruh, binary aint binaring")
|
|
return
|
|
}
|
|
// TODO: check headers
|
|
|
|
image = bodyData
|
|
} else {
|
|
log.Println("bad stuff?")
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
fmt.Fprintf(w, "Bruh, you need oclet stream or base64")
|
|
return
|
|
}
|
|
|
|
if err != nil {
|
|
log.Println("First case")
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
fmt.Fprintf(w, "Couldnt read the image from the request body")
|
|
return
|
|
}
|
|
|
|
userImage, err := imageModel.Process(r.Context(), uuid.MustParse(userId), model.Image{
|
|
Image: image,
|
|
ImageName: imageName,
|
|
})
|
|
if err != nil {
|
|
log.Println("Second case")
|
|
log.Println(err)
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
fmt.Fprintf(w, "Could not save image to DB")
|
|
return
|
|
}
|
|
|
|
jsonUserImage, err := json.Marshal(userImage)
|
|
if err != nil {
|
|
log.Println("Third case")
|
|
log.Println(err)
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
fmt.Fprintf(w, "Could not create JSON response for this image")
|
|
return
|
|
}
|
|
|
|
w.WriteHeader(http.StatusCreated)
|
|
|
|
fmt.Fprint(w, string(jsonUserImage))
|
|
w.Header().Add("Content-Type", "application/json")
|
|
})
|
|
|
|
log.Println("Listening and serving on port 3040.")
|
|
|
|
http.ListenAndServe(":3040", r)
|
|
}
|