5 Commits

8 changed files with 304 additions and 4 deletions

View File

@ -0,0 +1,102 @@
package client
import (
"bytes"
"encoding/json"
"errors"
"io"
"net/http"
"os"
)
type EmbeddingsClient struct {
url string
apiKey string
Do func(req *http.Request) (*http.Response, error)
}
type EmbeddingsRequest struct {
Model string `json:"model"`
Input string `json:"input"`
}
type EmbeddingsData struct {
Object string `json:"object"`
Index int `json:"index"`
Embeddings []float64 `json:"embedding"`
}
type EmbeddingsResponse struct {
Object string `json:"object"`
Data []EmbeddingsData `json:"data"`
}
func CreateEmbeddingsClient() EmbeddingsClient {
apiKey := os.Getenv(OPENAI_API_KEY)
if len(apiKey) == 0 {
panic("No api key")
}
return EmbeddingsClient{
apiKey: apiKey,
url: "https://api.openai.com/v1/embeddings",
Do: func(req *http.Request) (*http.Response, error) {
client := &http.Client{}
return client.Do(req)
},
}
}
func (client EmbeddingsClient) getRequest(text string) (*http.Request, error) {
embeddingsReq := EmbeddingsRequest{
Model: "text-embedding-3-large",
Input: text,
}
jsonEmbeddingsBody, err := json.Marshal(embeddingsReq)
if err != nil {
return nil, err
}
req, err := http.NewRequest("POST", client.url, bytes.NewBuffer(jsonEmbeddingsBody))
if err != nil {
return req, err
}
req.Header.Add("Authorization", "Bearer "+client.apiKey)
req.Header.Add("Content-Type", "application/json")
return req, nil
}
func (client EmbeddingsClient) Request(text string) (EmbeddingsData, error) {
httpRequest, err := client.getRequest(text)
if err != nil {
return EmbeddingsData{}, err
}
resp, err := client.Do(httpRequest)
if err != nil {
return EmbeddingsData{}, err
}
response, err := io.ReadAll(resp.Body)
if err != nil {
return EmbeddingsData{}, err
}
embeddingsResponse := EmbeddingsResponse{}
err = json.Unmarshal(response, &embeddingsResponse)
if err != nil {
return EmbeddingsData{}, err
}
if len(embeddingsResponse.Data) != 1 {
return EmbeddingsData{}, errors.New("Unsupported. We currently only accept 1 choice from AI.")
}
return embeddingsResponse.Data[0], nil
}

View File

@ -0,0 +1,68 @@
package agents
import (
"context"
"screenmark/screenmark/agents/client"
"screenmark/screenmark/models"
"github.com/charmbracelet/log"
"github.com/google/uuid"
)
const embeddingsAgentPropmt = `
You are an agent who's job it is to describe the contents of an image.
This description should be detailed as it will be used to create embeddings from the image.
You should focus more on the content of the image, rather than it's appearence.
`
type EmbeddingAgent struct {
embeddings client.EmbeddingsClient
client client.AgentClient
imageModel models.ImageModel
log *log.Logger
}
func (agent EmbeddingAgent) GetEmbeddings(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
request := client.AgentRequestBody{
Model: "gpt-4.1-mini",
Temperature: 0.3,
ResponseFormat: client.ResponseFormat{
Type: "text",
},
Chat: &client.Chat{
Messages: make([]client.ChatMessage, 0),
},
}
request.Chat.AddSystem(embeddingsAgentPropmt)
request.Chat.AddImage(imageName, imageData, nil)
resp, err := agent.client.Request(&request)
if err != nil {
return err
}
description := resp.Choices[0].Message.Content
log.Info(description)
embeddings, err := agent.embeddings.Request(description)
if err != nil {
return err
}
return agent.imageModel.UpdateEmbedding(context.Background(), imageId, embeddings.Embeddings)
}
func NewEmbeddingsAgent(log *log.Logger, imageModel models.ImageModel) EmbeddingAgent {
return EmbeddingAgent{
client: client.CreateAgentClient(client.CreateAgentClientOptions{
SystemPrompt: embeddingsAgentPropmt,
Log: log,
}),
embeddings: client.CreateEmbeddingsClient(),
imageModel: imageModel,
}
}

View File

