feat: asking AI for a description and storing vectors on DB
This commit is contained in:
102
backend/agents/client/embeddings.go
Normal file
102
backend/agents/client/embeddings.go
Normal file
@ -0,0 +1,102 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
)
|
||||
|
||||
type EmbeddingsClient struct {
|
||||
url string
|
||||
apiKey string
|
||||
|
||||
Do func(req *http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
type EmbeddingsRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input string `json:"input"`
|
||||
}
|
||||
|
||||
type EmbeddingsData struct {
|
||||
Object string `json:"object"`
|
||||
Index int `json:"index"`
|
||||
Embeddings []float64 `json:"embedding"`
|
||||
}
|
||||
|
||||
type EmbeddingsResponse struct {
|
||||
Object string `json:"object"`
|
||||
Data []EmbeddingsData `json:"data"`
|
||||
}
|
||||
|
||||
func CreateEmbeddingsClient() EmbeddingsClient {
|
||||
apiKey := os.Getenv(OPENAI_API_KEY)
|
||||
|
||||
if len(apiKey) == 0 {
|
||||
panic("No api key")
|
||||
}
|
||||
|
||||
return EmbeddingsClient{
|
||||
apiKey: apiKey,
|
||||
url: "https://api.openai.com/v1/embeddings",
|
||||
Do: func(req *http.Request) (*http.Response, error) {
|
||||
client := &http.Client{}
|
||||
return client.Do(req)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (client EmbeddingsClient) getRequest(text string) (*http.Request, error) {
|
||||
embeddingsReq := EmbeddingsRequest{
|
||||
Model: "text-embedding-3-large",
|
||||
Input: text,
|
||||
}
|
||||
|
||||
jsonEmbeddingsBody, err := json.Marshal(embeddingsReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", client.url, bytes.NewBuffer(jsonEmbeddingsBody))
|
||||
if err != nil {
|
||||
return req, err
|
||||
}
|
||||
|
||||
req.Header.Add("Authorization", "Bearer "+client.apiKey)
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (client EmbeddingsClient) Request(text string) (EmbeddingsData, error) {
|
||||
httpRequest, err := client.getRequest(text)
|
||||
if err != nil {
|
||||
return EmbeddingsData{}, err
|
||||
}
|
||||
|
||||
resp, err := client.Do(httpRequest)
|
||||
if err != nil {
|
||||
return EmbeddingsData{}, err
|
||||
}
|
||||
|
||||
response, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return EmbeddingsData{}, err
|
||||
}
|
||||
|
||||
embeddingsResponse := EmbeddingsResponse{}
|
||||
err = json.Unmarshal(response, &embeddingsResponse)
|
||||
|
||||
if err != nil {
|
||||
return EmbeddingsData{}, err
|
||||
}
|
||||
|
||||
if len(embeddingsResponse.Data) != 1 {
|
||||
return EmbeddingsData{}, errors.New("Unsupported. We currently only accept 1 choice from AI.")
|
||||
}
|
||||
|
||||
return embeddingsResponse.Data[0], nil
|
||||
}
|
68
backend/agents/embeddings_agent.go
Normal file
68
backend/agents/embeddings_agent.go
Normal file
@ -0,0 +1,68 @@
|
||||
package agents
|
||||
|
||||
import (
|
||||
"context"
|
||||
"screenmark/screenmark/agents/client"
|
||||
"screenmark/screenmark/models"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const embeddingsAgentPropmt = `
|
||||
You are an agent who's job it is to describe the contents of an image.
|
||||
This description should be detailed as it will be used to create embeddings from the image.
|
||||
You should focus more on the content of the image, rather than it's appearence.
|
||||
`
|
||||
|
||||
type EmbeddingAgent struct {
|
||||
embeddings client.EmbeddingsClient
|
||||
|
||||
client client.AgentClient
|
||||
|
||||
imageModel models.ImageModel
|
||||
|
||||
log *log.Logger
|
||||
}
|
||||
|
||||
func (agent EmbeddingAgent) GetEmbeddings(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
|
||||
request := client.AgentRequestBody{
|
||||
Model: "gpt-4.1-mini",
|
||||
Temperature: 0.3,
|
||||
ResponseFormat: client.ResponseFormat{
|
||||
Type: "text",
|
||||
},
|
||||
Chat: &client.Chat{
|
||||
Messages: make([]client.ChatMessage, 0),
|
||||
},
|
||||
}
|
||||
|
||||
request.Chat.AddSystem(embeddingsAgentPropmt)
|
||||
request.Chat.AddImage(imageName, imageData, nil)
|
||||
|
||||
resp, err := agent.client.Request(&request)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
description := resp.Choices[0].Message.Content
|
||||
log.Info(description)
|
||||
|
||||
embeddings, err := agent.embeddings.Request(description)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return agent.imageModel.UpdateEmbedding(context.Background(), imageId, embeddings.Embeddings)
|
||||
}
|
||||
|
||||
func NewEmbeddingsAgent(log *log.Logger, imageModel models.ImageModel) EmbeddingAgent {
|
||||
return EmbeddingAgent{
|
||||
client: client.CreateAgentClient(client.CreateAgentClientOptions{
|
||||
SystemPrompt: embeddingsAgentPropmt,
|
||||
Log: log,
|
||||
}),
|
||||
embeddings: client.CreateEmbeddingsClient(),
|
||||
imageModel: imageModel,
|
||||
}
|
||||
}
|
@ -67,11 +67,19 @@ func ListenNewImageEvents(db *sql.DB, notifier *Notifier[Notification]) {
|
||||
locationAgent := agents.NewLocationAgent(createLogger("Locations 📍", splitWriter), locationModel)
|
||||
eventAgent := agents.NewEventAgent(createLogger("Events 📅", splitWriter), eventModel, locationModel)
|
||||
|
||||
embeddings := agents.NewEmbeddingsAgent(createLogger("Embeddings 📊", splitWriter), imageModel)
|
||||
|
||||
if err := imageModel.StartProcessing(ctx, image.ID); err != nil {
|
||||
databaseEventLog.Error("Failed to FinishProcessing", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = embeddings.GetEmbeddings(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
|
||||
if err != nil {
|
||||
databaseEventLog.Error("Failed to get embeddings", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
orchestrator := agents.NewOrchestratorAgent(createLogger("Orchestrator 🎼", splitWriter), noteAgent, contactAgent, locationAgent, eventAgent, image.Image.ImageName, image.Image.Image)
|
||||
orchestrator.RunAgent(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
|
||||
_, err = imageModel.FinishProcessing(ctx, image.ID)
|
||||
@ -184,7 +192,6 @@ func CreateEventsHandler(notifier *Notifier[Notification]) http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("Sending msg %s\n", msg)
|
||||
fmt.Fprintf(w, "event: data\ndata: %s\n\n", string(msgString))
|
||||
w.(http.Flusher).Flush()
|
||||
}
|
||||
|
@ -7,6 +7,8 @@ import (
|
||||
"fmt"
|
||||
"screenmark/screenmark/.gen/haystack/haystack/model"
|
||||
. "screenmark/screenmark/.gen/haystack/haystack/table"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
. "github.com/go-jet/jet/v2/postgres"
|
||||
|
||||
@ -180,6 +182,20 @@ func (m ImageModel) IsUserAuthorized(ctx context.Context, imageId uuid.UUID, use
|
||||
return err != nil && userImage.UserID.String() == userId.String()
|
||||
}
|
||||
|
||||
func (m ImageModel) UpdateEmbedding(ctx context.Context, imageId uuid.UUID, embedding []float64) error {
|
||||
stringSlice := make([]string, len(embedding))
|
||||
|
||||
for i, f := range embedding {
|
||||
stringSlice[i] = strconv.FormatFloat(f, 'f', -1, 64)
|
||||
}
|
||||
|
||||
embeddingString := "[" + strings.Join(stringSlice, ",") + "]"
|
||||
|
||||
_, err := m.dbPool.ExecContext(ctx, "UPDATE haystack.image SET embedding = $1::vector WHERE id = $2", embeddingString, imageId)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func NewImageModel(db *sql.DB) ImageModel {
|
||||
return ImageModel{dbPool: db}
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
DROP SCHEMA IF EXISTS haystack CASCADE;
|
||||
|
||||
CREATE SCHEMA haystack;
|
||||
CREATE EXTENSION IF NOT EXISTS vector;
|
||||
|
||||
/* -----| Enums |----- */
|
||||
|
||||
@ -16,7 +17,9 @@ CREATE TABLE haystack.users (
|
||||
CREATE TABLE haystack.image (
|
||||
id uuid PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
image_name TEXT NOT NULL,
|
||||
image BYTEA NOT NULL
|
||||
image BYTEA NOT NULL,
|
||||
|
||||
embedding vector(3072)
|
||||
);
|
||||
|
||||
CREATE TABLE haystack.user_images_to_process (
|
||||
|
Reference in New Issue
Block a user