Haystack/backend/main.go

382 lines
9.0 KiB
Go

package main
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
"screenmark/screenmark/.gen/haystack/haystack/model"
"screenmark/screenmark/agents"
"screenmark/screenmark/models"
"time"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/google/uuid"
"github.com/joho/godotenv"
"github.com/lib/pq"
)
type TestAiClient struct {
ImageInfo agents.ImageInfo
}
func (client TestAiClient) GetImageInfo(imageName string, imageData []byte) (agents.ImageInfo, error) {
return client.ImageInfo, nil
}
func GetAiClient() (agents.AiClient, error) {
mode := os.Getenv("MODE")
if mode == "TESTING" {
address := "10 Downing Street"
description := "Cheese and Crackers"
return TestAiClient{
ImageInfo: agents.ImageInfo{
Tags: []string{"tag"},
Links: []string{"links"},
Text: []string{"text"},
Locations: []model.Locations{{
ID: uuid.Nil,
Name: "London",
Address: &address,
}},
Events: []model.Events{{
ID: uuid.Nil,
Name: "Party",
Description: &description,
}},
},
}, nil
}
return agents.CreateAgentClient(agents.PROMPT)
}
func main() {
err := godotenv.Load()
if err != nil {
panic(err)
}
mode := os.Getenv("MODE")
log.Printf("Mode: %s\n", mode)
db, err := models.InitDatabase()
if err != nil {
panic(err)
}
imageModel := models.NewImageModel(db)
linkModel := models.NewLinkModel(db)
tagModel := models.NewTagModel(db)
textModel := models.NewTextModel(db)
locationModel := models.NewLocationModel(db)
eventModel := models.NewEventModel(db)
userModel := models.NewUserModel(db)
contactModel := models.NewContactModel(db)
noteModel := models.NewNoteModel(db)
listener := pq.NewListener(os.Getenv("DB_CONNECTION"), time.Second, time.Second, func(event pq.ListenerEventType, err error) {
if err != nil {
panic(err)
}
})
defer listener.Close()
go func() {
err := listener.Listen("new_image")
if err != nil {
panic(err)
}
for {
select {
case parameters := <-listener.Notify:
imageId := uuid.MustParse(parameters.Extra)
ctx := context.Background()
go func() {
openAiClient, err := GetAiClient()
if err != nil {
panic(err)
}
locationAgent, err := agents.NewLocationEventAgent(locationModel, eventModel, contactModel)
if err != nil {
panic(err)
}
noteAgent, err := agents.NewNoteAgent(noteModel)
if err != nil {
panic(err)
}
orchestrator, err := agents.NewOrchestratorAgent()
if err != nil {
panic(err)
}
image, err := imageModel.GetToProcessWithData(ctx, imageId)
if err != nil {
log.Println("Failed to GetToProcessWithData")
log.Println(err)
return
}
userImage, err := imageModel.FinishProcessing(ctx, image.ID)
if err != nil {
log.Println("Failed to FinishProcessing")
log.Println(err)
return
}
orchestrator.Orchestrate(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
return
// TODO: this can very much be parallel
log.Println("Calling locationAgent!")
err = locationAgent.GetLocations(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
log.Println(err)
log.Println("Calling noteAgent!")
err = noteAgent.GetNotes(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
log.Println(err)
return
imageInfo, err := openAiClient.GetImageInfo(image.Image.ImageName, image.Image.Image)
if err != nil {
log.Println("Failed to GetImageInfo")
log.Println(err)
return
}
err = tagModel.SaveToImage(ctx, userImage.ImageID, imageInfo.Tags)
if err != nil {
log.Println("Failed to save tags")
log.Println(err)
return
}
err = linkModel.Save(ctx, userImage.ImageID, imageInfo.Links)
if err != nil {
log.Println("Failed to save links")
log.Println(err)
return
}
err = textModel.Save(ctx, userImage.ImageID, imageInfo.Text)
if err != nil {
log.Println("Failed to save text")
log.Println(err)
return
}
}()
}
}
}()
r := chi.NewRouter()
r.Use(middleware.Logger)
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Content-Type", "application/json")
next.ServeHTTP(w, r)
})
})
r.Options("/*", func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Access-Control-Allow-Origin", "*")
w.Header().Add("Access-Control-Allow-Credentials", "*")
w.Header().Add("Access-Control-Allow-Headers", "*")
})
r.Get("/image", func(w http.ResponseWriter, r *http.Request) {
userId := r.Header.Get("userId")
w.Header().Add("Access-Control-Allow-Origin", "*")
w.Header().Add("Access-Control-Allow-Credentials", "*")
w.Header().Add("Access-Control-Allow-Headers", "*")
images, err := userModel.ListWithProperties(r.Context(), uuid.MustParse(userId))
if err != nil {
log.Println(err)
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, "Something went wrong")
return
}
log.Println(images)
type DataType struct {
Type string `json:"type"`
Data any `json:"data"`
}
dataTypes := make([]DataType, 0)
for _, image := range images {
for _, location := range image.Locations {
dataTypes = append(dataTypes, DataType{
Type: "location",
Data: location,
})
}
for _, event := range image.Events {
dataTypes = append(dataTypes, DataType{
Type: "event",
Data: event,
})
}
for _, note := range image.Notes {
dataTypes = append(dataTypes, DataType{
Type: "note",
Data: note,
})
}
}
jsonImages, err := json.Marshal(dataTypes)
if err != nil {
log.Println(err)
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "Could not create JSON response for this image")
return
}
w.Write(jsonImages)
})
r.Get("/image/{id}", func(w http.ResponseWriter, r *http.Request) {
imageId := r.PathValue("id")
w.Header().Add("Access-Control-Allow-Origin", "*")
w.Header().Add("Access-Control-Allow-Credentials", "*")
w.Header().Add("Access-Control-Allow-Headers", "*")
// TODO: really need authorization here!
image, err := imageModel.Get(r.Context(), uuid.MustParse(imageId))
if err != nil {
log.Println(err)
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, "Could not get image")
return
}
// TODO: this could be part of the db table
extension := filepath.Ext(image.Image.ImageName)
extension = extension[1:]
w.Header().Add("Content-Type", "image/"+extension)
w.Write(image.Image.Image)
})
r.Post("/image/{name}", func(w http.ResponseWriter, r *http.Request) {
imageName := r.PathValue("name")
userId := r.Header.Get("userId")
w.Header().Add("Access-Control-Allow-Origin", "*")
w.Header().Add("Access-Control-Allow-Credentials", "*")
w.Header().Add("Access-Control-Allow-Headers", "*")
if len(imageName) == 0 {
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "You need to provide a name in the path")
return
}
contentType := r.Header.Get("Content-Type")
log.Println(contentType)
// TODO: length checks on body
// TODO: extract this shit out
image := make([]byte, 0)
if contentType == "application/base64" {
decoder := base64.NewDecoder(base64.StdEncoding, r.Body)
buf := &bytes.Buffer{}
decodedIamge, err := io.Copy(buf, decoder)
if err != nil {
log.Println(err)
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "bruh, base64 aint decoding")
return
}
fmt.Println(string(image))
fmt.Println(decodedIamge)
image = buf.Bytes()
} else if contentType == "application/oclet-stream" {
bodyData, err := io.ReadAll(r.Body)
if err != nil {
log.Println(err)
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "bruh, binary aint binaring")
return
}
// TODO: check headers
image = bodyData
} else {
log.Println("bad stuff?")
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "Bruh, you need oclet stream or base64")
return
}
if err != nil {
log.Println("First case")
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "Couldnt read the image from the request body")
return
}
userImage, err := imageModel.Process(r.Context(), uuid.MustParse(userId), model.Image{
Image: image,
ImageName: imageName,
})
if err != nil {
log.Println("Second case")
log.Println(err)
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "Could not save image to DB")
return
}
jsonUserImage, err := json.Marshal(userImage)
if err != nil {
log.Println("Third case")
log.Println(err)
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "Could not create JSON response for this image")
return
}
w.WriteHeader(http.StatusCreated)
fmt.Fprint(w, string(jsonUserImage))
w.Header().Add("Content-Type", "application/json")
})
log.Println("Listening and serving on port 3040.")
http.ListenAndServe(":3040", r)
}