diff --git a/backend/agents/agent.go b/backend/agents/agent.go deleted file mode 100644 index 30aa9c1..0000000 --- a/backend/agents/agent.go +++ /dev/null @@ -1,481 +0,0 @@ -package agents - -import ( - "bytes" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "io" - "log" - "net/http" - "os" - "path/filepath" - "screenmark/screenmark/.gen/haystack/haystack/model" -) - -type ImageInfo struct { - Tags []string `json:"tags"` - Text []string `json:"text"` - Links []string `json:"links"` - - Locations []model.Locations `json:"locations"` - Events []model.Events `json:"events"` -} - -type ResponseFormat struct { - Type string `json:"type"` - JsonSchema any `json:"json_schema"` -} - -type AgentRequestBody struct { - Model string `json:"model"` - Temperature float64 `json:"temperature"` - ResponseFormat ResponseFormat `json:"response_format"` - - Tools *any `json:"tools,omitempty"` - ToolChoice *string `json:"tool_choice,omitempty"` - - // ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` - - AgentMessages -} - -type AgentMessages struct { - Messages []AgentMessage `json:"messages"` -} - -type AgentMessage interface { - MessageToJson() ([]byte, error) -} - -type AgentTextMessage struct { - Role string `json:"role"` - Content string `json:"content"` - ToolCallId string `json:"tool_call_id,omitempty"` - Name string `json:"name,omitempty"` -} - -func (textContent AgentTextMessage) MessageToJson() ([]byte, error) { - // TODO: Validate the `Role`. - return json.Marshal(textContent) -} - -type AgentAssistantToolCall struct { - Role string `json:"role"` - Content string `json:"content"` - ToolCalls []ToolCall `json:"tool_calls"` -} - -func (toolCall AgentAssistantToolCall) MessageToJson() ([]byte, error) { - return json.Marshal(toolCall) -} - -type AgentArrayMessage struct { - Role string `json:"role"` - Content []AgentContent `json:"content"` -} - -func (arrayContent AgentArrayMessage) MessageToJson() ([]byte, error) { - return json.Marshal(arrayContent) -} - -func (content *AgentMessages) AddText(message AgentTextMessage) { - content.Messages = append(content.Messages, message) -} - -func (content *AgentMessages) AddToolCall(toolCall AgentAssistantToolCall) { - content.Messages = append(content.Messages, toolCall) -} - -func (content *AgentMessages) AddImage(imageName string, image []byte) error { - extension := filepath.Ext(imageName) - if len(extension) == 0 { - // TODO: could also validate for image types we support. - return errors.New("Image does not have extension") - } - - extension = extension[1:] - - encodedString := base64.StdEncoding.EncodeToString(image) - - arrayMessage := AgentArrayMessage{Role: ROLE_USER, Content: make([]AgentContent, 1)} - arrayMessage.Content[0] = AgentImage{ - ImageType: IMAGE_TYPE, - ImageUrl: fmt.Sprintf("data:image/%s;base64,%s", extension, encodedString), - } - - content.Messages = append(content.Messages, arrayMessage) - - return nil -} - -func (content *AgentMessages) AddSystem(prompt string) error { - if len(content.Messages) != 0 { - return errors.New("You can only add a system prompt at the beginning") - } - - content.Messages = append(content.Messages, AgentTextMessage{ - Role: ROLE_SYSTEM, - Content: prompt, - }) - - return nil -} - -type AgentContent interface { - ToJson() ([]byte, error) -} - -type ImageUrl struct { - Url string `json:"url"` -} - -type AgentImage struct { - ImageType string `json:"type"` - ImageUrl string `json:"image_url"` -} - -func (imageMessage AgentImage) ToJson() ([]byte, error) { - imageMessage.ImageType = IMAGE_TYPE - return json.Marshal(imageMessage) -} - -type AiClient interface { - GetImageInfo(imageName string, imageData []byte) (ImageInfo, error) -} - -type AgentClient struct { - url string - apiKey string - systemPrompt string - responseFormat string - - Do func(req *http.Request) (*http.Response, error) -} - -// func (client AgentClient) Do(req *http.Request) () { -// httpClient := http.Client{} -// return httpClient.Do(req) -// } - -const OPENAI_API_KEY = "OPENAI_API_KEY" -const ROLE_USER = "user" -const ROLE_SYSTEM = "system" -const IMAGE_TYPE = "image_url" - -// TODO: extract to text file probably -const PROMPT = ` -You are an image information extractor. The user will provide you with screenshots and your job is to extract any relevant links and text -that the image might contain. You will also try your best to assign some tags to this image, avoid too many tags. -Be sure to extract every link (URL) that you find. -Use generic tags. -` - -const RESPONSE_FORMAT = ` -{ - "name": "image_info", - "strict": true, - "schema": { - "type": "object", - "title": "image", - "required": ["tags", "text", "links"], - "additionalProperties": false, - "properties": { - "tags": { - "type": "array", - "title": "tags", - "description": "A list of tags you think the image is relevant to.", - "items": { - "type": "string" - } - }, - "text": { - "type": "array", - "title": "text", - "description": "A list of sentences the image contains.", - "items": { - "type": "string" - } - }, - "links": { - "type": "array", - "title": "links", - "description": "A list of all the links you can find in the image.", - "items": { - "type": "string" - } - }, - "locations": { - "title": "locations", - "type": "array", - "description": "A list of locations you can find on the image, if any", - "items": { - "type": "object", - "required": ["name"], - "additionalProperties": false, - "properties": { - "name": { - "title": "name", - "type": "string" - }, - "coordinates": { - "title": "coordinates", - "type": "string" - }, - "address": { - "title": "address", - "type": "string" - }, - "description": { - "title": "description", - "type": "string" - } - } - } - }, - "events": { - "title": "events", - "type": "array", - "description": "A list of events you find on the image, if any", - "items": { - "type": "object", - "required": ["name"], - "additionalProperties": false, - "properties": { - "name": { - "type": "string", - "title": "name" - }, - "locations": { - "title": "locations", - "type": "array", - "description": "A list of locations on this event, if any", - "items": { - "type": "object", - "required": ["name"], - "additionalProperties": false, - "properties": { - "name": { - "title": "name", - "type": "string" - }, - "coordinates": { - "title": "coordinates", - "type": "string" - }, - "address": { - "title": "address", - "type": "string" - }, - "description": { - "title": "description", - "type": "string" - } - } - } - } - } - } - } - } - } -} -` - -func CreateAgentClient(prompt string) (AgentClient, error) { - apiKey := os.Getenv(OPENAI_API_KEY) - - if len(apiKey) == 0 { - return AgentClient{}, errors.New(OPENAI_API_KEY + " was not found.") - } - - return AgentClient{ - apiKey: apiKey, - url: "https://api.mistral.ai/v1/chat/completions", - systemPrompt: prompt, - Do: func(req *http.Request) (*http.Response, error) { - client := &http.Client{} - return client.Do(req) - }, - }, nil -} - -func (client AgentClient) getRequest(body []byte) (*http.Request, error) { - req, err := http.NewRequest("POST", client.url, bytes.NewBuffer(body)) - if err != nil { - return req, err - } - - req.Header.Add("Authorization", "Bearer "+client.apiKey) - req.Header.Add("Content-Type", "application/json") - - return req, nil -} - -func getCompletionsForImage(model string, temperature float64, prompt string, imageName string, jsonSchema string, imageData []byte) (AgentRequestBody, error) { - request := AgentRequestBody{ - Model: model, - Temperature: temperature, - ResponseFormat: ResponseFormat{ - Type: "json_schema", - JsonSchema: jsonSchema, - }, - } - - // TODO: Add build pattern here that deals with errors in some internal state? - // I want a monad!!! - err := request.AddSystem(prompt) - if err != nil { - return request, err - } - - log.Println(request) - - err = request.AddImage(imageName, imageData) - if err != nil { - return request, err - } - - request.Tools = nil - - return request, nil -} - -type FunctionCall struct { - Name string `json:"name"` - Arguments string `json:"arguments"` -} - -type ToolCall struct { - Index int `json:"index"` - Id string `json:"id"` - Function FunctionCall `json:"function"` -} - -type ResponseChoiceMessage struct { - Role string `json:"role"` - Content string `json:"content"` - ToolCalls []ToolCall `json:"tool_calls"` -} - -type ResponseChoice struct { - Index int `json:"index"` - Message ResponseChoiceMessage `json:"message"` - FinishReason string `json:"finish_reason"` -} - -type AgentResponse struct { - Id string `json:"id"` - Object string `json:"object"` - Choices []ResponseChoice `json:"choices"` - Created int `json:"created"` -} - -// TODO: add usage parsing -func parseAgentResponse(jsonResponse []byte) (ImageInfo, error) { - response := AgentResponse{} - - err := json.Unmarshal(jsonResponse, &response) - if err != nil { - return ImageInfo{}, err - } - - if len(response.Choices) != 1 { - log.Println(string(jsonResponse)) - return ImageInfo{}, errors.New("Expected exactly one choice.") - } - - imageInfo := ImageInfo{} - err = json.Unmarshal([]byte(response.Choices[0].Message.Content), &imageInfo) - if err != nil { - return ImageInfo{}, errors.New("Could not parse content into image type.") - } - - return imageInfo, nil -} - -func (client AgentClient) Request(request *AgentRequestBody) (AgentResponse, error) { - jsonAiRequest, err := json.Marshal(request) - if err != nil { - return AgentResponse{}, err - } - - httpRequest, err := client.getRequest(jsonAiRequest) - if err != nil { - return AgentResponse{}, err - } - - resp, err := client.Do(httpRequest) - if err != nil { - return AgentResponse{}, err - } - - response, err := io.ReadAll(resp.Body) - if err != nil { - return AgentResponse{}, err - } - - agentResponse := AgentResponse{} - err = json.Unmarshal(response, &agentResponse) - - if err != nil { - return AgentResponse{}, err - } - - log.Println(string(response)) - - toolCalls := agentResponse.Choices[0].Message.ToolCalls - if len(toolCalls) > 0 { - // Should for sure be more flexible. - request.AddToolCall(AgentAssistantToolCall{ - Role: "assistant", - Content: "", - ToolCalls: toolCalls, - }) - } - - return agentResponse, nil -} - -func (client AgentClient) GetImageInfo(imageName string, imageData []byte) (ImageInfo, error) { - aiRequest, err := getCompletionsForImage("pixtral-12b-2409", 1.0, client.systemPrompt, imageName, RESPONSE_FORMAT, imageData) - if err != nil { - return ImageInfo{}, err - } - - var jsonSchema any - err = json.Unmarshal([]byte(RESPONSE_FORMAT), &jsonSchema) - if err != nil { - return ImageInfo{}, err - } - - aiRequest.ResponseFormat = ResponseFormat{ - Type: "json_schema", - JsonSchema: jsonSchema, - } - - jsonAiRequest, err := json.Marshal(aiRequest) - if err != nil { - return ImageInfo{}, err - } - - request, err := client.getRequest(jsonAiRequest) - if err != nil { - return ImageInfo{}, err - } - - resp, err := client.Do(request) - if err != nil { - return ImageInfo{}, err - } - - response, err := io.ReadAll(resp.Body) - if err != nil { - return ImageInfo{}, err - } - - log.Println(string(response)) - - return parseAgentResponse(response) -} diff --git a/backend/agents/agent_test.go b/backend/agents/agent_test.go deleted file mode 100644 index 770e0ab..0000000 --- a/backend/agents/agent_test.go +++ /dev/null @@ -1,213 +0,0 @@ -package agents - -import ( - "bytes" - "encoding/base64" - "encoding/json" - "fmt" - "io" - "net/http" - "testing" -) - -func TestMessageBuilder(t *testing.T) { - content := AgentMessages{} - - err := content.AddSystem("Some prompt") - - if err != nil { - t.Log(err) - t.FailNow() - } - - if len(content.Messages) != 1 { - t.Logf("Expected length 1, got %d.\n", len(content.Messages)) - t.FailNow() - } -} - -func TestMessageBuilderImage(t *testing.T) { - content := AgentMessages{} - - prompt := "some prompt" - imageTitle := "image.png" - data := []byte("some data") - - content.AddSystem(prompt) - content.AddImage(imageTitle, data) - - if len(content.Messages) != 2 { - t.Logf("Expected length 2, got %d.\n", len(content.Messages)) - t.FailNow() - } - - promptMessage, ok := content.Messages[0].(AgentTextMessage) - if !ok { - t.Logf("Expected text content message, got %T\n", content.Messages[0]) - t.FailNow() - } - - if promptMessage.Role != ROLE_SYSTEM { - t.Log("Prompt message role is incorrect.") - t.FailNow() - } - - if promptMessage.Content != prompt { - t.Log("Prompt message content is incorrect.") - t.FailNow() - } - - arrayContentMessage, ok := content.Messages[1].(AgentArrayMessage) - if !ok { - t.Logf("Expected text content message, got %T\n", content.Messages[1]) - t.FailNow() - } - - if arrayContentMessage.Role != ROLE_USER { - t.Log("Array content message role is incorrect.") - t.FailNow() - } - - if len(arrayContentMessage.Content) != 1 { - t.Logf("Expected length 1, got %d.\n", len(arrayContentMessage.Content)) - t.FailNow() - } - - imageContent, ok := arrayContentMessage.Content[0].(AgentImage) - if !ok { - t.Logf("Expected text content message, got %T\n", arrayContentMessage.Content[0]) - t.FailNow() - } - - base64data := base64.StdEncoding.EncodeToString(data) - url := fmt.Sprintf("data:image/%s;base64,%s", "png", base64data) - - if imageContent.ImageUrl != url { - t.Logf("Expected %s, but got %s.\n", url, imageContent.ImageUrl) - t.FailNow() - } -} - -func TestFullImageRequest(t *testing.T) { - request, err := getCompletionsForImage("model", 0.1, "You are an assistant", "image.png", "", []byte("some data")) - if err != nil { - t.Log(request) - t.FailNow() - } - - jsonData, err := json.Marshal(request) - if err != nil { - t.Log(err) - t.FailNow() - } - - expectedJson := `{"model":"model","temperature":0.1,"response_format":{"type":"json_schema","json_schema":""},"messages":[{"role":"system","content":"You are an assistant"},{"role":"user","content":[{"type":"image_url","image_url":{"url":"data:image/png;base64,c29tZSBkYXRh"}}]}]}` - - if string(jsonData) != expectedJson { - t.Logf("Expected:\n%s\n Got:\n%s\n", expectedJson, string(jsonData)) - t.FailNow() - } -} - -func TestResponse(t *testing.T) { - testResponse := `{"tags": ["tag1", "tag2"], "text": ["text1"], "links": []}` - buffer := bytes.NewReader([]byte(testResponse)) - - body := io.NopCloser(buffer) - - client := AgentClient{ - url: "http://localhost:1234", - apiKey: "some-key", - Do: func(_req *http.Request) (*http.Response, error) { - return &http.Response{Body: body}, nil - }, - } - - info, err := client.GetImageInfo("image.png", []byte("some data")) - if err != nil { - t.Log(err) - t.FailNow() - } - - if len(info.Tags) != 2 || len(info.Text) != 1 || len(info.Links) != 0 { - t.Logf("Some lengths are wrong.\nTags: %d\nText: %d\nLinks: %d\n", len(info.Tags), len(info.Text), len(info.Links)) - t.FailNow() - } - - if info.Tags[0] != "tag1" { - t.Log("0th tag is wrong.") - t.FailNow() - } - - if info.Tags[1] != "tag2" { - t.Log("1th tag is wrong.") - t.FailNow() - } - - if info.Text[0] != "text1" { - t.Log("0th text is wrong.") - t.FailNow() - } -} - -func TestResponseParsing(t *testing.T) { - response := `{ - "id": "chatcmpl-B4XgiHcd7A2nyK7eyARdggSvfFuWQ", - "object": "chat.completion", - "created": 1740422508, - "model": "gpt-4o-mini-2024-07-18", - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": "{\"links\":[\"link\"],\"tags\":[\"tag\"],\"text\":[\"text\"]}", - "refusal": null - }, - "logprobs": null, - "finish_reason": "stop" - } - ], - "usage": { - "prompt_tokens": 775, - "completion_tokens": 33, - "total_tokens": 808, - "prompt_tokens_details": { - "cached_tokens": 0, - "audio_tokens": 0 - }, - "completion_tokens_details": { - "reasoning_tokens": 0, - "audio_tokens": 0, - "accepted_prediction_tokens": 0, - "rejected_prediction_tokens": 0 - } - }, - "service_tier": "default", - "system_fingerprint": "fp_7fcd609668" -}` - - imageParsed, err := parseAgentResponse([]byte(response)) - if err != nil { - t.Log(err) - t.FailNow() - } - - if len(imageParsed.Links) != 1 || imageParsed.Links[0] != "link" { - t.Log(imageParsed) - t.Log("Should have one link called 'link'.") - t.FailNow() - } - - if len(imageParsed.Tags) != 1 || imageParsed.Tags[0] != "tag" { - t.Log(imageParsed) - t.Log("Should have one tag called 'tag'.") - t.FailNow() - } - - if len(imageParsed.Text) != 1 || imageParsed.Text[0] != "text" { - t.Log(imageParsed) - t.Log("Should have one text called 'text'.") - t.FailNow() - } -} diff --git a/backend/agents/client/client.go b/backend/agents/client/client.go new file mode 100644 index 0000000..c137b8a --- /dev/null +++ b/backend/agents/client/client.go @@ -0,0 +1,265 @@ +package client + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "net/http" + "os" + "path/filepath" + "screenmark/screenmark/.gen/haystack/haystack/model" +) + +type ImageInfo struct { + Tags []string `json:"tags"` + Text []string `json:"text"` + Links []string `json:"links"` + + Locations []model.Locations `json:"locations"` + Events []model.Events `json:"events"` +} + +type ResponseFormat struct { + Type string `json:"type"` + JsonSchema any `json:"json_schema"` +} + +type AgentRequestBody struct { + Model string `json:"model"` + Temperature float64 `json:"temperature"` + ResponseFormat ResponseFormat `json:"response_format"` + + Tools *any `json:"tools,omitempty"` + ToolChoice *string `json:"tool_choice,omitempty"` + + // ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` + + AgentMessages +} + +type AgentMessages struct { + Messages []AgentMessage `json:"messages"` +} + +type AgentMessage interface { + MessageToJson() ([]byte, error) +} + +type AgentTextMessage struct { + Role string `json:"role"` + Content string `json:"content"` + ToolCallId string `json:"tool_call_id,omitempty"` + Name string `json:"name,omitempty"` +} + +func (textContent AgentTextMessage) MessageToJson() ([]byte, error) { + // TODO: Validate the `Role`. + return json.Marshal(textContent) +} + +type ToolCall struct { + Index int `json:"index"` + Id string `json:"id"` + Function FunctionCall `json:"function"` +} + +type FunctionCall struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +type AgentAssistantToolCall struct { + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []ToolCall `json:"tool_calls"` +} + +func (toolCall AgentAssistantToolCall) MessageToJson() ([]byte, error) { + return json.Marshal(toolCall) +} + +type AgentArrayMessage struct { + Role string `json:"role"` + Content []AgentContent `json:"content"` +} + +func (arrayContent AgentArrayMessage) MessageToJson() ([]byte, error) { + return json.Marshal(arrayContent) +} + +func (content *AgentMessages) AddText(message AgentTextMessage) { + content.Messages = append(content.Messages, message) +} + +func (content *AgentMessages) AddToolCall(toolCall AgentAssistantToolCall) { + content.Messages = append(content.Messages, toolCall) +} + +func (content *AgentMessages) AddImage(imageName string, image []byte) error { + extension := filepath.Ext(imageName) + if len(extension) == 0 { + // TODO: could also validate for image types we support. + return errors.New("Image does not have extension") + } + + extension = extension[1:] + + encodedString := base64.StdEncoding.EncodeToString(image) + + arrayMessage := AgentArrayMessage{Role: ROLE_USER, Content: make([]AgentContent, 1)} + arrayMessage.Content[0] = AgentImage{ + ImageType: IMAGE_TYPE, + ImageUrl: fmt.Sprintf("data:image/%s;base64,%s", extension, encodedString), + } + + content.Messages = append(content.Messages, arrayMessage) + + return nil +} + +func (content *AgentMessages) AddSystem(prompt string) error { + if len(content.Messages) != 0 { + return errors.New("You can only add a system prompt at the beginning") + } + + content.Messages = append(content.Messages, AgentTextMessage{ + Role: ROLE_SYSTEM, + Content: prompt, + }) + + return nil +} + +type AgentContent interface { + ToJson() ([]byte, error) +} + +type ImageUrl struct { + Url string `json:"url"` +} + +type AgentImage struct { + ImageType string `json:"type"` + ImageUrl string `json:"image_url"` +} + +func (imageMessage AgentImage) ToJson() ([]byte, error) { + imageMessage.ImageType = IMAGE_TYPE + return json.Marshal(imageMessage) +} + +type AiClient interface { + GetImageInfo(imageName string, imageData []byte) (ImageInfo, error) +} + +type ResponseChoiceMessage struct { + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []ToolCall `json:"tool_calls"` +} + +type ResponseChoice struct { + Index int `json:"index"` + Message ResponseChoiceMessage `json:"message"` + FinishReason string `json:"finish_reason"` +} + +type AgentResponse struct { + Id string `json:"id"` + Object string `json:"object"` + Choices []ResponseChoice `json:"choices"` + Created int `json:"created"` +} + +type AgentClient struct { + url string + apiKey string + systemPrompt string + responseFormat string + + ToolHandler ToolsHandlers + + Do func(req *http.Request) (*http.Response, error) +} + +const OPENAI_API_KEY = "OPENAI_API_KEY" +const ROLE_USER = "user" +const ROLE_SYSTEM = "system" +const IMAGE_TYPE = "image_url" + +func CreateAgentClient(prompt string) (AgentClient, error) { + apiKey := os.Getenv(OPENAI_API_KEY) + + if len(apiKey) == 0 { + return AgentClient{}, errors.New(OPENAI_API_KEY + " was not found.") + } + + return AgentClient{ + apiKey: apiKey, + url: "https://api.mistral.ai/v1/chat/completions", + systemPrompt: prompt, + Do: func(req *http.Request) (*http.Response, error) { + client := &http.Client{} + return client.Do(req) + }, + }, nil +} + +func (client AgentClient) getRequest(body []byte) (*http.Request, error) { + req, err := http.NewRequest("POST", client.url, bytes.NewBuffer(body)) + if err != nil { + return req, err + } + + req.Header.Add("Authorization", "Bearer "+client.apiKey) + req.Header.Add("Content-Type", "application/json") + + return req, nil +} + +func (client AgentClient) Request(request *AgentRequestBody) (AgentResponse, error) { + jsonAiRequest, err := json.Marshal(request) + if err != nil { + return AgentResponse{}, err + } + + httpRequest, err := client.getRequest(jsonAiRequest) + if err != nil { + return AgentResponse{}, err + } + + resp, err := client.Do(httpRequest) + if err != nil { + return AgentResponse{}, err + } + + response, err := io.ReadAll(resp.Body) + if err != nil { + return AgentResponse{}, err + } + + agentResponse := AgentResponse{} + err = json.Unmarshal(response, &agentResponse) + + if err != nil { + return AgentResponse{}, err + } + + log.Println(string(response)) + + toolCalls := agentResponse.Choices[0].Message.ToolCalls + if len(toolCalls) > 0 { + // Should for sure be more flexible. + request.AddToolCall(AgentAssistantToolCall{ + Role: "assistant", + Content: "", + ToolCalls: toolCalls, + }) + } + + return agentResponse, nil +} diff --git a/backend/agents/client/tools.go b/backend/agents/client/tools.go new file mode 100644 index 0000000..b9c8f42 --- /dev/null +++ b/backend/agents/client/tools.go @@ -0,0 +1,79 @@ +package client + +import ( + "encoding/json" + "errors" + "log" + + "github.com/google/uuid" +) + +type ToolHandlerInfo struct { + UserId uuid.UUID + ImageId uuid.UUID +} + +type ToolHandler struct { + Fn func(info ToolHandlerInfo, args string, call ToolCall) (string, error) +} + +type ToolsHandlers struct { + handlers *map[string]ToolHandler +} + +func (handler ToolsHandlers) Handle(info ToolHandlerInfo, request *AgentRequestBody) (string, error) { + agentMessage := request.Messages[len(request.Messages)-1] + + toolCall, ok := agentMessage.(AgentAssistantToolCall) + if !ok { + return "", errors.New("Latest message was not a tool call.") + } + + fnName := toolCall.ToolCalls[0].Function.Name + arguments := toolCall.ToolCalls[0].Function.Arguments + + fnHandler, exists := (*handler.handlers)[fnName] + if !exists { + return "", errors.New("Could not find tool with this name.") + } + + log.Printf("Calling: %s\n", fnName) + res, err := fnHandler.Fn(info, arguments, toolCall.ToolCalls[0]) + if err != nil { + return "", err + } + + request.AddText(AgentTextMessage{ + Role: "tool", + Name: "createLocation", + Content: res, + ToolCallId: toolCall.ToolCalls[0].Id, + }) + + return res, nil +} + +func (handler ToolsHandlers) AddTool(name string, getArgs func() any, fn func(info ToolHandlerInfo, args any, call ToolCall) (any, error)) { + (*handler.handlers)["createLocation"] = ToolHandler{ + Fn: func(info ToolHandlerInfo, args string, call ToolCall) (string, error) { + argsStruct := getArgs() + + err := json.Unmarshal([]byte(args), &argsStruct) + if err != nil { + return "", err + } + + res, err := fn(info, argsStruct, call) + if err != nil { + return "", err + } + + marshalledRes, err := json.Marshal(res) + if err != nil { + return "", err + } + + return string(marshalledRes), nil + }, + } +} diff --git a/backend/agents/event_location_agent.go b/backend/agents/event_location_agent.go index 72757ad..d131d19 100644 --- a/backend/agents/event_location_agent.go +++ b/backend/agents/event_location_agent.go @@ -5,8 +5,8 @@ import ( "encoding/json" "errors" "log" - "reflect" "screenmark/screenmark/.gen/haystack/haystack/model" + "screenmark/screenmark/agents/client" "screenmark/screenmark/models" "time" @@ -116,13 +116,13 @@ const TOOLS = ` ` type EventLocationAgent struct { - client AgentClient + client client.AgentClient eventModel models.EventModel locationModel models.LocationModel contactModel models.ContactModel - toolHandler ToolsHandlers + toolHandler client.ToolsHandlers } type ListLocationArguments struct{} @@ -158,12 +158,12 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID toolChoice := "any" - request := AgentRequestBody{ + request := client.AgentRequestBody{ Tools: &tools, ToolChoice: &toolChoice, Model: "pixtral-12b-2409", Temperature: 0.3, - ResponseFormat: ResponseFormat{ + ResponseFormat: client.ResponseFormat{ Type: "text", }, } @@ -180,11 +180,12 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID return err } - toolHandlerInfo := ToolHandlerInfo{ - imageId: imageId, - userId: userId, + toolHandlerInfo := client.ToolHandlerInfo{ + ImageId: imageId, + UserId: userId, } + // TODO: this should go into a loop with toolHandler Handle function and not be here bruh. _, err = agent.toolHandler.Handle(toolHandlerInfo, &request) for err == nil { @@ -208,107 +209,41 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID return err } -// TODO: extract this into a more general tool handler package. -func (handler ToolsHandlers) Handle(info ToolHandlerInfo, request *AgentRequestBody) (string, error) { - agentMessage := request.Messages[len(request.Messages)-1] - - toolCall, ok := agentMessage.(AgentAssistantToolCall) - if !ok { - return "", errors.New("Latest message was not a tool call.") - } - - fnName := toolCall.ToolCalls[0].Function.Name - arguments := toolCall.ToolCalls[0].Function.Arguments - - if fnName == "finish" { - return "", errors.New("This is the end! Maybe we just return a boolean.") - } - - fn, exists := handler.Handlers[fnName] - if !exists { - return "", errors.New("Could not find tool with this name.") - } - - // holy jesus what the fuck. - parseMethod := reflect.ValueOf(fn).Field(1) - if !parseMethod.IsValid() { - return "", errors.New("Parse method not found") - } - - parsedArgs := parseMethod.Call([]reflect.Value{reflect.ValueOf(arguments)}) - if !parsedArgs[1].IsNil() { - return "", parsedArgs[1].Interface().(error) - } - - log.Printf("Calling: %s\n", fnName) - - fnMethod := reflect.ValueOf(fn).Field(2) - if !fnMethod.IsValid() { - return "", errors.New("Fn method not found") - } - - response := fnMethod.Call([]reflect.Value{reflect.ValueOf(info), parsedArgs[0], reflect.ValueOf(toolCall.ToolCalls[0])}) - if !response[1].IsNil() { - return "", response[1].Interface().(error) - } - - stringResponse, err := json.Marshal(response[0].Interface()) - if err != nil { - return "", err - } - - request.AddText(AgentTextMessage{ - Role: "tool", - Name: "createLocation", - Content: string(stringResponse), - ToolCallId: toolCall.ToolCalls[0].Id, - }) - - return string(stringResponse), nil -} - func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel, contactModel models.ContactModel) (EventLocationAgent, error) { - client, err := CreateAgentClient(eventLocationPrompt) + agentClient, err := client.CreateAgentClient(eventLocationPrompt) if err != nil { return EventLocationAgent{}, err } agent := EventLocationAgent{ - client: client, + client: agentClient, locationModel: locationModel, eventModel: eventModel, contactModel: contactModel, } - toolHandler := ToolsHandlers{ - Handlers: make(map[string]ToolHandlerInterface), - } - - toolHandler.Handlers["listLocations"] = ToolHandler[ListLocationArguments, []model.Locations]{ - FunctionName: "listLocations", - Parse: func(stringArgs string) (ListLocationArguments, error) { - args := ListLocationArguments{} - err := json.Unmarshal([]byte(stringArgs), &args) - - return args, err + agentClient.ToolHandler.AddTool("listLocations", + func() any { + return ListLocationArguments{} }, - Fn: func(info ToolHandlerInfo, _args ListLocationArguments, call ToolCall) ([]model.Locations, error) { - return agent.locationModel.List(context.Background(), info.userId) + func(info client.ToolHandlerInfo, _args any, call client.ToolCall) (any, error) { + return agent.locationModel.List(context.Background(), info.UserId) }, - } + ) - toolHandler.Handlers["createLocation"] = ToolHandler[CreateLocationArguments, model.Locations]{ - FunctionName: "createLocation", - Parse: func(stringArgs string) (CreateLocationArguments, error) { - args := CreateLocationArguments{} - err := json.Unmarshal([]byte(stringArgs), &args) - - return args, err + agentClient.ToolHandler.AddTool("createLocation", + func() any { + return CreateLocationArguments{} }, - Fn: func(info ToolHandlerInfo, args CreateLocationArguments, call ToolCall) (model.Locations, error) { + func(info client.ToolHandlerInfo, _args any, call client.ToolCall) (any, error) { + args, ok := _args.(CreateLocationArguments) + if !ok { + return _args, errors.New("Type error, arguments are not of the correct struct type") + } + ctx := context.Background() - location, err := agent.locationModel.Save(ctx, info.userId, model.Locations{ + location, err := agent.locationModel.Save(ctx, info.UserId, model.Locations{ Name: args.Name, Address: args.Address, }) @@ -317,21 +252,22 @@ func NewLocationEventAgent(locationModel models.LocationModel, eventModel models return location, err } - _, err = agent.locationModel.SaveToImage(ctx, info.imageId, location.ID) + _, err = agent.locationModel.SaveToImage(ctx, info.ImageId, location.ID) return location, err }, - } + ) - toolHandler.Handlers["createEvent"] = ToolHandler[CreateEventArguments, model.Events]{ - FunctionName: "createEvent", - Parse: func(stringArgs string) (CreateEventArguments, error) { - args := CreateEventArguments{} - err := json.Unmarshal([]byte(stringArgs), &args) - - return args, err + agentClient.ToolHandler.AddTool("createEvent", + func() any { + return CreateEventArguments{} }, - Fn: func(info ToolHandlerInfo, args CreateEventArguments, call ToolCall) (model.Events, error) { + func(info client.ToolHandlerInfo, _args any, call client.ToolCall) (any, error) { + args, ok := _args.(CreateEventArguments) + if !ok { + return _args, errors.New("Type error, arguments are not of the correct struct type") + } + ctx := context.Background() layout := "2006-01-02T15:04:05Z" @@ -346,7 +282,7 @@ func NewLocationEventAgent(locationModel models.LocationModel, eventModel models return model.Events{}, err } - event, err := agent.eventModel.Save(ctx, info.userId, model.Events{ + event, err := agent.eventModel.Save(ctx, info.UserId, model.Events{ Name: args.Name, StartDateTime: &startTime, EndDateTime: &endTime, @@ -356,7 +292,7 @@ func NewLocationEventAgent(locationModel models.LocationModel, eventModel models return event, err } - organizer, err := agent.contactModel.Save(ctx, info.userId, model.Contacts{ + organizer, err := agent.contactModel.Save(ctx, info.UserId, model.Contacts{ Name: args.Name, }) @@ -364,12 +300,12 @@ func NewLocationEventAgent(locationModel models.LocationModel, eventModel models return event, err } - _, err = agent.eventModel.SaveToImage(ctx, info.imageId, event.ID) + _, err = agent.eventModel.SaveToImage(ctx, info.ImageId, event.ID) if err != nil { return event, err } - _, err = agent.contactModel.SaveToImage(ctx, info.imageId, organizer.ID) + _, err = agent.contactModel.SaveToImage(ctx, info.ImageId, organizer.ID) if err != nil { return event, err } @@ -386,9 +322,7 @@ func NewLocationEventAgent(locationModel models.LocationModel, eventModel models return agent.eventModel.UpdateOrganizer(ctx, event.ID, organizer.ID) }, - } - - agent.toolHandler = toolHandler + ) return agent, nil } diff --git a/backend/agents/note_agent.go b/backend/agents/note_agent.go index 1616ea6..3d62c94 100644 --- a/backend/agents/note_agent.go +++ b/backend/agents/note_agent.go @@ -3,6 +3,7 @@ package agents import ( "context" "screenmark/screenmark/.gen/haystack/haystack/model" + "screenmark/screenmark/agents/client" "screenmark/screenmark/models" "github.com/google/uuid" @@ -19,16 +20,16 @@ Do not return anything except markdown. ` type NoteAgent struct { - client AgentClient + client client.AgentClient noteModel models.NoteModel } func (agent NoteAgent) GetNotes(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error { - request := AgentRequestBody{ + request := client.AgentRequestBody{ Model: "pixtral-12b-2409", Temperature: 0.3, - ResponseFormat: ResponseFormat{ + ResponseFormat: client.ResponseFormat{ Type: "text", }, } @@ -65,7 +66,7 @@ func (agent NoteAgent) GetNotes(userId uuid.UUID, imageId uuid.UUID, imageName s } func NewNoteAgent(noteModel models.NoteModel) (NoteAgent, error) { - client, err := CreateAgentClient(noteAgentPrompt) + client, err := client.CreateAgentClient(noteAgentPrompt) if err != nil { return NoteAgent{}, err } diff --git a/backend/agents/orchestrator.go b/backend/agents/orchestrator.go index a3d4934..cff08b3 100644 --- a/backend/agents/orchestrator.go +++ b/backend/agents/orchestrator.go @@ -3,6 +3,7 @@ package agents import ( "encoding/json" "fmt" + "screenmark/screenmark/agents/client" "github.com/google/uuid" ) @@ -74,7 +75,7 @@ const MY_TOOLS = ` ]` type OrchestratorAgent struct { - client AgentClient + client client.AgentClient } func (agent OrchestratorAgent) Orchestrate(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error { @@ -86,10 +87,10 @@ func (agent OrchestratorAgent) Orchestrate(userId uuid.UUID, imageId uuid.UUID, return err } - request := AgentRequestBody{ + request := client.AgentRequestBody{ Model: "pixtral-12b-2409", Temperature: 0.3, - ResponseFormat: ResponseFormat{ + ResponseFormat: client.ResponseFormat{ Type: "text", }, ToolChoice: &toolChoice, @@ -113,7 +114,7 @@ func (agent OrchestratorAgent) Orchestrate(userId uuid.UUID, imageId uuid.UUID, } func NewOrchestratorAgent() (OrchestratorAgent, error) { - agent, err := CreateAgentClient(orchestratorPrompt) + agent, err := client.CreateAgentClient(orchestratorPrompt) if err != nil { return OrchestratorAgent{}, err } diff --git a/backend/agents/tools_handler.go b/backend/agents/tools_handler.go deleted file mode 100644 index 5981d50..0000000 --- a/backend/agents/tools_handler.go +++ /dev/null @@ -1,26 +0,0 @@ -package agents - -import "github.com/google/uuid" - -type ToolHandlerInfo struct { - userId uuid.UUID - imageId uuid.UUID -} - -type ToolHandler[TArgs any, TResp any] struct { - FunctionName string - Parse func(args string) (TArgs, error) - Fn func(info ToolHandlerInfo, args TArgs, call ToolCall) (TResp, error) -} - -type ToolHandlerInterface interface { - GetFunctionName() string -} - -func (handler ToolHandler[TArgs, TResp]) GetFunctionName() string { - return handler.FunctionName -} - -type ToolsHandlers struct { - Handlers map[string]ToolHandlerInterface -} diff --git a/backend/main.go b/backend/main.go index f1b0530..dbc80df 100644 --- a/backend/main.go +++ b/backend/main.go @@ -13,6 +13,7 @@ import ( "path/filepath" "screenmark/screenmark/.gen/haystack/haystack/model" "screenmark/screenmark/agents" + "screenmark/screenmark/agents/client" "screenmark/screenmark/models" "time" @@ -24,41 +25,13 @@ import ( ) type TestAiClient struct { - ImageInfo agents.ImageInfo + ImageInfo client.ImageInfo } -func (client TestAiClient) GetImageInfo(imageName string, imageData []byte) (agents.ImageInfo, error) { +func (client TestAiClient) GetImageInfo(imageName string, imageData []byte) (client.ImageInfo, error) { return client.ImageInfo, nil } -func GetAiClient() (agents.AiClient, error) { - mode := os.Getenv("MODE") - if mode == "TESTING" { - address := "10 Downing Street" - description := "Cheese and Crackers" - - return TestAiClient{ - ImageInfo: agents.ImageInfo{ - Tags: []string{"tag"}, - Links: []string{"links"}, - Text: []string{"text"}, - Locations: []model.Locations{{ - ID: uuid.Nil, - Name: "London", - Address: &address, - }}, - Events: []model.Events{{ - ID: uuid.Nil, - Name: "Party", - Description: &description, - }}, - }, - }, nil - } - - return agents.CreateAgentClient(agents.PROMPT) -} - func main() { err := godotenv.Load() if err != nil { @@ -74,9 +47,6 @@ func main() { } imageModel := models.NewImageModel(db) - linkModel := models.NewLinkModel(db) - tagModel := models.NewTagModel(db) - textModel := models.NewTextModel(db) locationModel := models.NewLocationModel(db) eventModel := models.NewEventModel(db) userModel := models.NewUserModel(db) @@ -105,11 +75,6 @@ func main() { ctx := context.Background() go func() { - openAiClient, err := GetAiClient() - if err != nil { - panic(err) - } - locationAgent, err := agents.NewLocationEventAgent(locationModel, eventModel, contactModel) if err != nil { panic(err) @@ -132,7 +97,7 @@ func main() { return } - userImage, err := imageModel.FinishProcessing(ctx, image.ID) + _, err = imageModel.FinishProcessing(ctx, image.ID) if err != nil { log.Println("Failed to FinishProcessing") log.Println(err) @@ -152,36 +117,6 @@ func main() { log.Println("Calling noteAgent!") err = noteAgent.GetNotes(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image) log.Println(err) - - return - - imageInfo, err := openAiClient.GetImageInfo(image.Image.ImageName, image.Image.Image) - if err != nil { - log.Println("Failed to GetImageInfo") - log.Println(err) - return - } - - err = tagModel.SaveToImage(ctx, userImage.ImageID, imageInfo.Tags) - if err != nil { - log.Println("Failed to save tags") - log.Println(err) - return - } - - err = linkModel.Save(ctx, userImage.ImageID, imageInfo.Links) - if err != nil { - log.Println("Failed to save links") - log.Println(err) - return - } - - err = textModel.Save(ctx, userImage.ImageID, imageInfo.Text) - if err != nil { - log.Println("Failed to save text") - log.Println(err) - return - } }() } }