Compare commits
1 Commits
feat/embed
...
main
Author | SHA1 | Date | |
---|---|---|---|
cb4a03015a |
@ -1,102 +0,0 @@
|
|||||||
package client
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
)
|
|
||||||
|
|
||||||
type EmbeddingsClient struct {
|
|
||||||
url string
|
|
||||||
apiKey string
|
|
||||||
|
|
||||||
Do func(req *http.Request) (*http.Response, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type EmbeddingsRequest struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Input string `json:"input"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type EmbeddingsData struct {
|
|
||||||
Object string `json:"object"`
|
|
||||||
Index int `json:"index"`
|
|
||||||
Embeddings []float64 `json:"embedding"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type EmbeddingsResponse struct {
|
|
||||||
Object string `json:"object"`
|
|
||||||
Data []EmbeddingsData `json:"data"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func CreateEmbeddingsClient() EmbeddingsClient {
|
|
||||||
apiKey := os.Getenv(OPENAI_API_KEY)
|
|
||||||
|
|
||||||
if len(apiKey) == 0 {
|
|
||||||
panic("No api key")
|
|
||||||
}
|
|
||||||
|
|
||||||
return EmbeddingsClient{
|
|
||||||
apiKey: apiKey,
|
|
||||||
url: "https://api.openai.com/v1/embeddings",
|
|
||||||
Do: func(req *http.Request) (*http.Response, error) {
|
|
||||||
client := &http.Client{}
|
|
||||||
return client.Do(req)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (client EmbeddingsClient) getRequest(text string) (*http.Request, error) {
|
|
||||||
embeddingsReq := EmbeddingsRequest{
|
|
||||||
Model: "text-embedding-3-large",
|
|
||||||
Input: text,
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonEmbeddingsBody, err := json.Marshal(embeddingsReq)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := http.NewRequest("POST", client.url, bytes.NewBuffer(jsonEmbeddingsBody))
|
|
||||||
if err != nil {
|
|
||||||
return req, err
|
|
||||||
}
|
|
||||||
|
|
||||||
req.Header.Add("Authorization", "Bearer "+client.apiKey)
|
|
||||||
req.Header.Add("Content-Type", "application/json")
|
|
||||||
|
|
||||||
return req, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (client EmbeddingsClient) Request(text string) (EmbeddingsData, error) {
|
|
||||||
httpRequest, err := client.getRequest(text)
|
|
||||||
if err != nil {
|
|
||||||
return EmbeddingsData{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := client.Do(httpRequest)
|
|
||||||
if err != nil {
|
|
||||||
return EmbeddingsData{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
response, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return EmbeddingsData{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
embeddingsResponse := EmbeddingsResponse{}
|
|
||||||
err = json.Unmarshal(response, &embeddingsResponse)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return EmbeddingsData{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(embeddingsResponse.Data) != 1 {
|
|
||||||
return EmbeddingsData{}, errors.New("Unsupported. We currently only accept 1 choice from AI.")
|
|
||||||
}
|
|
||||||
|
|
||||||
return embeddingsResponse.Data[0], nil
|
|
||||||
}
|
|
@ -1,68 +0,0 @@
|
|||||||
package agents
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"screenmark/screenmark/agents/client"
|
|
||||||
"screenmark/screenmark/models"
|
|
||||||
|
|
||||||
"github.com/charmbracelet/log"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
)
|
|
||||||
|
|
||||||
const embeddingsAgentPropmt = `
|
|
||||||
You are an agent who's job it is to describe the contents of an image.
|
|
||||||
This description should be detailed as it will be used to create embeddings from the image.
|
|
||||||
You should focus more on the content of the image, rather than it's appearence.
|
|
||||||
`
|
|
||||||
|
|
||||||
type EmbeddingAgent struct {
|
|
||||||
embeddings client.EmbeddingsClient
|
|
||||||
|
|
||||||
client client.AgentClient
|
|
||||||
|
|
||||||
imageModel models.ImageModel
|
|
||||||
|
|
||||||
log *log.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
func (agent EmbeddingAgent) GetEmbeddings(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
|
|
||||||
request := client.AgentRequestBody{
|
|
||||||
Model: "gpt-4.1-mini",
|
|
||||||
Temperature: 0.3,
|
|
||||||
ResponseFormat: client.ResponseFormat{
|
|
||||||
Type: "text",
|
|
||||||
},
|
|
||||||
Chat: &client.Chat{
|
|
||||||
Messages: make([]client.ChatMessage, 0),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
request.Chat.AddSystem(embeddingsAgentPropmt)
|
|
||||||
request.Chat.AddImage(imageName, imageData, nil)
|
|
||||||
|
|
||||||
resp, err := agent.client.Request(&request)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
description := resp.Choices[0].Message.Content
|
|
||||||
log.Info(description)
|
|
||||||
|
|
||||||
embeddings, err := agent.embeddings.Request(description)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return agent.imageModel.UpdateEmbedding(context.Background(), imageId, embeddings.Embeddings)
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewEmbeddingsAgent(log *log.Logger, imageModel models.ImageModel) EmbeddingAgent {
|
|
||||||
return EmbeddingAgent{
|
|
||||||
client: client.CreateAgentClient(client.CreateAgentClientOptions{
|
|
||||||
SystemPrompt: embeddingsAgentPropmt,
|
|
||||||
Log: log,
|
|
||||||
}),
|
|
||||||
embeddings: client.CreateEmbeddingsClient(),
|
|
||||||
imageModel: imageModel,
|
|
||||||
}
|
|
||||||
}
|
|
@ -67,19 +67,11 @@ func ListenNewImageEvents(db *sql.DB, notifier *Notifier[Notification]) {
|
|||||||
locationAgent := agents.NewLocationAgent(createLogger("Locations 📍", splitWriter), locationModel)
|
locationAgent := agents.NewLocationAgent(createLogger("Locations 📍", splitWriter), locationModel)
|
||||||
eventAgent := agents.NewEventAgent(createLogger("Events 📅", splitWriter), eventModel, locationModel)
|
eventAgent := agents.NewEventAgent(createLogger("Events 📅", splitWriter), eventModel, locationModel)
|
||||||
|
|
||||||
embeddings := agents.NewEmbeddingsAgent(createLogger("Embeddings 📊", splitWriter), imageModel)
|
|
||||||
|
|
||||||
if err := imageModel.StartProcessing(ctx, image.ID); err != nil {
|
if err := imageModel.StartProcessing(ctx, image.ID); err != nil {
|
||||||
databaseEventLog.Error("Failed to FinishProcessing", "error", err)
|
databaseEventLog.Error("Failed to FinishProcessing", "error", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = embeddings.GetEmbeddings(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
|
|
||||||
if err != nil {
|
|
||||||
databaseEventLog.Error("Failed to get embeddings", "error", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
orchestrator := agents.NewOrchestratorAgent(createLogger("Orchestrator 🎼", splitWriter), noteAgent, contactAgent, locationAgent, eventAgent, image.Image.ImageName, image.Image.Image)
|
orchestrator := agents.NewOrchestratorAgent(createLogger("Orchestrator 🎼", splitWriter), noteAgent, contactAgent, locationAgent, eventAgent, image.Image.ImageName, image.Image.Image)
|
||||||
orchestrator.RunAgent(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
|
orchestrator.RunAgent(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
|
||||||
_, err = imageModel.FinishProcessing(ctx, image.ID)
|
_, err = imageModel.FinishProcessing(ctx, image.ID)
|
||||||
@ -192,6 +184,7 @@ func CreateEventsHandler(notifier *Notifier[Notification]) http.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Sending msg %s\n", msg)
|
||||||
fmt.Fprintf(w, "event: data\ndata: %s\n\n", string(msgString))
|
fmt.Fprintf(w, "event: data\ndata: %s\n\n", string(msgString))
|
||||||
w.(http.Flusher).Flush()
|
w.(http.Flusher).Flush()
|
||||||
}
|
}
|
||||||
|
@ -186,36 +186,6 @@ func main() {
|
|||||||
w.Write(jsonImages)
|
w.Write(jsonImages)
|
||||||
})
|
})
|
||||||
|
|
||||||
r.Get("/image-similar/{id}", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
// TODO: authentication
|
|
||||||
stringImageId := r.PathValue("id")
|
|
||||||
|
|
||||||
imageId, err := uuid.Parse(stringImageId)
|
|
||||||
if err != nil {
|
|
||||||
w.WriteHeader(http.StatusForbidden)
|
|
||||||
fmt.Fprintf(w, "You cannot read this")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
similarImages, err := imageModel.GetSimilar(r.Context(), imageId)
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
w.WriteHeader(http.StatusNotFound)
|
|
||||||
fmt.Fprintf(w, "Something went wrong")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonImages, err := json.Marshal(similarImages)
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
|
||||||
fmt.Fprintf(w, "Could not create JSON response for this image")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Write(jsonImages)
|
|
||||||
})
|
|
||||||
|
|
||||||
r.Post("/image/{name}", func(w http.ResponseWriter, r *http.Request) {
|
r.Post("/image/{name}", func(w http.ResponseWriter, r *http.Request) {
|
||||||
imageName := r.PathValue("name")
|
imageName := r.PathValue("name")
|
||||||
userId := r.Context().Value(USER_ID).(uuid.UUID)
|
userId := r.Context().Value(USER_ID).(uuid.UUID)
|
||||||
|
@ -7,8 +7,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"screenmark/screenmark/.gen/haystack/haystack/model"
|
"screenmark/screenmark/.gen/haystack/haystack/model"
|
||||||
. "screenmark/screenmark/.gen/haystack/haystack/table"
|
. "screenmark/screenmark/.gen/haystack/haystack/table"
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
. "github.com/go-jet/jet/v2/postgres"
|
. "github.com/go-jet/jet/v2/postgres"
|
||||||
|
|
||||||
@ -182,58 +180,6 @@ func (m ImageModel) IsUserAuthorized(ctx context.Context, imageId uuid.UUID, use
|
|||||||
return err != nil && userImage.UserID.String() == userId.String()
|
return err != nil && userImage.UserID.String() == userId.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m ImageModel) UpdateEmbedding(ctx context.Context, imageId uuid.UUID, embedding []float64) error {
|
|
||||||
stringSlice := make([]string, len(embedding))
|
|
||||||
|
|
||||||
for i, f := range embedding {
|
|
||||||
stringSlice[i] = strconv.FormatFloat(f, 'f', -1, 64)
|
|
||||||
}
|
|
||||||
|
|
||||||
embeddingString := "[" + strings.Join(stringSlice, ",") + "]"
|
|
||||||
|
|
||||||
_, err := m.dbPool.ExecContext(ctx, "UPDATE haystack.image SET embedding = $1::vector WHERE id = $2", embeddingString, imageId)
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
const SIMILARITY_THRESHOLD = 0.5
|
|
||||||
|
|
||||||
func (m ImageModel) GetSimilar(ctx context.Context, imageId uuid.UUID) ([]model.Image, error) {
|
|
||||||
stmt, err := m.dbPool.PrepareContext(ctx, `
|
|
||||||
WITH similarities AS (
|
|
||||||
SELECT 1 - (embedding <=> (
|
|
||||||
SELECT embedding
|
|
||||||
FROM haystack.image
|
|
||||||
WHERE ID = $1
|
|
||||||
)) AS cosine_similarity, id, image_name FROM haystack.image
|
|
||||||
WHERE embedding IS NOT NULL
|
|
||||||
)
|
|
||||||
SELECT id, image_name
|
|
||||||
FROM similarities
|
|
||||||
WHERE cosine_similarity > $2
|
|
||||||
`)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return []model.Image{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
images := []model.Image{}
|
|
||||||
|
|
||||||
rows, err := stmt.QueryContext(ctx, imageId, SIMILARITY_THRESHOLD)
|
|
||||||
|
|
||||||
for rows.Next() {
|
|
||||||
image := model.Image{}
|
|
||||||
err := rows.Scan(&image.ID, &image.ImageName)
|
|
||||||
if err != nil {
|
|
||||||
return []model.Image{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
images = append(images, image)
|
|
||||||
}
|
|
||||||
|
|
||||||
return images, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewImageModel(db *sql.DB) ImageModel {
|
func NewImageModel(db *sql.DB) ImageModel {
|
||||||
return ImageModel{dbPool: db}
|
return ImageModel{dbPool: db}
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
DROP SCHEMA IF EXISTS haystack CASCADE;
|
DROP SCHEMA IF EXISTS haystack CASCADE;
|
||||||
|
|
||||||
CREATE SCHEMA haystack;
|
CREATE SCHEMA haystack;
|
||||||
CREATE EXTENSION IF NOT EXISTS vector;
|
|
||||||
|
|
||||||
/* -----| Enums |----- */
|
/* -----| Enums |----- */
|
||||||
|
|
||||||
@ -17,9 +16,7 @@ CREATE TABLE haystack.users (
|
|||||||
CREATE TABLE haystack.image (
|
CREATE TABLE haystack.image (
|
||||||
id uuid PRIMARY KEY DEFAULT gen_random_uuid(),
|
id uuid PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
image_name TEXT NOT NULL,
|
image_name TEXT NOT NULL,
|
||||||
image BYTEA NOT NULL,
|
image BYTEA NOT NULL
|
||||||
|
|
||||||
embedding vector(3072)
|
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE TABLE haystack.user_images_to_process (
|
CREATE TABLE haystack.user_images_to_process (
|
||||||
|
4
backup.bash
Normal file
4
backup.bash
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
name=haystack-db-dump-$(date "+%Y-%m-%d").sql
|
||||||
|
pg_dump haystack > $name
|
||||||
|
rsync -avH $name zh3586@zh3586.rsync.net:Backups/Haystack/
|
||||||
|
rm $name
|
@ -1,6 +1,6 @@
|
|||||||
import { A, useParams } from "@solidjs/router";
|
import { A, useParams } from "@solidjs/router";
|
||||||
import { createResource, For, Show, type Component } from "solid-js";
|
import { For, Show, type Component } from "solid-js";
|
||||||
import { base, getImageSimilar, type UserImage } from "./network";
|
import { base, type UserImage } from "./network";
|
||||||
import { useSearchImageContext } from "./contexts/SearchImageContext";
|
import { useSearchImageContext } from "./contexts/SearchImageContext";
|
||||||
import { SearchCard } from "./components/search-card/SearchCard";
|
import { SearchCard } from "./components/search-card/SearchCard";
|
||||||
import { IconArrowLeft } from "@tabler/icons-solidjs";
|
import { IconArrowLeft } from "@tabler/icons-solidjs";
|
||||||
@ -11,8 +11,6 @@ export const Image: Component = () => {
|
|||||||
|
|
||||||
const { imagesWithProperties } = useSearchImageContext();
|
const { imagesWithProperties } = useSearchImageContext();
|
||||||
|
|
||||||
const [similarImages] = createResource(() => getImageSimilar(imageId));
|
|
||||||
|
|
||||||
const imageProperties = (): UserImage[] | undefined =>
|
const imageProperties = (): UserImage[] | undefined =>
|
||||||
Object.entries(imagesWithProperties()).find(
|
Object.entries(imagesWithProperties()).find(
|
||||||
([id]) => id === imageId,
|
([id]) => id === imageId,
|
||||||
@ -45,21 +43,6 @@ export const Image: Component = () => {
|
|||||||
)}
|
)}
|
||||||
</Show>
|
</Show>
|
||||||
</div>
|
</div>
|
||||||
<div class="w-3/4 grid grid-cols-9 gap-2 grid-flow-row-dense py-4">
|
|
||||||
<Show when={similarImages()}>
|
|
||||||
{(images) => (
|
|
||||||
<For each={images()}>
|
|
||||||
{(image) => (
|
|
||||||
<img
|
|
||||||
alt="similar"
|
|
||||||
class="col-span-3"
|
|
||||||
src={`${base}/image/${image.ID}`}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
</For>
|
|
||||||
)}
|
|
||||||
</Show>
|
|
||||||
</div>
|
|
||||||
</main>
|
</main>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -193,25 +193,6 @@ export const getImage = async (imageId: string): Promise<UserImage> => {
|
|||||||
return parse(dataTypeValidator, res);
|
return parse(dataTypeValidator, res);
|
||||||
};
|
};
|
||||||
|
|
||||||
const similarImageValidator = array(
|
|
||||||
strictObject({
|
|
||||||
ID: pipe(string(), uuid()),
|
|
||||||
ImageName: string(),
|
|
||||||
Image: null_(),
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
|
|
||||||
export const getImageSimilar = async (
|
|
||||||
imageId: string,
|
|
||||||
): Promise<InferOutput<typeof similarImageValidator>> => {
|
|
||||||
const request = getBaseAuthorizedRequest({
|
|
||||||
path: `image-similar/${imageId}`,
|
|
||||||
});
|
|
||||||
|
|
||||||
const res = await fetch(request).then((res) => res.json());
|
|
||||||
return parse(similarImageValidator, res);
|
|
||||||
};
|
|
||||||
|
|
||||||
export const postLogin = async (email: string): Promise<void> => {
|
export const postLogin = async (email: string): Promise<void> => {
|
||||||
const request = getBaseRequest({
|
const request = getBaseRequest({
|
||||||
path: "login",
|
path: "login",
|
||||||
|
Reference in New Issue
Block a user