Compare commits
5 Commits
main
...
feat/embed
Author | SHA1 | Date | |
---|---|---|---|
936855fe78 | |||
8fde207233 | |||
9ae7fe3077 | |||
85ff91a11b | |||
b58efc413d |
102
backend/agents/client/embeddings.go
Normal file
102
backend/agents/client/embeddings.go
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
type EmbeddingsClient struct {
|
||||||
|
url string
|
||||||
|
apiKey string
|
||||||
|
|
||||||
|
Do func(req *http.Request) (*http.Response, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmbeddingsRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Input string `json:"input"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmbeddingsData struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
Embeddings []float64 `json:"embedding"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmbeddingsResponse struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
Data []EmbeddingsData `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateEmbeddingsClient() EmbeddingsClient {
|
||||||
|
apiKey := os.Getenv(OPENAI_API_KEY)
|
||||||
|
|
||||||
|
if len(apiKey) == 0 {
|
||||||
|
panic("No api key")
|
||||||
|
}
|
||||||
|
|
||||||
|
return EmbeddingsClient{
|
||||||
|
apiKey: apiKey,
|
||||||
|
url: "https://api.openai.com/v1/embeddings",
|
||||||
|
Do: func(req *http.Request) (*http.Response, error) {
|
||||||
|
client := &http.Client{}
|
||||||
|
return client.Do(req)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (client EmbeddingsClient) getRequest(text string) (*http.Request, error) {
|
||||||
|
embeddingsReq := EmbeddingsRequest{
|
||||||
|
Model: "text-embedding-3-large",
|
||||||
|
Input: text,
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonEmbeddingsBody, err := json.Marshal(embeddingsReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequest("POST", client.url, bytes.NewBuffer(jsonEmbeddingsBody))
|
||||||
|
if err != nil {
|
||||||
|
return req, err
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Add("Authorization", "Bearer "+client.apiKey)
|
||||||
|
req.Header.Add("Content-Type", "application/json")
|
||||||
|
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (client EmbeddingsClient) Request(text string) (EmbeddingsData, error) {
|
||||||
|
httpRequest, err := client.getRequest(text)
|
||||||
|
if err != nil {
|
||||||
|
return EmbeddingsData{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Do(httpRequest)
|
||||||
|
if err != nil {
|
||||||
|
return EmbeddingsData{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
response, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return EmbeddingsData{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
embeddingsResponse := EmbeddingsResponse{}
|
||||||
|
err = json.Unmarshal(response, &embeddingsResponse)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return EmbeddingsData{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(embeddingsResponse.Data) != 1 {
|
||||||
|
return EmbeddingsData{}, errors.New("Unsupported. We currently only accept 1 choice from AI.")
|
||||||
|
}
|
||||||
|
|
||||||
|
return embeddingsResponse.Data[0], nil
|
||||||
|
}
|
68
backend/agents/embeddings_agent.go
Normal file
68
backend/agents/embeddings_agent.go
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
package agents
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"screenmark/screenmark/agents/client"
|
||||||
|
"screenmark/screenmark/models"
|
||||||
|
|
||||||
|
"github.com/charmbracelet/log"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
const embeddingsAgentPropmt = `
|
||||||
|
You are an agent who's job it is to describe the contents of an image.
|
||||||
|
This description should be detailed as it will be used to create embeddings from the image.
|
||||||
|
You should focus more on the content of the image, rather than it's appearence.
|
||||||
|
`
|
||||||
|
|
||||||
|
type EmbeddingAgent struct {
|
||||||
|
embeddings client.EmbeddingsClient
|
||||||
|
|
||||||
|
client client.AgentClient
|
||||||
|
|
||||||
|
imageModel models.ImageModel
|
||||||
|
|
||||||
|
log *log.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func (agent EmbeddingAgent) GetEmbeddings(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
|
||||||
|
request := client.AgentRequestBody{
|
||||||
|
Model: "gpt-4.1-mini",
|
||||||
|
Temperature: 0.3,
|
||||||
|
ResponseFormat: client.ResponseFormat{
|
||||||
|
Type: "text",
|
||||||
|
},
|
||||||
|
Chat: &client.Chat{
|
||||||
|
Messages: make([]client.ChatMessage, 0),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
request.Chat.AddSystem(embeddingsAgentPropmt)
|
||||||
|
request.Chat.AddImage(imageName, imageData, nil)
|
||||||
|
|
||||||
|
resp, err := agent.client.Request(&request)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
description := resp.Choices[0].Message.Content
|
||||||
|
log.Info(description)
|
||||||
|
|
||||||
|
embeddings, err := agent.embeddings.Request(description)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return agent.imageModel.UpdateEmbedding(context.Background(), imageId, embeddings.Embeddings)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewEmbeddingsAgent(log *log.Logger, imageModel models.ImageModel) EmbeddingAgent {
|
||||||
|
return EmbeddingAgent{
|
||||||
|
client: client.CreateAgentClient(client.CreateAgentClientOptions{
|
||||||
|
SystemPrompt: embeddingsAgentPropmt,
|
||||||
|
Log: log,
|
||||||
|
}),
|
||||||
|
embeddings: client.CreateEmbeddingsClient(),
|
||||||
|
imageModel: imageModel,
|
||||||
|
}
|
||||||
|
}
|
@ -67,11 +67,19 @@ func ListenNewImageEvents(db *sql.DB, notifier *Notifier[Notification]) {
|
|||||||
locationAgent := agents.NewLocationAgent(createLogger("Locations 📍", splitWriter), locationModel)
|
locationAgent := agents.NewLocationAgent(createLogger("Locations 📍", splitWriter), locationModel)
|
||||||
eventAgent := agents.NewEventAgent(createLogger("Events 📅", splitWriter), eventModel, locationModel)
|
eventAgent := agents.NewEventAgent(createLogger("Events 📅", splitWriter), eventModel, locationModel)
|
||||||
|
|
||||||
|
embeddings := agents.NewEmbeddingsAgent(createLogger("Embeddings 📊", splitWriter), imageModel)
|
||||||
|
|
||||||
if err := imageModel.StartProcessing(ctx, image.ID); err != nil {
|
if err := imageModel.StartProcessing(ctx, image.ID); err != nil {
|
||||||
databaseEventLog.Error("Failed to FinishProcessing", "error", err)
|
databaseEventLog.Error("Failed to FinishProcessing", "error", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = embeddings.GetEmbeddings(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
|
||||||
|
if err != nil {
|
||||||
|
databaseEventLog.Error("Failed to get embeddings", "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
orchestrator := agents.NewOrchestratorAgent(createLogger("Orchestrator 🎼", splitWriter), noteAgent, contactAgent, locationAgent, eventAgent, image.Image.ImageName, image.Image.Image)
|
orchestrator := agents.NewOrchestratorAgent(createLogger("Orchestrator 🎼", splitWriter), noteAgent, contactAgent, locationAgent, eventAgent, image.Image.ImageName, image.Image.Image)
|
||||||
orchestrator.RunAgent(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
|
orchestrator.RunAgent(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
|
||||||
_, err = imageModel.FinishProcessing(ctx, image.ID)
|
_, err = imageModel.FinishProcessing(ctx, image.ID)
|
||||||
@ -184,7 +192,6 @@ func CreateEventsHandler(notifier *Notifier[Notification]) http.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("Sending msg %s\n", msg)
|
|
||||||
fmt.Fprintf(w, "event: data\ndata: %s\n\n", string(msgString))
|
fmt.Fprintf(w, "event: data\ndata: %s\n\n", string(msgString))
|
||||||
w.(http.Flusher).Flush()
|
w.(http.Flusher).Flush()
|
||||||
}
|
}
|
||||||
|
@ -186,6 +186,36 @@ func main() {
|
|||||||
w.Write(jsonImages)
|
w.Write(jsonImages)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
r.Get("/image-similar/{id}", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// TODO: authentication
|
||||||
|
stringImageId := r.PathValue("id")
|
||||||
|
|
||||||
|
imageId, err := uuid.Parse(stringImageId)
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(http.StatusForbidden)
|
||||||
|
fmt.Fprintf(w, "You cannot read this")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
similarImages, err := imageModel.GetSimilar(r.Context(), imageId)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
fmt.Fprintf(w, "Something went wrong")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonImages, err := json.Marshal(similarImages)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
fmt.Fprintf(w, "Could not create JSON response for this image")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Write(jsonImages)
|
||||||
|
})
|
||||||
|
|
||||||
r.Post("/image/{name}", func(w http.ResponseWriter, r *http.Request) {
|
r.Post("/image/{name}", func(w http.ResponseWriter, r *http.Request) {
|
||||||
imageName := r.PathValue("name")
|
imageName := r.PathValue("name")
|
||||||
userId := r.Context().Value(USER_ID).(uuid.UUID)
|
userId := r.Context().Value(USER_ID).(uuid.UUID)
|
||||||
|
@ -7,6 +7,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"screenmark/screenmark/.gen/haystack/haystack/model"
|
"screenmark/screenmark/.gen/haystack/haystack/model"
|
||||||
. "screenmark/screenmark/.gen/haystack/haystack/table"
|
. "screenmark/screenmark/.gen/haystack/haystack/table"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
. "github.com/go-jet/jet/v2/postgres"
|
. "github.com/go-jet/jet/v2/postgres"
|
||||||
|
|
||||||
@ -180,6 +182,58 @@ func (m ImageModel) IsUserAuthorized(ctx context.Context, imageId uuid.UUID, use
|
|||||||
return err != nil && userImage.UserID.String() == userId.String()
|
return err != nil && userImage.UserID.String() == userId.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m ImageModel) UpdateEmbedding(ctx context.Context, imageId uuid.UUID, embedding []float64) error {
|
||||||
|
stringSlice := make([]string, len(embedding))
|
||||||
|
|
||||||
|
for i, f := range embedding {
|
||||||
|
stringSlice[i] = strconv.FormatFloat(f, 'f', -1, 64)
|
||||||
|
}
|
||||||
|
|
||||||
|
embeddingString := "[" + strings.Join(stringSlice, ",") + "]"
|
||||||
|
|
||||||
|
_, err := m.dbPool.ExecContext(ctx, "UPDATE haystack.image SET embedding = $1::vector WHERE id = $2", embeddingString, imageId)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
const SIMILARITY_THRESHOLD = 0.5
|
||||||
|
|
||||||
|
func (m ImageModel) GetSimilar(ctx context.Context, imageId uuid.UUID) ([]model.Image, error) {
|
||||||
|
stmt, err := m.dbPool.PrepareContext(ctx, `
|
||||||
|
WITH similarities AS (
|
||||||
|
SELECT 1 - (embedding <=> (
|
||||||
|
SELECT embedding
|
||||||
|
FROM haystack.image
|
||||||
|
WHERE ID = $1
|
||||||
|
)) AS cosine_similarity, id, image_name FROM haystack.image
|
||||||
|
WHERE embedding IS NOT NULL
|
||||||
|
)
|
||||||
|
SELECT id, image_name
|
||||||
|
FROM similarities
|
||||||
|
WHERE cosine_similarity > $2
|
||||||
|
`)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return []model.Image{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
images := []model.Image{}
|
||||||
|
|
||||||
|
rows, err := stmt.QueryContext(ctx, imageId, SIMILARITY_THRESHOLD)
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
image := model.Image{}
|
||||||
|
err := rows.Scan(&image.ID, &image.ImageName)
|
||||||
|
if err != nil {
|
||||||
|
return []model.Image{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
images = append(images, image)
|
||||||
|
}
|
||||||
|
|
||||||
|
return images, nil
|
||||||
|
}
|
||||||
|
|
||||||
func NewImageModel(db *sql.DB) ImageModel {
|
func NewImageModel(db *sql.DB) ImageModel {
|
||||||
return ImageModel{dbPool: db}
|
return ImageModel{dbPool: db}
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
DROP SCHEMA IF EXISTS haystack CASCADE;
|
DROP SCHEMA IF EXISTS haystack CASCADE;
|
||||||
|
|
||||||
CREATE SCHEMA haystack;
|
CREATE SCHEMA haystack;
|
||||||
|
CREATE EXTENSION IF NOT EXISTS vector;
|
||||||
|
|
||||||
/* -----| Enums |----- */
|
/* -----| Enums |----- */
|
||||||
|
|
||||||
@ -16,7 +17,9 @@ CREATE TABLE haystack.users (
|
|||||||
CREATE TABLE haystack.image (
|
CREATE TABLE haystack.image (
|
||||||
id uuid PRIMARY KEY DEFAULT gen_random_uuid(),
|
id uuid PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
image_name TEXT NOT NULL,
|
image_name TEXT NOT NULL,
|
||||||
image BYTEA NOT NULL
|
image BYTEA NOT NULL,
|
||||||
|
|
||||||
|
embedding vector(3072)
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE TABLE haystack.user_images_to_process (
|
CREATE TABLE haystack.user_images_to_process (
|
||||||
|
@ -1,4 +0,0 @@
|
|||||||
name=haystack-db-dump-$(date "+%Y-%m-%d").sql
|
|
||||||
pg_dump haystack > $name
|
|
||||||
rsync -avH $name zh3586@zh3586.rsync.net:Backups/Haystack/
|
|
||||||
rm $name
|
|
@ -1,6 +1,6 @@
|
|||||||
import { A, useParams } from "@solidjs/router";
|
import { A, useParams } from "@solidjs/router";
|
||||||
import { For, Show, type Component } from "solid-js";
|
import { createResource, For, Show, type Component } from "solid-js";
|
||||||
import { base, type UserImage } from "./network";
|
import { base, getImageSimilar, type UserImage } from "./network";
|
||||||
import { useSearchImageContext } from "./contexts/SearchImageContext";
|
import { useSearchImageContext } from "./contexts/SearchImageContext";
|
||||||
import { SearchCard } from "./components/search-card/SearchCard";
|
import { SearchCard } from "./components/search-card/SearchCard";
|
||||||
import { IconArrowLeft } from "@tabler/icons-solidjs";
|
import { IconArrowLeft } from "@tabler/icons-solidjs";
|
||||||
@ -11,6 +11,8 @@ export const Image: Component = () => {
|
|||||||
|
|
||||||
const { imagesWithProperties } = useSearchImageContext();
|
const { imagesWithProperties } = useSearchImageContext();
|
||||||
|
|
||||||
|
const [similarImages] = createResource(() => getImageSimilar(imageId));
|
||||||
|
|
||||||
const imageProperties = (): UserImage[] | undefined =>
|
const imageProperties = (): UserImage[] | undefined =>
|
||||||
Object.entries(imagesWithProperties()).find(
|
Object.entries(imagesWithProperties()).find(
|
||||||
([id]) => id === imageId,
|
([id]) => id === imageId,
|
||||||
@ -43,6 +45,21 @@ export const Image: Component = () => {
|
|||||||
)}
|
)}
|
||||||
</Show>
|
</Show>
|
||||||
</div>
|
</div>
|
||||||
|
<div class="w-3/4 grid grid-cols-9 gap-2 grid-flow-row-dense py-4">
|
||||||
|
<Show when={similarImages()}>
|
||||||
|
{(images) => (
|
||||||
|
<For each={images()}>
|
||||||
|
{(image) => (
|
||||||
|
<img
|
||||||
|
alt="similar"
|
||||||
|
class="col-span-3"
|
||||||
|
src={`${base}/image/${image.ID}`}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</For>
|
||||||
|
)}
|
||||||
|
</Show>
|
||||||
|
</div>
|
||||||
</main>
|
</main>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -193,6 +193,25 @@ export const getImage = async (imageId: string): Promise<UserImage> => {
|
|||||||
return parse(dataTypeValidator, res);
|
return parse(dataTypeValidator, res);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const similarImageValidator = array(
|
||||||
|
strictObject({
|
||||||
|
ID: pipe(string(), uuid()),
|
||||||
|
ImageName: string(),
|
||||||
|
Image: null_(),
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
export const getImageSimilar = async (
|
||||||
|
imageId: string,
|
||||||
|
): Promise<InferOutput<typeof similarImageValidator>> => {
|
||||||
|
const request = getBaseAuthorizedRequest({
|
||||||
|
path: `image-similar/${imageId}`,
|
||||||
|
});
|
||||||
|
|
||||||
|
const res = await fetch(request).then((res) => res.json());
|
||||||
|
return parse(similarImageValidator, res);
|
||||||
|
};
|
||||||
|
|
||||||
export const postLogin = async (email: string): Promise<void> => {
|
export const postLogin = async (email: string): Promise<void> => {
|
||||||
const request = getBaseRequest({
|
const request = getBaseRequest({
|
||||||
path: "login",
|
path: "login",
|
||||||
|
Reference in New Issue
Block a user