@ -67,11 +67,19 @@ func ListenNewImageEvents(db *sql.DB, notifier *Notifier[Notification]) {
locationAgent := agents.NewLocationAgent(createLogger("Locations 📍", splitWriter), locationModel)
eventAgent := agents.NewEventAgent(createLogger("Events 📅", splitWriter), eventModel, locationModel)
embeddings := agents.NewEmbeddingsAgent(createLogger("Embeddings 📊", splitWriter), imageModel)
if err := imageModel.StartProcessing(ctx, image.ID); err != nil {
databaseEventLog.Error("Failed to FinishProcessing", "error", err)
return
}
err = embeddings.GetEmbeddings(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
if err != nil {
databaseEventLog.Error("Failed to get embeddings", "error", err)
return
}
orchestrator := agents.NewOrchestratorAgent(createLogger("Orchestrator 🎼", splitWriter), noteAgent, contactAgent, locationAgent, eventAgent, image.Image.ImageName, image.Image.Image)
orchestrator.RunAgent(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
_, err = imageModel.FinishProcessing(ctx, image.ID)
@ -184,7 +192,6 @@ func CreateEventsHandler(notifier *Notifier[Notification]) http.HandlerFunc {
return
}
fmt.Printf("Sending msg %s\n", msg)
fmt.Fprintf(w, "event: data\ndata: %s\n\n", string(msgString))
w.(http.Flusher).Flush()
}

View File

@ -186,6 +186,36 @@ func main() {
w.Write(jsonImages)
})
r.Get("/image-similar/{id}", func(w http.ResponseWriter, r *http.Request) {
// TODO: authentication
stringImageId := r.PathValue("id")
imageId, err := uuid.Parse(stringImageId)
if err != nil {
w.WriteHeader(http.StatusForbidden)
fmt.Fprintf(w, "You cannot read this")
return
}
similarImages, err := imageModel.GetSimilar(r.Context(), imageId)
if err != nil {
log.Println(err)
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, "Something went wrong")
return
}
jsonImages, err := json.Marshal(similarImages)
if err != nil {
log.Println(err)
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "Could not create JSON response for this image")
return
}
w.Write(jsonImages)
})
r.Post("/image/{name}", func(w http.ResponseWriter, r *http.Request) {
imageName := r.PathValue("name")
userId := r.Context().Value(USER_ID).(uuid.UUID)

View File

@ -7,6 +7,8 @@ import (
"fmt"
"screenmark/screenmark/.gen/haystack/haystack/model"
. "screenmark/screenmark/.gen/haystack/haystack/table"
"strconv"
"strings"
. "github.com/go-jet/jet/v2/postgres"
@ -180,6 +182,58 @@ func (m ImageModel) IsUserAuthorized(ctx context.Context, imageId uuid.UUID, use
return err != nil && userImage.UserID.String() == userId.String()
}
func (m ImageModel) UpdateEmbedding(ctx context.Context, imageId uuid.UUID, embedding []float64) error {
stringSlice := make([]string, len(embedding))
for i, f := range embedding {
stringSlice[i] = strconv.FormatFloat(f, 'f', -1, 64)
}
embeddingString := "[" + strings.Join(stringSlice, ",") + "]"
_, err := m.dbPool.ExecContext(ctx, "UPDATE haystack.image SET embedding = $1::vector WHERE id = $2", embeddingString, imageId)
return err
}
const SIMILARITY_THRESHOLD = 0.5
func (m ImageModel) GetSimilar(ctx context.Context, imageId uuid.UUID) ([]model.Image, error) {
stmt, err := m.dbPool.PrepareContext(ctx, `
WITH similarities AS (
SELECT 1 - (embedding <=> (
SELECT embedding
FROM haystack.image
WHERE ID = $1
)) AS cosine_similarity, id, image_name FROM haystack.image
WHERE embedding IS NOT NULL
)
SELECT id, image_name
FROM similarities
WHERE cosine_similarity > $2
`)
if err != nil {
return []model.Image{}, err
}
images := []model.Image{}
rows, err := stmt.QueryContext(ctx, imageId, SIMILARITY_THRESHOLD)
for rows.Next() {
image := model.Image{}
err := rows.Scan(&image.ID, &image.ImageName)
if err != nil {
return []model.Image{}, err
}
images = append(images, image)
}
return images, nil
}
func NewImageModel(db *sql.DB) ImageModel {
return ImageModel{dbPool: db}
}

View File

@ -1,6 +1,7 @@
DROP SCHEMA IF EXISTS haystack CASCADE;
CREATE SCHEMA haystack;
CREATE EXTENSION IF NOT EXISTS vector;
/* -----| Enums |----- */
@ -16,7 +17,9 @@ CREATE TABLE haystack.users (
CREATE TABLE haystack.image (
id uuid PRIMARY KEY DEFAULT gen_random_uuid(),
image_name TEXT NOT NULL,
image BYTEA NOT NULL
image BYTEA NOT NULL,
embedding vector(3072)
);
CREATE TABLE haystack.user_images_to_process (

View File

@ -1,6 +1,6 @@
import { A, useParams } from "@solidjs/router";
import { For, Show, type Component } from "solid-js";
import { base, type UserImage } from "./network";
import { createResource, For, Show, type Component } from "solid-js";
import { base, getImageSimilar, type UserImage } from "./network";
import { useSearchImageContext } from "./contexts/SearchImageContext";
import { SearchCard } from "./components/search-card/SearchCard";
import { IconArrowLeft } from "@tabler/icons-solidjs";
@ -11,6 +11,8 @@ export const Image: Component = () => {
const { imagesWithProperties } = useSearchImageContext();
const [similarImages] = createResource(() => getImageSimilar(imageId));
const imageProperties = (): UserImage[] | undefined =>
Object.entries(imagesWithProperties()).find(
([id]) => id === imageId,
@ -43,6 +45,21 @@ export const Image: Component = () => {
)}
</Show>
</div>
<div class="w-3/4 grid grid-cols-9 gap-2 grid-flow-row-dense py-4">
<Show when={similarImages()}>
{(images) => (
<For each={images()}>
{(image) => (
<img
alt="similar"
class="col-span-3"
src={`${base}/image/${image.ID}`}
/>
)}
</For>
)}
</Show>
</div>
</main>
);
};

View File

@ -193,6 +193,25 @@ export const getImage = async (imageId: string): Promise<UserImage> => {
return parse(dataTypeValidator, res);
};
const similarImageValidator = array(
strictObject({
ID: pipe(string(), uuid()),
ImageName: string(),
Image: null_(),
}),
);
export const getImageSimilar = async (
imageId: string,
): Promise<InferOutput<typeof similarImageValidator>> => {
const request = getBaseAuthorizedRequest({
path: `image-similar/${imageId}`,
});
const res = await fetch(request).then((res) => res.json());
return parse(similarImageValidator, res);
};
export const postLogin = async (email: string): Promise<void> => {
const request = getBaseRequest({
path: "login",