Compare commits
1 Commits
feat/embed
...
main
Author | SHA1 | Date | |
---|---|---|---|
cb4a03015a |
@ -1,102 +0,0 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
)
|
||||
|
||||
type EmbeddingsClient struct {
|
||||
url string
|
||||
apiKey string
|
||||
|
||||
Do func(req *http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
type EmbeddingsRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input string `json:"input"`
|
||||
}
|
||||
|
||||
type EmbeddingsData struct {
|
||||
Object string `json:"object"`
|
||||
Index int `json:"index"`
|
||||
Embeddings []float64 `json:"embedding"`
|
||||
}
|
||||
|
||||
type EmbeddingsResponse struct {
|
||||
Object string `json:"object"`
|
||||
Data []EmbeddingsData `json:"data"`
|
||||
}
|
||||
|
||||
func CreateEmbeddingsClient() EmbeddingsClient {
|
||||
apiKey := os.Getenv(OPENAI_API_KEY)
|
||||
|
||||
if len(apiKey) == 0 {
|
||||
panic("No api key")
|
||||
}
|
||||
|
||||
return EmbeddingsClient{
|
||||
apiKey: apiKey,
|
||||
url: "https://api.openai.com/v1/embeddings",
|
||||
Do: func(req *http.Request) (*http.Response, error) {
|
||||
client := &http.Client{}
|
||||
return client.Do(req)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (client EmbeddingsClient) getRequest(text string) (*http.Request, error) {
|
||||
embeddingsReq := EmbeddingsRequest{
|
||||
Model: "text-embedding-3-large",
|
||||
Input: text,
|
||||
}
|
||||
|
||||
jsonEmbeddingsBody, err := json.Marshal(embeddingsReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", client.url, bytes.NewBuffer(jsonEmbeddingsBody))
|
||||
if err != nil {
|
||||
return req, err
|
||||
}
|
||||
|
||||
req.Header.Add("Authorization", "Bearer "+client.apiKey)
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (client EmbeddingsClient) Request(text string) (EmbeddingsData, error) {
|
||||
httpRequest, err := client.getRequest(text)
|
||||
if err != nil {
|
||||
return EmbeddingsData{}, err
|
||||
}
|
||||
|
||||
resp, err := client.Do(httpRequest)
|
||||
if err != nil {
|
||||
return EmbeddingsData{}, err
|
||||
}
|
||||
|
||||
response, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return EmbeddingsData{}, err
|
||||
}
|
||||
|
||||
embeddingsResponse := EmbeddingsResponse{}
|
||||
err = json.Unmarshal(response, &embeddingsResponse)
|
||||
|
||||
if err != nil {
|
||||
return EmbeddingsData{}, err
|
||||
}
|
||||
|
||||
if len(embeddingsResponse.Data) != 1 {
|
||||
return EmbeddingsData{}, errors.New("Unsupported. We currently only accept 1 choice from AI.")
|
||||
}
|
||||
|
||||
return embeddingsResponse.Data[0], nil
|
||||
}
|
@ -1,68 +0,0 @@
|
||||
package agents
|
||||
|
||||
import (
|
||||
"context"
|
||||
"screenmark/screenmark/agents/client"
|
||||
"screenmark/screenmark/models"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const embeddingsAgentPropmt = `
|
||||
You are an agent who's job it is to describe the contents of an image.
|
||||
This description should be detailed as it will be used to create embeddings from the image.
|
||||
You should focus more on the content of the image, rather than it's appearence.
|
||||
`
|
||||
|
||||
type EmbeddingAgent struct {
|
||||
embeddings client.EmbeddingsClient
|
||||
|
||||
client client.AgentClient
|
||||
|
||||
imageModel models.ImageModel
|
||||
|
||||
log *log.Logger
|
||||
}
|
||||
|
||||
func (agent EmbeddingAgent) GetEmbeddings(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
|
||||
request := client.AgentRequestBody{
|
||||
Model: "gpt-4.1-mini",
|
||||
Temperature: 0.3,
|
||||
ResponseFormat: client.ResponseFormat{
|
||||
Type: "text",
|
||||
},
|
||||
Chat: &client.Chat{
|
||||
Messages: make([]client.ChatMessage, 0),
|
||||
},
|
||||
}
|
||||
|
||||
request.Chat.AddSystem(embeddingsAgentPropmt)
|
||||
request.Chat.AddImage(imageName, imageData, nil)
|
||||
|
||||
resp, err := agent.client.Request(&request)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
description := resp.Choices[0].Message.Content
|
||||
log.Info(description)
|
||||
|
||||
embeddings, err := agent.embeddings.Request(description)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return agent.imageModel.UpdateEmbedding(context.Background(), imageId, embeddings.Embeddings)
|
||||
}
|
||||
|
||||
func NewEmbeddingsAgent(log *log.Logger, imageModel models.ImageModel) EmbeddingAgent {
|
||||
return EmbeddingAgent{
|
||||
client: client.CreateAgentClient(client.CreateAgentClientOptions{
|
||||
SystemPrompt: embeddingsAgentPropmt,
|
||||
Log: log,
|
||||
}),
|
||||
embeddings: client.CreateEmbeddingsClient(),
|
||||
imageModel: imageModel,
|
||||
}
|
||||
}
|
@ -67,19 +67,11 @@ func ListenNewImageEvents(db *sql.DB, notifier *Notifier[Notification]) {
|
||||
locationAgent := agents.NewLocationAgent(createLogger("Locations 📍", splitWriter), locationModel)
|
||||
eventAgent := agents.NewEventAgent(createLogger("Events 📅", splitWriter), eventModel, locationModel)
|
||||
|
||||
embeddings := agents.NewEmbeddingsAgent(createLogger("Embeddings 📊", splitWriter), imageModel)
|
||||
|
||||
if err := imageModel.StartProcessing(ctx, image.ID); err != nil {
|
||||
databaseEventLog.Error("Failed to FinishProcessing", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = embeddings.GetEmbeddings(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
|
||||
if err != nil {
|
||||
databaseEventLog.Error("Failed to get embeddings", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
orchestrator := agents.NewOrchestratorAgent(createLogger("Orchestrator 🎼", splitWriter), noteAgent, contactAgent, locationAgent, eventAgent, image.Image.ImageName, image.Image.Image)
|
||||
orchestrator.RunAgent(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
|
||||
_, err = imageModel.FinishProcessing(ctx, image.ID)
|
||||
@ -192,6 +184,7 @@ func CreateEventsHandler(notifier *Notifier[Notification]) http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("Sending msg %s\n", msg)
|
||||
fmt.Fprintf(w, "event: data\ndata: %s\n\n", string(msgString))
|
||||
w.(http.Flusher).Flush()
|
||||
}
|
||||
|
@ -186,36 +186,6 @@ func main() {
|
||||
w.Write(jsonImages)
|
||||
})
|
||||
|
||||
r.Get("/image-similar/{id}", func(w http.ResponseWriter, r *http.Request) {
|
||||
// TODO: authentication
|
||||
stringImageId := r.PathValue("id")
|
||||
|
||||
imageId, err := uuid.Parse(stringImageId)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
fmt.Fprintf(w, "You cannot read this")
|
||||
return
|
||||
}
|
||||
|
||||
similarImages, err := imageModel.GetSimilar(r.Context(), imageId)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
fmt.Fprintf(w, "Something went wrong")
|
||||
return
|
||||
}
|
||||
|
||||
jsonImages, err := json.Marshal(similarImages)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprintf(w, "Could not create JSON response for this image")
|
||||
return
|
||||
}
|
||||
|
||||
w.Write(jsonImages)
|
||||
})
|
||||
|
||||
r.Post("/image/{name}", func(w http.ResponseWriter, r *http.Request) {
|
||||
imageName := r.PathValue("name")
|
||||
userId := r.Context().Value(USER_ID).(uuid.UUID)
|
||||
|
@ -7,8 +7,6 @@ import (
|
||||
"fmt"
|
||||
"screenmark/screenmark/.gen/haystack/haystack/model"
|
||||
. "screenmark/screenmark/.gen/haystack/haystack/table"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
. "github.com/go-jet/jet/v2/postgres"
|
||||
|
||||
@ -182,58 +180,6 @@ func (m ImageModel) IsUserAuthorized(ctx context.Context, imageId uuid.UUID, use
|
||||
return err != nil && userImage.UserID.String() == userId.String()
|
||||
}
|
||||
|
||||
func (m ImageModel) UpdateEmbedding(ctx context.Context, imageId uuid.UUID, embedding []float64) error {
|
||||
stringSlice := make([]string, len(embedding))
|
||||
|
||||
for i, f := range embedding {
|
||||
stringSlice[i] = strconv.FormatFloat(f, 'f', -1, 64)
|
||||
}
|
||||
|
||||
embeddingString := "[" + strings.Join(stringSlice, ",") + "]"
|
||||
|
||||
_, err := m.dbPool.ExecContext(ctx, "UPDATE haystack.image SET embedding = $1::vector WHERE id = $2", embeddingString, imageId)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
const SIMILARITY_THRESHOLD = 0.5
|
||||
|
||||
func (m ImageModel) GetSimilar(ctx context.Context, imageId uuid.UUID) ([]model.Image, error) {
|
||||
stmt, err := m.dbPool.PrepareContext(ctx, `
|
||||
WITH similarities AS (
|
||||
SELECT 1 - (embedding <=> (
|
||||
SELECT embedding
|
||||
FROM haystack.image
|
||||
WHERE ID = $1
|
||||
)) AS cosine_similarity, id, image_name FROM haystack.image
|
||||
WHERE embedding IS NOT NULL
|
||||
)
|
||||
SELECT id, image_name
|
||||
FROM similarities
|
||||
WHERE cosine_similarity > $2
|
||||
`)
|
||||
|
||||
if err != nil {
|
||||
return []model.Image{}, err
|
||||
}
|
||||
|
||||
images := []model.Image{}
|
||||
|
||||
rows, err := stmt.QueryContext(ctx, imageId, SIMILARITY_THRESHOLD)
|
||||
|
||||
for rows.Next() {
|
||||
image := model.Image{}
|
||||
err := rows.Scan(&image.ID, &image.ImageName)
|
||||
if err != nil {
|
||||
return []model.Image{}, err
|
||||
}
|
||||
|
||||
images = append(images, image)
|
||||
}
|
||||
|
||||
return images, nil
|
||||
}
|
||||
|
||||
func NewImageModel(db *sql.DB) ImageModel {
|
||||
return ImageModel{dbPool: db}
|
||||
}
|
||||
|
@ -1,7 +1,6 @@
|
||||
DROP SCHEMA IF EXISTS haystack CASCADE;
|
||||
|
||||
CREATE SCHEMA haystack;
|
||||
CREATE EXTENSION IF NOT EXISTS vector;
|
||||
|
||||
/* -----| Enums |----- */
|
||||
|
||||
@ -17,9 +16,7 @@ CREATE TABLE haystack.users (
|
||||
CREATE TABLE haystack.image (
|
||||
id uuid PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
image_name TEXT NOT NULL,
|
||||
image BYTEA NOT NULL,
|
||||
|
||||
embedding vector(3072)
|
||||
image BYTEA NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE haystack.user_images_to_process (
|
||||
|
4
backup.bash
Normal file
4
backup.bash
Normal file
@ -0,0 +1,4 @@
|
||||
name=haystack-db-dump-$(date "+%Y-%m-%d").sql
|
||||
pg_dump haystack > $name
|
||||
rsync -avH $name zh3586@zh3586.rsync.net:Backups/Haystack/
|
||||
rm $name
|
@ -1,6 +1,6 @@
|
||||
import { A, useParams } from "@solidjs/router";
|
||||
import { createResource, For, Show, type Component } from "solid-js";
|
||||
import { base, getImageSimilar, type UserImage } from "./network";
|
||||
import { For, Show, type Component } from "solid-js";
|
||||
import { base, type UserImage } from "./network";
|
||||
import { useSearchImageContext } from "./contexts/SearchImageContext";
|
||||
import { SearchCard } from "./components/search-card/SearchCard";
|
||||
import { IconArrowLeft } from "@tabler/icons-solidjs";
|
||||
@ -11,8 +11,6 @@ export const Image: Component = () => {
|
||||
|
||||
const { imagesWithProperties } = useSearchImageContext();
|
||||
|
||||
const [similarImages] = createResource(() => getImageSimilar(imageId));
|
||||
|
||||
const imageProperties = (): UserImage[] | undefined =>
|
||||
Object.entries(imagesWithProperties()).find(
|
||||
([id]) => id === imageId,
|
||||
@ -45,21 +43,6 @@ export const Image: Component = () => {
|
||||
)}
|
||||
</Show>
|
||||
</div>
|
||||
<div class="w-3/4 grid grid-cols-9 gap-2 grid-flow-row-dense py-4">
|
||||
<Show when={similarImages()}>
|
||||
{(images) => (
|
||||
<For each={images()}>
|
||||
{(image) => (
|
||||
<img
|
||||
alt="similar"
|
||||
class="col-span-3"
|
||||
src={`${base}/image/${image.ID}`}
|
||||
/>
|
||||
)}
|
||||
</For>
|
||||
)}
|
||||
</Show>
|
||||
</div>
|
||||
</main>
|
||||
);
|
||||
};
|
||||
|
@ -193,25 +193,6 @@ export const getImage = async (imageId: string): Promise<UserImage> => {
|
||||
return parse(dataTypeValidator, res);
|
||||
};
|
||||
|
||||
const similarImageValidator = array(
|
||||
strictObject({
|
||||
ID: pipe(string(), uuid()),
|
||||
ImageName: string(),
|
||||
Image: null_(),
|
||||
}),
|
||||
);
|
||||
|
||||
export const getImageSimilar = async (
|
||||
imageId: string,
|
||||
): Promise<InferOutput<typeof similarImageValidator>> => {
|
||||
const request = getBaseAuthorizedRequest({
|
||||
path: `image-similar/${imageId}`,
|
||||
});
|
||||
|
||||
const res = await fetch(request).then((res) => res.json());
|
||||
return parse(similarImageValidator, res);
|
||||
};
|
||||
|
||||
export const postLogin = async (email: string): Promise<void> => {
|
||||
const request = getBaseRequest({
|
||||
path: "login",
|
||||
|
Reference in New Issue
Block a user