From b58efc413d8e8b405eec075fae3b572411281a11 Mon Sep 17 00:00:00 2001 From: John Costa Date: Tue, 13 May 2025 16:34:52 +0100 Subject: [PATCH] feat: asking AI for a description and storing vectors on DB --- backend/agents/client/embeddings.go | 102 ++++++++++++++++++++++++++++ backend/agents/embeddings_agent.go | 68 +++++++++++++++++++ backend/events.go | 9 ++- backend/models/image.go | 16 +++++ backend/schema.sql | 5 +- 5 files changed, 198 insertions(+), 2 deletions(-) create mode 100644 backend/agents/client/embeddings.go create mode 100644 backend/agents/embeddings_agent.go diff --git a/backend/agents/client/embeddings.go b/backend/agents/client/embeddings.go new file mode 100644 index 0000000..27a00c8 --- /dev/null +++ b/backend/agents/client/embeddings.go @@ -0,0 +1,102 @@ +package client + +import ( + "bytes" + "encoding/json" + "errors" + "io" + "net/http" + "os" +) + +type EmbeddingsClient struct { + url string + apiKey string + + Do func(req *http.Request) (*http.Response, error) +} + +type EmbeddingsRequest struct { + Model string `json:"model"` + Input string `json:"input"` +} + +type EmbeddingsData struct { + Object string `json:"object"` + Index int `json:"index"` + Embeddings []float64 `json:"embedding"` +} + +type EmbeddingsResponse struct { + Object string `json:"object"` + Data []EmbeddingsData `json:"data"` +} + +func CreateEmbeddingsClient() EmbeddingsClient { + apiKey := os.Getenv(OPENAI_API_KEY) + + if len(apiKey) == 0 { + panic("No api key") + } + + return EmbeddingsClient{ + apiKey: apiKey, + url: "https://api.openai.com/v1/embeddings", + Do: func(req *http.Request) (*http.Response, error) { + client := &http.Client{} + return client.Do(req) + }, + } +} + +func (client EmbeddingsClient) getRequest(text string) (*http.Request, error) { + embeddingsReq := EmbeddingsRequest{ + Model: "text-embedding-3-large", + Input: text, + } + + jsonEmbeddingsBody, err := json.Marshal(embeddingsReq) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("POST", client.url, bytes.NewBuffer(jsonEmbeddingsBody)) + if err != nil { + return req, err + } + + req.Header.Add("Authorization", "Bearer "+client.apiKey) + req.Header.Add("Content-Type", "application/json") + + return req, nil +} + +func (client EmbeddingsClient) Request(text string) (EmbeddingsData, error) { + httpRequest, err := client.getRequest(text) + if err != nil { + return EmbeddingsData{}, err + } + + resp, err := client.Do(httpRequest) + if err != nil { + return EmbeddingsData{}, err + } + + response, err := io.ReadAll(resp.Body) + if err != nil { + return EmbeddingsData{}, err + } + + embeddingsResponse := EmbeddingsResponse{} + err = json.Unmarshal(response, &embeddingsResponse) + + if err != nil { + return EmbeddingsData{}, err + } + + if len(embeddingsResponse.Data) != 1 { + return EmbeddingsData{}, errors.New("Unsupported. We currently only accept 1 choice from AI.") + } + + return embeddingsResponse.Data[0], nil +} diff --git a/backend/agents/embeddings_agent.go b/backend/agents/embeddings_agent.go new file mode 100644 index 0000000..446c97d --- /dev/null +++ b/backend/agents/embeddings_agent.go @@ -0,0 +1,68 @@ +package agents + +import ( + "context" + "screenmark/screenmark/agents/client" + "screenmark/screenmark/models" + + "github.com/charmbracelet/log" + "github.com/google/uuid" +) + +const embeddingsAgentPropmt = ` +You are an agent who's job it is to describe the contents of an image. +This description should be detailed as it will be used to create embeddings from the image. +You should focus more on the content of the image, rather than it's appearence. +` + +type EmbeddingAgent struct { + embeddings client.EmbeddingsClient + + client client.AgentClient + + imageModel models.ImageModel + + log *log.Logger +} + +func (agent EmbeddingAgent) GetEmbeddings(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error { + request := client.AgentRequestBody{ + Model: "gpt-4.1-mini", + Temperature: 0.3, + ResponseFormat: client.ResponseFormat{ + Type: "text", + }, + Chat: &client.Chat{ + Messages: make([]client.ChatMessage, 0), + }, + } + + request.Chat.AddSystem(embeddingsAgentPropmt) + request.Chat.AddImage(imageName, imageData, nil) + + resp, err := agent.client.Request(&request) + if err != nil { + return err + } + + description := resp.Choices[0].Message.Content + log.Info(description) + + embeddings, err := agent.embeddings.Request(description) + if err != nil { + return err + } + + return agent.imageModel.UpdateEmbedding(context.Background(), imageId, embeddings.Embeddings) +} + +func NewEmbeddingsAgent(log *log.Logger, imageModel models.ImageModel) EmbeddingAgent { + return EmbeddingAgent{ + client: client.CreateAgentClient(client.CreateAgentClientOptions{ + SystemPrompt: embeddingsAgentPropmt, + Log: log, + }), + embeddings: client.CreateEmbeddingsClient(), + imageModel: imageModel, + } +} diff --git a/backend/events.go b/backend/events.go index c7a93e3..e2819c8 100644 --- a/backend/events.go +++ b/backend/events.go @@ -67,11 +67,19 @@ func ListenNewImageEvents(db *sql.DB, notifier *Notifier[Notification]) { locationAgent := agents.NewLocationAgent(createLogger("Locations 📍", splitWriter), locationModel) eventAgent := agents.NewEventAgent(createLogger("Events 📅", splitWriter), eventModel, locationModel) + embeddings := agents.NewEmbeddingsAgent(createLogger("Embeddings 📊", splitWriter), imageModel) + if err := imageModel.StartProcessing(ctx, image.ID); err != nil { databaseEventLog.Error("Failed to FinishProcessing", "error", err) return } + err = embeddings.GetEmbeddings(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image) + if err != nil { + databaseEventLog.Error("Failed to get embeddings", "error", err) + return + } + orchestrator := agents.NewOrchestratorAgent(createLogger("Orchestrator 🎼", splitWriter), noteAgent, contactAgent, locationAgent, eventAgent, image.Image.ImageName, image.Image.Image) orchestrator.RunAgent(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image) _, err = imageModel.FinishProcessing(ctx, image.ID) @@ -184,7 +192,6 @@ func CreateEventsHandler(notifier *Notifier[Notification]) http.HandlerFunc { return } - fmt.Printf("Sending msg %s\n", msg) fmt.Fprintf(w, "event: data\ndata: %s\n\n", string(msgString)) w.(http.Flusher).Flush() } diff --git a/backend/models/image.go b/backend/models/image.go index 1406ba2..dc04dbf 100644 --- a/backend/models/image.go +++ b/backend/models/image.go @@ -7,6 +7,8 @@ import ( "fmt" "screenmark/screenmark/.gen/haystack/haystack/model" . "screenmark/screenmark/.gen/haystack/haystack/table" + "strconv" + "strings" . "github.com/go-jet/jet/v2/postgres" @@ -180,6 +182,20 @@ func (m ImageModel) IsUserAuthorized(ctx context.Context, imageId uuid.UUID, use return err != nil && userImage.UserID.String() == userId.String() } +func (m ImageModel) UpdateEmbedding(ctx context.Context, imageId uuid.UUID, embedding []float64) error { + stringSlice := make([]string, len(embedding)) + + for i, f := range embedding { + stringSlice[i] = strconv.FormatFloat(f, 'f', -1, 64) + } + + embeddingString := "[" + strings.Join(stringSlice, ",") + "]" + + _, err := m.dbPool.ExecContext(ctx, "UPDATE haystack.image SET embedding = $1::vector WHERE id = $2", embeddingString, imageId) + + return err +} + func NewImageModel(db *sql.DB) ImageModel { return ImageModel{dbPool: db} } diff --git a/backend/schema.sql b/backend/schema.sql index cf4d16f..da2210e 100644 --- a/backend/schema.sql +++ b/backend/schema.sql @@ -1,6 +1,7 @@ DROP SCHEMA IF EXISTS haystack CASCADE; CREATE SCHEMA haystack; +CREATE EXTENSION IF NOT EXISTS vector; /* -----| Enums |----- */ @@ -16,7 +17,9 @@ CREATE TABLE haystack.users ( CREATE TABLE haystack.image ( id uuid PRIMARY KEY DEFAULT gen_random_uuid(), image_name TEXT NOT NULL, - image BYTEA NOT NULL + image BYTEA NOT NULL, + + embedding vector(3072) ); CREATE TABLE haystack.user_images_to_process (