diff --git a/backend/agents/agent.go b/backend/agents/agent.go deleted file mode 100644 index 231037b..0000000 --- a/backend/agents/agent.go +++ /dev/null @@ -1,479 +0,0 @@ -package agents - -import ( - "bytes" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "io" - "log" - "net/http" - "os" - "path/filepath" - "screenmark/screenmark/.gen/haystack/haystack/model" -) - -type ImageInfo struct { - Tags []string `json:"tags"` - Text []string `json:"text"` - Links []string `json:"links"` - - Locations []model.Locations `json:"locations"` - Events []model.Events `json:"events"` -} - -type ResponseFormat struct { - Type string `json:"type"` - JsonSchema any `json:"json_schema"` -} - -type AgentRequestBody struct { - Model string `json:"model"` - Temperature float64 `json:"temperature"` - ResponseFormat ResponseFormat `json:"response_format"` - - Tools *any `json:"tools,omitempty"` - ToolChoice *string `json:"tool_choice,omitempty"` - - AgentMessages -} - -type AgentMessages struct { - Messages []AgentMessage `json:"messages"` -} - -type AgentMessage interface { - MessageToJson() ([]byte, error) -} - -type AgentTextMessage struct { - Role string `json:"role"` - Content string `json:"content"` - ToolCallId string `json:"tool_call_id,omitempty"` - Name string `json:"name,omitempty"` -} - -func (textContent AgentTextMessage) MessageToJson() ([]byte, error) { - // TODO: Validate the `Role`. - return json.Marshal(textContent) -} - -type AgentAssistantToolCall struct { - Role string `json:"role"` - Content string `json:"content"` - ToolCalls []ToolCall `json:"tool_calls"` -} - -func (toolCall AgentAssistantToolCall) MessageToJson() ([]byte, error) { - return json.Marshal(toolCall) -} - -type AgentArrayMessage struct { - Role string `json:"role"` - Content []AgentContent `json:"content"` -} - -func (arrayContent AgentArrayMessage) MessageToJson() ([]byte, error) { - return json.Marshal(arrayContent) -} - -func (content *AgentMessages) AddText(message AgentTextMessage) { - content.Messages = append(content.Messages, message) -} - -func (content *AgentMessages) AddToolCall(toolCall AgentAssistantToolCall) { - content.Messages = append(content.Messages, toolCall) -} - -func (content *AgentMessages) AddImage(imageName string, image []byte) error { - extension := filepath.Ext(imageName) - if len(extension) == 0 { - // TODO: could also validate for image types we support. - return errors.New("Image does not have extension") - } - - extension = extension[1:] - - encodedString := base64.StdEncoding.EncodeToString(image) - - arrayMessage := AgentArrayMessage{Role: ROLE_USER, Content: make([]AgentContent, 1)} - arrayMessage.Content[0] = AgentImage{ - ImageType: IMAGE_TYPE, - ImageUrl: fmt.Sprintf("data:image/%s;base64,%s", extension, encodedString), - } - - content.Messages = append(content.Messages, arrayMessage) - - return nil -} - -func (content *AgentMessages) AddSystem(prompt string) error { - if len(content.Messages) != 0 { - return errors.New("You can only add a system prompt at the beginning") - } - - content.Messages = append(content.Messages, AgentTextMessage{ - Role: ROLE_SYSTEM, - Content: prompt, - }) - - return nil -} - -type AgentContent interface { - ToJson() ([]byte, error) -} - -type ImageUrl struct { - Url string `json:"url"` -} - -type AgentImage struct { - ImageType string `json:"type"` - ImageUrl string `json:"image_url"` -} - -func (imageMessage AgentImage) ToJson() ([]byte, error) { - imageMessage.ImageType = IMAGE_TYPE - return json.Marshal(imageMessage) -} - -type AiClient interface { - GetImageInfo(imageName string, imageData []byte) (ImageInfo, error) -} - -type AgentClient struct { - url string - apiKey string - systemPrompt string - responseFormat string - - Do func(req *http.Request) (*http.Response, error) -} - -// func (client AgentClient) Do(req *http.Request) () { -// httpClient := http.Client{} -// return httpClient.Do(req) -// } - -const OPENAI_API_KEY = "OPENAI_API_KEY" -const ROLE_USER = "user" -const ROLE_SYSTEM = "system" -const IMAGE_TYPE = "image_url" - -// TODO: extract to text file probably -const PROMPT = ` -You are an image information extractor. The user will provide you with screenshots and your job is to extract any relevant links and text -that the image might contain. You will also try your best to assign some tags to this image, avoid too many tags. -Be sure to extract every link (URL) that you find. -Use generic tags. -` - -const RESPONSE_FORMAT = ` -{ - "name": "image_info", - "strict": true, - "schema": { - "type": "object", - "title": "image", - "required": ["tags", "text", "links"], - "additionalProperties": false, - "properties": { - "tags": { - "type": "array", - "title": "tags", - "description": "A list of tags you think the image is relevant to.", - "items": { - "type": "string" - } - }, - "text": { - "type": "array", - "title": "text", - "description": "A list of sentences the image contains.", - "items": { - "type": "string" - } - }, - "links": { - "type": "array", - "title": "links", - "description": "A list of all the links you can find in the image.", - "items": { - "type": "string" - } - }, - "locations": { - "title": "locations", - "type": "array", - "description": "A list of locations you can find on the image, if any", - "items": { - "type": "object", - "required": ["name"], - "additionalProperties": false, - "properties": { - "name": { - "title": "name", - "type": "string" - }, - "coordinates": { - "title": "coordinates", - "type": "string" - }, - "address": { - "title": "address", - "type": "string" - }, - "description": { - "title": "description", - "type": "string" - } - } - } - }, - "events": { - "title": "events", - "type": "array", - "description": "A list of events you find on the image, if any", - "items": { - "type": "object", - "required": ["name"], - "additionalProperties": false, - "properties": { - "name": { - "type": "string", - "title": "name" - }, - "locations": { - "title": "locations", - "type": "array", - "description": "A list of locations on this event, if any", - "items": { - "type": "object", - "required": ["name"], - "additionalProperties": false, - "properties": { - "name": { - "title": "name", - "type": "string" - }, - "coordinates": { - "title": "coordinates", - "type": "string" - }, - "address": { - "title": "address", - "type": "string" - }, - "description": { - "title": "description", - "type": "string" - } - } - } - } - } - } - } - } - } -} -` - -func CreateAgentClient(prompt string) (AgentClient, error) { - apiKey := os.Getenv(OPENAI_API_KEY) - - if len(apiKey) == 0 { - return AgentClient{}, errors.New(OPENAI_API_KEY + " was not found.") - } - - return AgentClient{ - apiKey: apiKey, - url: "https://api.mistral.ai/v1/chat/completions", - systemPrompt: prompt, - Do: func(req *http.Request) (*http.Response, error) { - client := &http.Client{} - return client.Do(req) - }, - }, nil -} - -func (client AgentClient) getRequest(body []byte) (*http.Request, error) { - req, err := http.NewRequest("POST", client.url, bytes.NewBuffer(body)) - if err != nil { - return req, err - } - - req.Header.Add("Authorization", "Bearer "+client.apiKey) - req.Header.Add("Content-Type", "application/json") - - return req, nil -} - -func getCompletionsForImage(model string, temperature float64, prompt string, imageName string, jsonSchema string, imageData []byte) (AgentRequestBody, error) { - request := AgentRequestBody{ - Model: model, - Temperature: temperature, - ResponseFormat: ResponseFormat{ - Type: "json_schema", - JsonSchema: jsonSchema, - }, - } - - // TODO: Add build pattern here that deals with errors in some internal state? - // I want a monad!!! - err := request.AddSystem(prompt) - if err != nil { - return request, err - } - - log.Println(request) - - err = request.AddImage(imageName, imageData) - if err != nil { - return request, err - } - - request.Tools = nil - - return request, nil -} - -type FunctionCall struct { - Name string `json:"name"` - Arguments string `json:"arguments"` -} - -type ToolCall struct { - Index int `json:"index"` - Id string `json:"id"` - Function FunctionCall `json:"function"` -} - -type ResponseChoiceMessage struct { - Role string `json:"role"` - Content string `json:"content"` - ToolCalls []ToolCall `json:"tool_calls"` -} - -type ResponseChoice struct { - Index int `json:"index"` - Message ResponseChoiceMessage `json:"message"` - FinishReason string `json:"finish_reason"` -} - -type AgentResponse struct { - Id string `json:"id"` - Object string `json:"object"` - Choices []ResponseChoice `json:"choices"` - Created int `json:"created"` -} - -// TODO: add usage parsing -func parseAgentResponse(jsonResponse []byte) (ImageInfo, error) { - response := AgentResponse{} - - err := json.Unmarshal(jsonResponse, &response) - if err != nil { - return ImageInfo{}, err - } - - if len(response.Choices) != 1 { - log.Println(string(jsonResponse)) - return ImageInfo{}, errors.New("Expected exactly one choice.") - } - - imageInfo := ImageInfo{} - err = json.Unmarshal([]byte(response.Choices[0].Message.Content), &imageInfo) - if err != nil { - return ImageInfo{}, errors.New("Could not parse content into image type.") - } - - return imageInfo, nil -} - -func (client AgentClient) Request(request *AgentRequestBody) (AgentResponse, error) { - jsonAiRequest, err := json.Marshal(request) - if err != nil { - return AgentResponse{}, err - } - - httpRequest, err := client.getRequest(jsonAiRequest) - if err != nil { - return AgentResponse{}, err - } - - resp, err := client.Do(httpRequest) - if err != nil { - return AgentResponse{}, err - } - - response, err := io.ReadAll(resp.Body) - if err != nil { - return AgentResponse{}, err - } - - agentResponse := AgentResponse{} - err = json.Unmarshal(response, &agentResponse) - - if err != nil { - return AgentResponse{}, err - } - - log.Println(string(response)) - - toolCalls := agentResponse.Choices[0].Message.ToolCalls - if len(toolCalls) > 0 { - // Should for sure be more flexible. - request.AddToolCall(AgentAssistantToolCall{ - Role: "assistant", - Content: "", - ToolCalls: toolCalls, - }) - } - - return agentResponse, nil -} - -func (client AgentClient) GetImageInfo(imageName string, imageData []byte) (ImageInfo, error) { - aiRequest, err := getCompletionsForImage("pixtral-12b-2409", 1.0, client.systemPrompt, imageName, RESPONSE_FORMAT, imageData) - if err != nil { - return ImageInfo{}, err - } - - var jsonSchema any - err = json.Unmarshal([]byte(RESPONSE_FORMAT), &jsonSchema) - if err != nil { - return ImageInfo{}, err - } - - aiRequest.ResponseFormat = ResponseFormat{ - Type: "json_schema", - JsonSchema: jsonSchema, - } - - jsonAiRequest, err := json.Marshal(aiRequest) - if err != nil { - return ImageInfo{}, err - } - - request, err := client.getRequest(jsonAiRequest) - if err != nil { - return ImageInfo{}, err - } - - resp, err := client.Do(request) - if err != nil { - return ImageInfo{}, err - } - - response, err := io.ReadAll(resp.Body) - if err != nil { - return ImageInfo{}, err - } - - log.Println(string(response)) - - return parseAgentResponse(response) -} diff --git a/backend/agents/agent_test.go b/backend/agents/agent_test.go deleted file mode 100644 index 770e0ab..0000000 --- a/backend/agents/agent_test.go +++ /dev/null @@ -1,213 +0,0 @@ -package agents - -import ( - "bytes" - "encoding/base64" - "encoding/json" - "fmt" - "io" - "net/http" - "testing" -) - -func TestMessageBuilder(t *testing.T) { - content := AgentMessages{} - - err := content.AddSystem("Some prompt") - - if err != nil { - t.Log(err) - t.FailNow() - } - - if len(content.Messages) != 1 { - t.Logf("Expected length 1, got %d.\n", len(content.Messages)) - t.FailNow() - } -} - -func TestMessageBuilderImage(t *testing.T) { - content := AgentMessages{} - - prompt := "some prompt" - imageTitle := "image.png" - data := []byte("some data") - - content.AddSystem(prompt) - content.AddImage(imageTitle, data) - - if len(content.Messages) != 2 { - t.Logf("Expected length 2, got %d.\n", len(content.Messages)) - t.FailNow() - } - - promptMessage, ok := content.Messages[0].(AgentTextMessage) - if !ok { - t.Logf("Expected text content message, got %T\n", content.Messages[0]) - t.FailNow() - } - - if promptMessage.Role != ROLE_SYSTEM { - t.Log("Prompt message role is incorrect.") - t.FailNow() - } - - if promptMessage.Content != prompt { - t.Log("Prompt message content is incorrect.") - t.FailNow() - } - - arrayContentMessage, ok := content.Messages[1].(AgentArrayMessage) - if !ok { - t.Logf("Expected text content message, got %T\n", content.Messages[1]) - t.FailNow() - } - - if arrayContentMessage.Role != ROLE_USER { - t.Log("Array content message role is incorrect.") - t.FailNow() - } - - if len(arrayContentMessage.Content) != 1 { - t.Logf("Expected length 1, got %d.\n", len(arrayContentMessage.Content)) - t.FailNow() - } - - imageContent, ok := arrayContentMessage.Content[0].(AgentImage) - if !ok { - t.Logf("Expected text content message, got %T\n", arrayContentMessage.Content[0]) - t.FailNow() - } - - base64data := base64.StdEncoding.EncodeToString(data) - url := fmt.Sprintf("data:image/%s;base64,%s", "png", base64data) - - if imageContent.ImageUrl != url { - t.Logf("Expected %s, but got %s.\n", url, imageContent.ImageUrl) - t.FailNow() - } -} - -func TestFullImageRequest(t *testing.T) { - request, err := getCompletionsForImage("model", 0.1, "You are an assistant", "image.png", "", []byte("some data")) - if err != nil { - t.Log(request) - t.FailNow() - } - - jsonData, err := json.Marshal(request) - if err != nil { - t.Log(err) - t.FailNow() - } - - expectedJson := `{"model":"model","temperature":0.1,"response_format":{"type":"json_schema","json_schema":""},"messages":[{"role":"system","content":"You are an assistant"},{"role":"user","content":[{"type":"image_url","image_url":{"url":""}}]}]}` - - if string(jsonData) != expectedJson { - t.Logf("Expected:\n%s\n Got:\n%s\n", expectedJson, string(jsonData)) - t.FailNow() - } -} - -func TestResponse(t *testing.T) { - testResponse := `{"tags": ["tag1", "tag2"], "text": ["text1"], "links": []}` - buffer := bytes.NewReader([]byte(testResponse)) - - body := io.NopCloser(buffer) - - client := AgentClient{ - url: "http://localhost:1234", - apiKey: "some-key", - Do: func(_req *http.Request) (*http.Response, error) { - return &http.Response{Body: body}, nil - }, - } - - info, err := client.GetImageInfo("image.png", []byte("some data")) - if err != nil { - t.Log(err) - t.FailNow() - } - - if len(info.Tags) != 2 || len(info.Text) != 1 || len(info.Links) != 0 { - t.Logf("Some lengths are wrong.\nTags: %d\nText: %d\nLinks: %d\n", len(info.Tags), len(info.Text), len(info.Links)) - t.FailNow() - } - - if info.Tags[0] != "tag1" { - t.Log("0th tag is wrong.") - t.FailNow() - } - - if info.Tags[1] != "tag2" { - t.Log("1th tag is wrong.") - t.FailNow() - } - - if info.Text[0] != "text1" { - t.Log("0th text is wrong.") - t.FailNow() - } -} - -func TestResponseParsing(t *testing.T) { - response := `{ - "id": "chatcmpl-B4XgiHcd7A2nyK7eyARdggSvfFuWQ", - "object": "chat.completion", - "created": 1740422508, - "model": "gpt-4o-mini-2024-07-18", - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": "{\"links\":[\"link\"],\"tags\":[\"tag\"],\"text\":[\"text\"]}", - "refusal": null - }, - "logprobs": null, - "finish_reason": "stop" - } - ], - "usage": { - "prompt_tokens": 775, - "completion_tokens": 33, - "total_tokens": 808, - "prompt_tokens_details": { - "cached_tokens": 0, - "audio_tokens": 0 - }, - "completion_tokens_details": { - "reasoning_tokens": 0, - "audio_tokens": 0, - "accepted_prediction_tokens": 0, - "rejected_prediction_tokens": 0 - } - }, - "service_tier": "default", - "system_fingerprint": "fp_7fcd609668" -}` - - imageParsed, err := parseAgentResponse([]byte(response)) - if err != nil { - t.Log(err) - t.FailNow() - } - - if len(imageParsed.Links) != 1 || imageParsed.Links[0] != "link" { - t.Log(imageParsed) - t.Log("Should have one link called 'link'.") - t.FailNow() - } - - if len(imageParsed.Tags) != 1 || imageParsed.Tags[0] != "tag" { - t.Log(imageParsed) - t.Log("Should have one tag called 'tag'.") - t.FailNow() - } - - if len(imageParsed.Text) != 1 || imageParsed.Text[0] != "text" { - t.Log(imageParsed) - t.Log("Should have one text called 'text'.") - t.FailNow() - } -} diff --git a/backend/agents/client/chat.go b/backend/agents/client/chat.go new file mode 100644 index 0000000..6628e56 --- /dev/null +++ b/backend/agents/client/chat.go @@ -0,0 +1,208 @@ +package client + +import ( + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "path/filepath" +) + +type Chat struct { + Messages []ChatMessage `json:"messages"` +} + +type ChatMessage interface { + IsResponse() bool +} + +// TODO: the role could be inferred from the type. +// This would solve some bugs. + +/* + +Is there a world where this actually becomes the product? +Where we build such a resilient system of AI calls that we +can build some app builder, or even just an API system, +with a fancy UI? + +Manage all the complexity for the user? + +*/ + +// ============================================= +// Messages from us to the AI. +// ============================================= + +type UserRole = string + +const ( + User UserRole = "user" + System UserRole = "system" +) + +type ToolRole = string + +const ( + Tool ToolRole = "tool" +) + +type ChatUserMessage struct { + Role UserRole `json:"role"` + + MessageContent `json:"MessageContent"` +} + +func (m ChatUserMessage) MarshalJSON() ([]byte, error) { + switch t := m.MessageContent.(type) { + case SingleMessage: + return json.Marshal(&struct { + Role UserRole `json:"role"` + Content string `json:"content"` + }{ + Role: User, + Content: t.Content, + }) + case ArrayMessage: + return json.Marshal(&struct { + Role UserRole `json:"role"` + Content []ImageMessageContent `json:"content"` + }{ + Role: User, + Content: t.Content, + }) + } + + return []byte{}, errors.New("Unreachable") +} + +func (r ChatUserMessage) IsResponse() bool { + return false +} + +type ChatUserToolResponse struct { + Role ToolRole `json:"role"` + + // The name of the function we are responding to. + Name string `json:"name"` + Content string `json:"content"` + ToolCallId string `json:"tool_call_id"` +} + +func (r ChatUserToolResponse) IsResponse() bool { + return false +} + +type ChatAiMessage struct { + Role string `json:"role"` + ToolCalls *[]ToolCall `json:"tool_calls,omitempty"` + + Content string `json:"content"` +} + +func (m ChatAiMessage) IsResponse() bool { + return true +} + +// ============================================= +// Unique interface for message content. +// ============================================= + +type MessageContent interface { + IsSingleMessage() bool +} + +type SingleMessage struct { + Content string `json:"content"` +} + +func (m SingleMessage) IsSingleMessage() bool { + return true +} + +type ArrayMessage struct { + Content []ImageMessageContent `json:"content"` +} + +func (m ArrayMessage) IsSingleMessage() bool { + return false +} + +type ImageMessageContent struct { + ImageType string `json:"type"` + ImageUrl string `json:"image_url"` +} + +type ImageContentUrl struct { + Url string `json:"url"` +} + +// ============================================= +// Adjacent interfaces. +// ============================================= + +type ToolCall struct { + Index int `json:"index"` + Id string `json:"id"` + Function FunctionCall `json:"function"` +} + +type FunctionCall struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +// ============================================= +// Chat methods +// ============================================= + +func (chat *Chat) AddSystem(prompt string) { + chat.Messages = append(chat.Messages, ChatUserMessage{ + Role: System, + MessageContent: SingleMessage{ + Content: prompt, + }, + }) +} + +func (chat *Chat) AddImage(imageName string, image []byte) error { + extension := filepath.Ext(imageName) + if len(extension) == 0 { + // TODO: could also validate for image types we support. + return errors.New("Image does not have extension") + } + + extension = extension[1:] + + encodedString := base64.StdEncoding.EncodeToString(image) + + messageContent := ArrayMessage{ + Content: make([]ImageMessageContent, 1), + } + + messageContent.Content[0] = ImageMessageContent{ + ImageType: "image_url", + ImageUrl: fmt.Sprintf("data:image/%s;base64,%s", extension, encodedString), + } + + arrayMessage := ChatUserMessage{Role: User, MessageContent: messageContent} + chat.Messages = append(chat.Messages, arrayMessage) + + return nil +} + +func (chat *Chat) AddAiResponse(res ChatAiMessage) { + chat.Messages = append(chat.Messages, res) +} + +func (chat *Chat) AddToolResponse(res ChatUserToolResponse) { + chat.Messages = append(chat.Messages, res) +} + +func (chat Chat) GetLatest() (ChatMessage, error) { + if len(chat.Messages) == 0 { + return nil, errors.New("Not enough messages") + } + + return chat.Messages[len(chat.Messages)-1], nil +} diff --git a/backend/agents/client/chat_test.go b/backend/agents/client/chat_test.go new file mode 100644 index 0000000..09ca5ba --- /dev/null +++ b/backend/agents/client/chat_test.go @@ -0,0 +1,24 @@ +package client + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestFlatMarshallSingleMessage(t *testing.T) { + require := require.New(t) + + message := ChatUserMessage{ + Role: User, + MessageContent: SingleMessage{ + Content: "Hello", + }, + } + + json, err := json.Marshal(message) + require.NoError(err) + + require.Equal(string(json), "{\"role\":\"user\",\"content\":\"Hello\"}") +} diff --git a/backend/agents/client/client.go b/backend/agents/client/client.go new file mode 100644 index 0000000..18af351 --- /dev/null +++ b/backend/agents/client/client.go @@ -0,0 +1,194 @@ +package client + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "os" +) + +type ResponseFormat struct { + Type string `json:"type"` + JsonSchema any `json:"json_schema"` +} + +type AgentRequestBody struct { + Model string `json:"model"` + Temperature float64 `json:"temperature"` + ResponseFormat ResponseFormat `json:"response_format"` + + Tools *any `json:"tools,omitempty"` + ToolChoice *string `json:"tool_choice,omitempty"` + + EndToolCall string `json:"-"` + + Chat *Chat `json:"messages"` +} + +func (req AgentRequestBody) MarshalJSON() ([]byte, error) { + return json.Marshal(&struct { + Model string `json:"model"` + Temperature float64 `json:"temperature"` + ResponseFormat ResponseFormat `json:"response_format"` + + Tools *any `json:"tools,omitempty"` + ToolChoice *string `json:"tool_choice,omitempty"` + Messages []ChatMessage `json:"messages"` + }{ + Model: req.Model, + Temperature: req.Temperature, + ResponseFormat: req.ResponseFormat, + + Tools: req.Tools, + ToolChoice: req.ToolChoice, + + Messages: req.Chat.Messages, + }) +} + +type ResponseChoice struct { + Index int `json:"index"` + Message ChatAiMessage `json:"message"` + FinishReason string `json:"finish_reason"` +} + +type AgentResponse struct { + Id string `json:"id"` + Object string `json:"object"` + Choices []ResponseChoice `json:"choices"` + Created int `json:"created"` +} + +type AgentClient struct { + url string + apiKey string + responseFormat string + + ToolHandler ToolsHandlers + + Do func(req *http.Request) (*http.Response, error) +} + +const OPENAI_API_KEY = "OPENAI_API_KEY" + +func CreateAgentClient() (AgentClient, error) { + apiKey := os.Getenv(OPENAI_API_KEY) + + if len(apiKey) == 0 { + return AgentClient{}, errors.New(OPENAI_API_KEY + " was not found.") + } + + return AgentClient{ + apiKey: apiKey, + url: "https://api.mistral.ai/v1/chat/completions", + Do: func(req *http.Request) (*http.Response, error) { + client := &http.Client{} + return client.Do(req) + }, + + ToolHandler: ToolsHandlers{ + handlers: map[string]ToolHandler{}, + }, + }, nil +} + +func (client AgentClient) getRequest(body []byte) (*http.Request, error) { + req, err := http.NewRequest("POST", client.url, bytes.NewBuffer(body)) + if err != nil { + return req, err + } + + req.Header.Add("Authorization", "Bearer "+client.apiKey) + req.Header.Add("Content-Type", "application/json") + + return req, nil +} + +func (client AgentClient) Request(req *AgentRequestBody) (AgentResponse, error) { + jsonAiRequest, err := json.Marshal(req) + if err != nil { + return AgentResponse{}, err + } + + httpRequest, err := client.getRequest(jsonAiRequest) + if err != nil { + return AgentResponse{}, err + } + + resp, err := client.Do(httpRequest) + if err != nil { + return AgentResponse{}, err + } + + response, err := io.ReadAll(resp.Body) + if err != nil { + return AgentResponse{}, err + } + + fmt.Println(string(response)) + + agentResponse := AgentResponse{} + err = json.Unmarshal(response, &agentResponse) + + if err != nil { + return AgentResponse{}, err + } + + if len(agentResponse.Choices) != 1 { + return AgentResponse{}, errors.New("Unsupported. We currently only accept 1 choice from AI.") + } + + req.Chat.AddAiResponse(agentResponse.Choices[0].Message) + + return agentResponse, nil +} + +func (client AgentClient) ToolLoop(info ToolHandlerInfo, req *AgentRequestBody) error { + for { + err := client.Process(info, req) + if err != nil { + return err + } + + _, err = client.Request(req) + if err != nil { + return err + } + } +} + +var FinishedCall = errors.New("Last tool tool was called") + +func (client AgentClient) Process(info ToolHandlerInfo, req *AgentRequestBody) error { + var err error + + message, err := req.Chat.GetLatest() + if err != nil { + return err + } + + aiMessage, ok := message.(ChatAiMessage) + if !ok { + return errors.New("Latest message isnt an AI message") + } + + if aiMessage.ToolCalls == nil { + // Not an error, we just dont have any tool calls to process. + return nil + } + + for _, toolCall := range *aiMessage.ToolCalls { + if toolCall.Function.Name == req.EndToolCall { + return FinishedCall + } + + toolResponse := client.ToolHandler.Handle(info, toolCall) + + req.Chat.AddToolResponse(toolResponse) + } + + return err +} diff --git a/backend/agents/client/tools.go b/backend/agents/client/tools.go new file mode 100644 index 0000000..8dc9244 --- /dev/null +++ b/backend/agents/client/tools.go @@ -0,0 +1,70 @@ +package client + +import ( + "encoding/json" + "errors" + + "github.com/google/uuid" +) + +type ToolHandlerInfo struct { + UserId uuid.UUID + ImageId uuid.UUID +} + +type ToolHandler struct { + Fn func(info ToolHandlerInfo, args string, call ToolCall) (string, error) +} + +type ToolsHandlers struct { + handlers map[string]ToolHandler +} + +var NoToolCallError = errors.New("An assistant tool call with no tool calls was provided.") + +const NonExistantTool = "This tool does not exist" + +func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage ToolCall) ChatUserToolResponse { + fnName := toolCallMessage.Function.Name + arguments := toolCallMessage.Function.Arguments + + responseMessage := ChatUserToolResponse{ + Role: "tool", + Name: fnName, + ToolCallId: toolCallMessage.Id, + } + + fnHandler, exists := handler.handlers[fnName] + if !exists { + responseMessage.Content = NonExistantTool + return responseMessage + } + + res, err := fnHandler.Fn(info, arguments, toolCallMessage) + + if err != nil { + responseMessage.Content = err.Error() + } else { + responseMessage.Content = res + } + + return responseMessage +} + +func (handler *ToolsHandlers) AddTool(name string, fn func(info ToolHandlerInfo, args string, call ToolCall) (any, error)) { + handler.handlers[name] = ToolHandler{ + Fn: func(info ToolHandlerInfo, args string, call ToolCall) (string, error) { + res, err := fn(info, args, call) + if err != nil { + return "", err + } + + marshalledRes, err := json.Marshal(res) + if err != nil { + return "", err + } + + return string(marshalledRes), nil + }, + } +} diff --git a/backend/agents/client/tools_test.go b/backend/agents/client/tools_test.go new file mode 100644 index 0000000..912d94b --- /dev/null +++ b/backend/agents/client/tools_test.go @@ -0,0 +1,188 @@ +package client + +import ( + "errors" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/suite" +) + +type ToolTestSuite struct { + suite.Suite + + handler ToolsHandlers + client AgentClient +} + +func (suite *ToolTestSuite) SetupTest() { + suite.handler = ToolsHandlers{ + handlers: map[string]ToolHandler{}, + } + + suite.handler.AddTool("a", func(info ToolHandlerInfo, args string, call ToolCall) (any, error) { + return args, nil + }) + + suite.handler.AddTool("error", func(info ToolHandlerInfo, args string, call ToolCall) (any, error) { + return false, errors.New("I will always error") + }) + + suite.client.ToolHandler = suite.handler +} + +func (suite *ToolTestSuite) TestSingleToolCall() { + require := suite.Require() + + response := suite.handler.Handle( + ToolHandlerInfo{ + UserId: uuid.Nil, + ImageId: uuid.Nil, + }, + ToolCall{ + Index: 0, + Id: "1", + Function: FunctionCall{ + Name: "a", + Arguments: "return", + }, + }) + + require.EqualValues(response, ChatUserToolResponse{ + Role: "tool", + Content: "\"return\"", + ToolCallId: "1", + Name: "a", + }) +} + +func (suite *ToolTestSuite) TestMultipleToolCalls() { + assert := suite.Assert() + require := suite.Require() + + chat := Chat{ + Messages: []ChatMessage{ChatAiMessage{ + Role: "assistant", + Content: "", + ToolCalls: &[]ToolCall{ + { + Index: 0, + Id: "1", + Function: FunctionCall{ + Name: "a", + Arguments: "first-call", + }, + }, + { + Index: 1, + Id: "2", + Function: FunctionCall{ + Name: "a", + Arguments: "second-call", + }, + }, + }, + }}, + } + + err := suite.client.Process( + ToolHandlerInfo{ + UserId: uuid.Nil, + ImageId: uuid.Nil, + }, + &AgentRequestBody{ + Chat: &chat, + }) + + require.NoError(err, "Tool call shouldnt return an error") + assert.EqualValues(chat.Messages[1:], []ChatMessage{ + ChatUserToolResponse{ + Role: "tool", + Content: "\"first-call\"", + ToolCallId: "1", + Name: "a", + }, + ChatUserToolResponse{ + Role: "tool", + Content: "\"second-call\"", + ToolCallId: "2", + Name: "a", + }, + }) +} + +func (suite *ToolTestSuite) TestMultipleToolCallsWithErrors() { + assert := suite.Assert() + require := suite.Require() + + chat := Chat{ + Messages: []ChatMessage{ChatAiMessage{ + Role: "assistant", + Content: "", + ToolCalls: &[]ToolCall{ + { + Index: 0, + Id: "1", + Function: FunctionCall{ + Name: "error", + Arguments: "", + }, + }, + { + Index: 1, + Id: "2", + Function: FunctionCall{ + Name: "non-existant", + Arguments: "", + }, + }, + { + Index: 2, + Id: "3", + Function: FunctionCall{ + Name: "a", + Arguments: "no-error", + }, + }, + }, + }}, + } + + err := suite.client.Process( + ToolHandlerInfo{ + UserId: uuid.Nil, + ImageId: uuid.Nil, + }, + &AgentRequestBody{ + Chat: &chat, + }) + + require.NoError(err, "Tool call shouldnt return an error") + + assert.EqualValues(chat.Messages[1:], []ChatMessage{ + ChatUserToolResponse{ + Role: "tool", + Content: "I will always error", + ToolCallId: "1", + Name: "error", + }, + ChatUserToolResponse{ + Role: "tool", + Content: "This tool does not exist", + ToolCallId: "2", + Name: "non-existant", + }, + ChatUserToolResponse{ + Role: "tool", + Content: "\"no-error\"", + ToolCallId: "3", + Name: "a", + }, + }) +} + +func TestToolSuite(t *testing.T) { + suite.Run(t, &ToolTestSuite{ + client: AgentClient{}, + }) +} diff --git a/backend/agents/event_location_agent.go b/backend/agents/event_location_agent.go index 72757ad..22672f6 100644 --- a/backend/agents/event_location_agent.go +++ b/backend/agents/event_location_agent.go @@ -3,10 +3,8 @@ package agents import ( "context" "encoding/json" - "errors" - "log" - "reflect" "screenmark/screenmark/.gen/haystack/haystack/model" + "screenmark/screenmark/agents/client" "screenmark/screenmark/models" "time" @@ -101,28 +99,25 @@ const TOOLS = ` } } }, - { - "type": "function", - "function": { - "name": "finish", - "description": "Nothing else to do, call this function.", - "parameters": { - "type": "object", - "properties": {} - } - } - } + { + "type": "function", + "function": { + "name": "finish", + "description": "Nothing else to do. call this function.", + "parameters": {} + } + } ] ` type EventLocationAgent struct { - client AgentClient + client client.AgentClient eventModel models.EventModel locationModel models.LocationModel contactModel models.ContactModel - toolHandler ToolsHandlers + toolHandler client.ToolsHandlers } type ListLocationArguments struct{} @@ -158,157 +153,66 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID toolChoice := "any" - request := AgentRequestBody{ + request := client.AgentRequestBody{ Tools: &tools, ToolChoice: &toolChoice, Model: "pixtral-12b-2409", Temperature: 0.3, - ResponseFormat: ResponseFormat{ + EndToolCall: "finish", + ResponseFormat: client.ResponseFormat{ Type: "text", }, + Chat: &client.Chat{ + Messages: make([]client.ChatMessage, 0), + }, } - err = request.AddSystem(eventLocationPrompt) - if err != nil { - return err - } - - request.AddImage(imageName, imageData) + request.Chat.AddSystem(eventLocationPrompt) + request.Chat.AddImage(imageName, imageData) _, err = agent.client.Request(&request) if err != nil { return err } - toolHandlerInfo := ToolHandlerInfo{ - imageId: imageId, - userId: userId, + toolHandlerInfo := client.ToolHandlerInfo{ + ImageId: imageId, + UserId: userId, } - _, err = agent.toolHandler.Handle(toolHandlerInfo, &request) - - for err == nil { - log.Printf("Latest message: %+v\n", request.AgentMessages.Messages[len(request.AgentMessages.Messages)-1]) - - response, requestError := agent.client.Request(&request) - if requestError != nil { - return requestError - } - - log.Println(response) - - a, innerErr := agent.toolHandler.Handle(toolHandlerInfo, &request) - - err = innerErr - - log.Println(a) - log.Println("--------------------------") - } - - return err -} - -// TODO: extract this into a more general tool handler package. -func (handler ToolsHandlers) Handle(info ToolHandlerInfo, request *AgentRequestBody) (string, error) { - agentMessage := request.Messages[len(request.Messages)-1] - - toolCall, ok := agentMessage.(AgentAssistantToolCall) - if !ok { - return "", errors.New("Latest message was not a tool call.") - } - - fnName := toolCall.ToolCalls[0].Function.Name - arguments := toolCall.ToolCalls[0].Function.Arguments - - if fnName == "finish" { - return "", errors.New("This is the end! Maybe we just return a boolean.") - } - - fn, exists := handler.Handlers[fnName] - if !exists { - return "", errors.New("Could not find tool with this name.") - } - - // holy jesus what the fuck. - parseMethod := reflect.ValueOf(fn).Field(1) - if !parseMethod.IsValid() { - return "", errors.New("Parse method not found") - } - - parsedArgs := parseMethod.Call([]reflect.Value{reflect.ValueOf(arguments)}) - if !parsedArgs[1].IsNil() { - return "", parsedArgs[1].Interface().(error) - } - - log.Printf("Calling: %s\n", fnName) - - fnMethod := reflect.ValueOf(fn).Field(2) - if !fnMethod.IsValid() { - return "", errors.New("Fn method not found") - } - - response := fnMethod.Call([]reflect.Value{reflect.ValueOf(info), parsedArgs[0], reflect.ValueOf(toolCall.ToolCalls[0])}) - if !response[1].IsNil() { - return "", response[1].Interface().(error) - } - - stringResponse, err := json.Marshal(response[0].Interface()) - if err != nil { - return "", err - } - - request.AddText(AgentTextMessage{ - Role: "tool", - Name: "createLocation", - Content: string(stringResponse), - ToolCallId: toolCall.ToolCalls[0].Id, - }) - - return string(stringResponse), nil + return agent.client.ToolLoop(toolHandlerInfo, &request) } func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel, contactModel models.ContactModel) (EventLocationAgent, error) { - client, err := CreateAgentClient(eventLocationPrompt) + agentClient, err := client.CreateAgentClient() if err != nil { return EventLocationAgent{}, err } agent := EventLocationAgent{ - client: client, + client: agentClient, locationModel: locationModel, eventModel: eventModel, contactModel: contactModel, } - toolHandler := ToolsHandlers{ - Handlers: make(map[string]ToolHandlerInterface), - } - - toolHandler.Handlers["listLocations"] = ToolHandler[ListLocationArguments, []model.Locations]{ - FunctionName: "listLocations", - Parse: func(stringArgs string) (ListLocationArguments, error) { - args := ListLocationArguments{} - err := json.Unmarshal([]byte(stringArgs), &args) - - return args, err + agentClient.ToolHandler.AddTool("listLocations", + func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { + return agent.locationModel.List(context.Background(), info.UserId) }, - Fn: func(info ToolHandlerInfo, _args ListLocationArguments, call ToolCall) ([]model.Locations, error) { - return agent.locationModel.List(context.Background(), info.userId) - }, - } + ) - toolHandler.Handlers["createLocation"] = ToolHandler[CreateLocationArguments, model.Locations]{ - FunctionName: "createLocation", - Parse: func(stringArgs string) (CreateLocationArguments, error) { + agentClient.ToolHandler.AddTool("createLocation", + func(info client.ToolHandlerInfo, _args string, call client.ToolCall) (any, error) { args := CreateLocationArguments{} - err := json.Unmarshal([]byte(stringArgs), &args) + err := json.Unmarshal([]byte(_args), &args) + if err != nil { + return model.Locations{}, err + } - return args, err - }, - Fn: func(info ToolHandlerInfo, args CreateLocationArguments, call ToolCall) (model.Locations, error) { ctx := context.Background() - location, err := agent.locationModel.Save(ctx, info.userId, model.Locations{ + location, err := agent.locationModel.Save(ctx, info.UserId, model.Locations{ Name: args.Name, Address: args.Address, }) @@ -317,21 +221,20 @@ func NewLocationEventAgent(locationModel models.LocationModel, eventModel models return location, err } - _, err = agent.locationModel.SaveToImage(ctx, info.imageId, location.ID) + _, err = agent.locationModel.SaveToImage(ctx, info.ImageId, location.ID) return location, err }, - } + ) - toolHandler.Handlers["createEvent"] = ToolHandler[CreateEventArguments, model.Events]{ - FunctionName: "createEvent", - Parse: func(stringArgs string) (CreateEventArguments, error) { + agentClient.ToolHandler.AddTool("createEvent", + func(info client.ToolHandlerInfo, _args string, call client.ToolCall) (any, error) { args := CreateEventArguments{} - err := json.Unmarshal([]byte(stringArgs), &args) + err := json.Unmarshal([]byte(_args), &args) + if err != nil { + return model.Locations{}, err + } - return args, err - }, - Fn: func(info ToolHandlerInfo, args CreateEventArguments, call ToolCall) (model.Events, error) { ctx := context.Background() layout := "2006-01-02T15:04:05Z" @@ -346,7 +249,7 @@ func NewLocationEventAgent(locationModel models.LocationModel, eventModel models return model.Events{}, err } - event, err := agent.eventModel.Save(ctx, info.userId, model.Events{ + event, err := agent.eventModel.Save(ctx, info.UserId, model.Events{ Name: args.Name, StartDateTime: &startTime, EndDateTime: &endTime, @@ -356,7 +259,7 @@ func NewLocationEventAgent(locationModel models.LocationModel, eventModel models return event, err } - organizer, err := agent.contactModel.Save(ctx, info.userId, model.Contacts{ + organizer, err := agent.contactModel.Save(ctx, info.UserId, model.Contacts{ Name: args.Name, }) @@ -364,12 +267,12 @@ func NewLocationEventAgent(locationModel models.LocationModel, eventModel models return event, err } - _, err = agent.eventModel.SaveToImage(ctx, info.imageId, event.ID) + _, err = agent.eventModel.SaveToImage(ctx, info.ImageId, event.ID) if err != nil { return event, err } - _, err = agent.contactModel.SaveToImage(ctx, info.imageId, organizer.ID) + _, err = agent.contactModel.SaveToImage(ctx, info.ImageId, organizer.ID) if err != nil { return event, err } @@ -386,9 +289,7 @@ func NewLocationEventAgent(locationModel models.LocationModel, eventModel models return agent.eventModel.UpdateOrganizer(ctx, event.ID, organizer.ID) }, - } - - agent.toolHandler = toolHandler + ) return agent, nil } diff --git a/backend/agents/note_agent.go b/backend/agents/note_agent.go index 1616ea6..59b2fb0 100644 --- a/backend/agents/note_agent.go +++ b/backend/agents/note_agent.go @@ -3,6 +3,7 @@ package agents import ( "context" "screenmark/screenmark/.gen/haystack/haystack/model" + "screenmark/screenmark/agents/client" "screenmark/screenmark/models" "github.com/google/uuid" @@ -19,26 +20,26 @@ Do not return anything except markdown. ` type NoteAgent struct { - client AgentClient + client client.AgentClient noteModel models.NoteModel } func (agent NoteAgent) GetNotes(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error { - request := AgentRequestBody{ + request := client.AgentRequestBody{ Model: "pixtral-12b-2409", Temperature: 0.3, - ResponseFormat: ResponseFormat{ + ResponseFormat: client.ResponseFormat{ Type: "text", }, + Chat: &client.Chat{ + Messages: make([]client.ChatMessage, 0), + }, } - err := request.AddSystem(noteAgentPrompt) - if err != nil { - return err - } + request.Chat.AddSystem(noteAgentPrompt) + request.Chat.AddImage(imageName, imageData) - request.AddImage(imageName, imageData) resp, err := agent.client.Request(&request) if err != nil { return err @@ -65,7 +66,7 @@ func (agent NoteAgent) GetNotes(userId uuid.UUID, imageId uuid.UUID, imageName s } func NewNoteAgent(noteModel models.NoteModel) (NoteAgent, error) { - client, err := CreateAgentClient(noteAgentPrompt) + client, err := client.CreateAgentClient() if err != nil { return NoteAgent{}, err } diff --git a/backend/agents/orchestrator.go b/backend/agents/orchestrator.go new file mode 100644 index 0000000..430699c --- /dev/null +++ b/backend/agents/orchestrator.go @@ -0,0 +1,167 @@ +package agents + +import ( + "encoding/json" + "errors" + "fmt" + "screenmark/screenmark/agents/client" + + "github.com/google/uuid" +) + +const orchestratorPrompt = ` +You are an Orchestrator for various AI agents. + +The user will send you images and you have to determine which agents you have to call, in order to best help the user. + +You might decide no agent needs to be called. + +The agents are available as tool calls. + +Agents available: + +eventLocationAgent + +Use it when you think the image contains an event or a location of any sort. This can be an event page, a map, an address or a date. + +noteAgent + +Use it when there is text on the screen. Any text, always use this. Use me! + +defaultAgent + +When none of the above apply. + +Always call agents in parallel if you need to call more than 1. +` + +const MY_TOOLS = ` +[ + { + "type": "function", + "function": { + "name": "eventLocationAgent", + "description": "Uses the event location agent", + "parameters": { + "type": "object", + "properties": {}, + "required": [] + } + } + }, + { + "type": "function", + "function": { + "name": "noteAgent", + "description": "Uses the note agent", + "parameters": { + "type": "object", + "properties": {}, + "required": [] + } + } + }, + { + "type": "function", + "function": { + "name": "defaultAgent", + "description": "Used when you dont think its a good idea to call other agents", + "parameters": { + "type": "object", + "properties": {}, + "required": [] + } + } + } +]` + +type OrchestratorAgent struct { + client client.AgentClient +} + +type Status struct { + Ok bool `json:"ok"` +} + +// TODO: the primary function of the agent could be extracted outwards. +// This is basically the same function as we have in the `event_location_agent.go` +func (agent OrchestratorAgent) Orchestrate(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error { + toolChoice := "any" + + var tools any + err := json.Unmarshal([]byte(MY_TOOLS), &tools) + if err != nil { + return err + } + + request := client.AgentRequestBody{ + Model: "pixtral-12b-2409", + Temperature: 0.3, + ResponseFormat: client.ResponseFormat{ + Type: "text", + }, + ToolChoice: &toolChoice, + Tools: &tools, + + EndToolCall: "defaultAgent", + + Chat: &client.Chat{ + Messages: make([]client.ChatMessage, 0), + }, + } + + request.Chat.AddSystem(orchestratorPrompt) + request.Chat.AddImage(imageName, imageData) + + res, err := agent.client.Request(&request) + if err != nil { + return err + } + + fmt.Println(res) + + toolHandlerInfo := client.ToolHandlerInfo{ + ImageId: imageId, + UserId: userId, + } + + return agent.client.ToolLoop(toolHandlerInfo, &request) +} + +func NewOrchestratorAgent(eventLocationAgent EventLocationAgent, noteAgent NoteAgent, imageName string, imageData []byte) (OrchestratorAgent, error) { + agent, err := client.CreateAgentClient() + if err != nil { + return OrchestratorAgent{}, err + } + + agent.ToolHandler.AddTool("eventLocationAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { + // We need a way to keep track of this async? + // Probably just a DB, because we don't want to wait. The orchistrator shouldnt wait for this stuff to finish. + + go eventLocationAgent.GetLocations(info.UserId, info.ImageId, imageName, imageData) + + return Status{ + Ok: true, + }, nil + }) + + agent.ToolHandler.AddTool("noteAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { + go noteAgent.GetNotes(info.UserId, info.ImageId, imageName, imageData) + + return Status{ + Ok: true, + }, nil + }) + + agent.ToolHandler.AddTool("defaultAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { + // To nothing + + return Status{ + Ok: true, + }, errors.New("Finished! Kinda bad return type but...") + }) + + return OrchestratorAgent{ + client: agent, + }, nil +} diff --git a/backend/agents/tools_handler.go b/backend/agents/tools_handler.go deleted file mode 100644 index 5981d50..0000000 --- a/backend/agents/tools_handler.go +++ /dev/null @@ -1,26 +0,0 @@ -package agents - -import "github.com/google/uuid" - -type ToolHandlerInfo struct { - userId uuid.UUID - imageId uuid.UUID -} - -type ToolHandler[TArgs any, TResp any] struct { - FunctionName string - Parse func(args string) (TArgs, error) - Fn func(info ToolHandlerInfo, args TArgs, call ToolCall) (TResp, error) -} - -type ToolHandlerInterface interface { - GetFunctionName() string -} - -func (handler ToolHandler[TArgs, TResp]) GetFunctionName() string { - return handler.FunctionName -} - -type ToolsHandlers struct { - Handlers map[string]ToolHandlerInterface -} diff --git a/backend/go.mod b/backend/go.mod index fcdc38a..df64294 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -10,6 +10,6 @@ require ( github.com/joho/godotenv v1.5.1 // indirect github.com/lib/pq v1.10.9 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/stretchr/testify v1.9.0 // indirect + github.com/stretchr/testify v1.10.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/backend/go.sum b/backend/go.sum index ecd9db0..b982fe0 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -14,6 +14,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/backend/main.go b/backend/main.go index 60d8079..1f395e0 100644 --- a/backend/main.go +++ b/backend/main.go @@ -13,6 +13,7 @@ import ( "path/filepath" "screenmark/screenmark/.gen/haystack/haystack/model" "screenmark/screenmark/agents" + "screenmark/screenmark/agents/client" "screenmark/screenmark/models" "time" @@ -24,41 +25,13 @@ import ( ) type TestAiClient struct { - ImageInfo agents.ImageInfo + ImageInfo client.ImageMessageContent } -func (client TestAiClient) GetImageInfo(imageName string, imageData []byte) (agents.ImageInfo, error) { +func (client TestAiClient) GetImageInfo(imageName string, imageData []byte) (client.ImageMessageContent, error) { return client.ImageInfo, nil } -func GetAiClient() (agents.AiClient, error) { - mode := os.Getenv("MODE") - if mode == "TESTING" { - address := "10 Downing Street" - description := "Cheese and Crackers" - - return TestAiClient{ - ImageInfo: agents.ImageInfo{ - Tags: []string{"tag"}, - Links: []string{"links"}, - Text: []string{"text"}, - Locations: []model.Locations{{ - ID: uuid.Nil, - Name: "London", - Address: &address, - }}, - Events: []model.Events{{ - ID: uuid.Nil, - Name: "Party", - Description: &description, - }}, - }, - }, nil - } - - return agents.CreateAgentClient(agents.PROMPT) -} - func main() { err := godotenv.Load() if err != nil { @@ -74,9 +47,6 @@ func main() { } imageModel := models.NewImageModel(db) - linkModel := models.NewLinkModel(db) - tagModel := models.NewTagModel(db) - textModel := models.NewTextModel(db) locationModel := models.NewLocationModel(db) eventModel := models.NewEventModel(db) userModel := models.NewUserModel(db) @@ -105,11 +75,6 @@ func main() { ctx := context.Background() go func() { - openAiClient, err := GetAiClient() - if err != nil { - panic(err) - } - locationAgent, err := agents.NewLocationEventAgent(locationModel, eventModel, contactModel) if err != nil { panic(err) @@ -127,51 +92,21 @@ func main() { return } - userImage, err := imageModel.FinishProcessing(ctx, image.ID) + _, err = imageModel.FinishProcessing(ctx, image.ID) if err != nil { log.Println("Failed to FinishProcessing") log.Println(err) return } - // TODO: this can very much be parallel - - log.Println("Calling locationAgent!") - err = locationAgent.GetLocations(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image) - log.Println(err) - - log.Println("Calling noteAgent!") - err = noteAgent.GetNotes(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image) - log.Println(err) - - return - - imageInfo, err := openAiClient.GetImageInfo(image.Image.ImageName, image.Image.Image) + orchestrator, err := agents.NewOrchestratorAgent(locationAgent, noteAgent, image.Image.ImageName, image.Image.Image) if err != nil { - log.Println("Failed to GetImageInfo") - log.Println(err) - return + panic(err) } - err = tagModel.SaveToImage(ctx, userImage.ImageID, imageInfo.Tags) + err = orchestrator.Orchestrate(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image) if err != nil { - log.Println("Failed to save tags") - log.Println(err) - return - } - - err = linkModel.Save(ctx, userImage.ImageID, imageInfo.Links) - if err != nil { - log.Println("Failed to save links") - log.Println(err) - return - } - - err = textModel.Save(ctx, userImage.ImageID, imageInfo.Text) - if err != nil { - log.Println("Failed to save text") - log.Println(err) - return + fmt.Println(err) } }() }