From 8a165c2042b7c8e82f62c39beef395555f65ce99 Mon Sep 17 00:00:00 2001 From: John Costa Date: Fri, 4 Apr 2025 20:40:31 +0100 Subject: [PATCH 01/17] wip(orchestrator): basic scaffolding for the agent --- backend/agents/agent.go | 2 + backend/agents/orchestrator.go | 124 +++++++++++++++++++++++++++++++++ backend/main.go | 9 +++ 3 files changed, 135 insertions(+) create mode 100644 backend/agents/orchestrator.go diff --git a/backend/agents/agent.go b/backend/agents/agent.go index 231037b..30aa9c1 100644 --- a/backend/agents/agent.go +++ b/backend/agents/agent.go @@ -36,6 +36,8 @@ type AgentRequestBody struct { Tools *any `json:"tools,omitempty"` ToolChoice *string `json:"tool_choice,omitempty"` + // ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` + AgentMessages } diff --git a/backend/agents/orchestrator.go b/backend/agents/orchestrator.go new file mode 100644 index 0000000..a3d4934 --- /dev/null +++ b/backend/agents/orchestrator.go @@ -0,0 +1,124 @@ +package agents + +import ( + "encoding/json" + "fmt" + + "github.com/google/uuid" +) + +const orchestratorPrompt = ` +You are an Orchestrator for various AI agents. + +The user will send you images and you have to determine which agents you have to call, in order to best help the user. + +You might decide no agent needs to be called. + +The agents are available as tool calls. + +Agents available: + +eventLocationAgent + +Use it when you think the image contains an event or a location of any sort. This can be an event page, a map, an address or a date. + +noteAgent + +Use it when there is text on the screen. Any text, always use this. Use me! + +defaultAgent + +When none of the above apply. + +Always call agents in parallel if you need to call more than 1. +` + +const MY_TOOLS = ` +[ + { + "type": "function", + "function": { + "name": "eventLocationAgent", + "description": "Uses the event location agent", + "parameters": { + "type": "object", + "properties": {}, + "required": [] + } + } + }, + { + "type": "function", + "function": { + "name": "noteAgent", + "description": "Uses the note agent", + "parameters": { + "type": "object", + "properties": {}, + "required": [] + } + } + }, + { + "type": "function", + "function": { + "name": "defaultAgent", + "description": "Used when you dont think its a good idea to call other agents", + "parameters": { + "type": "object", + "properties": {}, + "required": [] + } + } + } +]` + +type OrchestratorAgent struct { + client AgentClient +} + +func (agent OrchestratorAgent) Orchestrate(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error { + toolChoice := "any" + + var tools any + err := json.Unmarshal([]byte(MY_TOOLS), &tools) + if err != nil { + return err + } + + request := AgentRequestBody{ + Model: "pixtral-12b-2409", + Temperature: 0.3, + ResponseFormat: ResponseFormat{ + Type: "text", + }, + ToolChoice: &toolChoice, + Tools: &tools, + } + + err = request.AddSystem(orchestratorPrompt) + if err != nil { + return err + } + + request.AddImage(imageName, imageData) + resp, err := agent.client.Request(&request) + if err != nil { + return err + } + + fmt.Println(resp) + + return nil +} + +func NewOrchestratorAgent() (OrchestratorAgent, error) { + agent, err := CreateAgentClient(orchestratorPrompt) + if err != nil { + return OrchestratorAgent{}, err + } + + return OrchestratorAgent{ + client: agent, + }, nil +} diff --git a/backend/main.go b/backend/main.go index 60d8079..f1b0530 100644 --- a/backend/main.go +++ b/backend/main.go @@ -120,6 +120,11 @@ func main() { panic(err) } + orchestrator, err := agents.NewOrchestratorAgent() + if err != nil { + panic(err) + } + image, err := imageModel.GetToProcessWithData(ctx, imageId) if err != nil { log.Println("Failed to GetToProcessWithData") @@ -134,6 +139,10 @@ func main() { return } + orchestrator.Orchestrate(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image) + + return + // TODO: this can very much be parallel log.Println("Calling locationAgent!") -- 2.47.2 From 71d4581110b89002847c3e6c9e56e4a027a7ec4b Mon Sep 17 00:00:00 2001 From: John Costa Date: Fri, 4 Apr 2025 22:03:46 +0100 Subject: [PATCH 02/17] refactor(ai-client): moving tool handling and client into seperate folders --- backend/agents/agent.go | 481 ------------------------- backend/agents/agent_test.go | 213 ----------- backend/agents/client/client.go | 265 ++++++++++++++ backend/agents/client/tools.go | 79 ++++ backend/agents/event_location_agent.go | 152 +++----- backend/agents/note_agent.go | 9 +- backend/agents/orchestrator.go | 9 +- backend/agents/tools_handler.go | 26 -- backend/main.go | 73 +--- 9 files changed, 401 insertions(+), 906 deletions(-) delete mode 100644 backend/agents/agent.go delete mode 100644 backend/agents/agent_test.go create mode 100644 backend/agents/client/client.go create mode 100644 backend/agents/client/tools.go delete mode 100644 backend/agents/tools_handler.go diff --git a/backend/agents/agent.go b/backend/agents/agent.go deleted file mode 100644 index 30aa9c1..0000000 --- a/backend/agents/agent.go +++ /dev/null @@ -1,481 +0,0 @@ -package agents - -import ( - "bytes" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "io" - "log" - "net/http" - "os" - "path/filepath" - "screenmark/screenmark/.gen/haystack/haystack/model" -) - -type ImageInfo struct { - Tags []string `json:"tags"` - Text []string `json:"text"` - Links []string `json:"links"` - - Locations []model.Locations `json:"locations"` - Events []model.Events `json:"events"` -} - -type ResponseFormat struct { - Type string `json:"type"` - JsonSchema any `json:"json_schema"` -} - -type AgentRequestBody struct { - Model string `json:"model"` - Temperature float64 `json:"temperature"` - ResponseFormat ResponseFormat `json:"response_format"` - - Tools *any `json:"tools,omitempty"` - ToolChoice *string `json:"tool_choice,omitempty"` - - // ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` - - AgentMessages -} - -type AgentMessages struct { - Messages []AgentMessage `json:"messages"` -} - -type AgentMessage interface { - MessageToJson() ([]byte, error) -} - -type AgentTextMessage struct { - Role string `json:"role"` - Content string `json:"content"` - ToolCallId string `json:"tool_call_id,omitempty"` - Name string `json:"name,omitempty"` -} - -func (textContent AgentTextMessage) MessageToJson() ([]byte, error) { - // TODO: Validate the `Role`. - return json.Marshal(textContent) -} - -type AgentAssistantToolCall struct { - Role string `json:"role"` - Content string `json:"content"` - ToolCalls []ToolCall `json:"tool_calls"` -} - -func (toolCall AgentAssistantToolCall) MessageToJson() ([]byte, error) { - return json.Marshal(toolCall) -} - -type AgentArrayMessage struct { - Role string `json:"role"` - Content []AgentContent `json:"content"` -} - -func (arrayContent AgentArrayMessage) MessageToJson() ([]byte, error) { - return json.Marshal(arrayContent) -} - -func (content *AgentMessages) AddText(message AgentTextMessage) { - content.Messages = append(content.Messages, message) -} - -func (content *AgentMessages) AddToolCall(toolCall AgentAssistantToolCall) { - content.Messages = append(content.Messages, toolCall) -} - -func (content *AgentMessages) AddImage(imageName string, image []byte) error { - extension := filepath.Ext(imageName) - if len(extension) == 0 { - // TODO: could also validate for image types we support. - return errors.New("Image does not have extension") - } - - extension = extension[1:] - - encodedString := base64.StdEncoding.EncodeToString(image) - - arrayMessage := AgentArrayMessage{Role: ROLE_USER, Content: make([]AgentContent, 1)} - arrayMessage.Content[0] = AgentImage{ - ImageType: IMAGE_TYPE, - ImageUrl: fmt.Sprintf("data:image/%s;base64,%s", extension, encodedString), - } - - content.Messages = append(content.Messages, arrayMessage) - - return nil -} - -func (content *AgentMessages) AddSystem(prompt string) error { - if len(content.Messages) != 0 { - return errors.New("You can only add a system prompt at the beginning") - } - - content.Messages = append(content.Messages, AgentTextMessage{ - Role: ROLE_SYSTEM, - Content: prompt, - }) - - return nil -} - -type AgentContent interface { - ToJson() ([]byte, error) -} - -type ImageUrl struct { - Url string `json:"url"` -} - -type AgentImage struct { - ImageType string `json:"type"` - ImageUrl string `json:"image_url"` -} - -func (imageMessage AgentImage) ToJson() ([]byte, error) { - imageMessage.ImageType = IMAGE_TYPE - return json.Marshal(imageMessage) -} - -type AiClient interface { - GetImageInfo(imageName string, imageData []byte) (ImageInfo, error) -} - -type AgentClient struct { - url string - apiKey string - systemPrompt string - responseFormat string - - Do func(req *http.Request) (*http.Response, error) -} - -// func (client AgentClient) Do(req *http.Request) () { -// httpClient := http.Client{} -// return httpClient.Do(req) -// } - -const OPENAI_API_KEY = "OPENAI_API_KEY" -const ROLE_USER = "user" -const ROLE_SYSTEM = "system" -const IMAGE_TYPE = "image_url" - -// TODO: extract to text file probably -const PROMPT = ` -You are an image information extractor. The user will provide you with screenshots and your job is to extract any relevant links and text -that the image might contain. You will also try your best to assign some tags to this image, avoid too many tags. -Be sure to extract every link (URL) that you find. -Use generic tags. -` - -const RESPONSE_FORMAT = ` -{ - "name": "image_info", - "strict": true, - "schema": { - "type": "object", - "title": "image", - "required": ["tags", "text", "links"], - "additionalProperties": false, - "properties": { - "tags": { - "type": "array", - "title": "tags", - "description": "A list of tags you think the image is relevant to.", - "items": { - "type": "string" - } - }, - "text": { - "type": "array", - "title": "text", - "description": "A list of sentences the image contains.", - "items": { - "type": "string" - } - }, - "links": { - "type": "array", - "title": "links", - "description": "A list of all the links you can find in the image.", - "items": { - "type": "string" - } - }, - "locations": { - "title": "locations", - "type": "array", - "description": "A list of locations you can find on the image, if any", - "items": { - "type": "object", - "required": ["name"], - "additionalProperties": false, - "properties": { - "name": { - "title": "name", - "type": "string" - }, - "coordinates": { - "title": "coordinates", - "type": "string" - }, - "address": { - "title": "address", - "type": "string" - }, - "description": { - "title": "description", - "type": "string" - } - } - } - }, - "events": { - "title": "events", - "type": "array", - "description": "A list of events you find on the image, if any", - "items": { - "type": "object", - "required": ["name"], - "additionalProperties": false, - "properties": { - "name": { - "type": "string", - "title": "name" - }, - "locations": { - "title": "locations", - "type": "array", - "description": "A list of locations on this event, if any", - "items": { - "type": "object", - "required": ["name"], - "additionalProperties": false, - "properties": { - "name": { - "title": "name", - "type": "string" - }, - "coordinates": { - "title": "coordinates", - "type": "string" - }, - "address": { - "title": "address", - "type": "string" - }, - "description": { - "title": "description", - "type": "string" - } - } - } - } - } - } - } - } - } -} -` - -func CreateAgentClient(prompt string) (AgentClient, error) { - apiKey := os.Getenv(OPENAI_API_KEY) - - if len(apiKey) == 0 { - return AgentClient{}, errors.New(OPENAI_API_KEY + " was not found.") - } - - return AgentClient{ - apiKey: apiKey, - url: "https://api.mistral.ai/v1/chat/completions", - systemPrompt: prompt, - Do: func(req *http.Request) (*http.Response, error) { - client := &http.Client{} - return client.Do(req) - }, - }, nil -} - -func (client AgentClient) getRequest(body []byte) (*http.Request, error) { - req, err := http.NewRequest("POST", client.url, bytes.NewBuffer(body)) - if err != nil { - return req, err - } - - req.Header.Add("Authorization", "Bearer "+client.apiKey) - req.Header.Add("Content-Type", "application/json") - - return req, nil -} - -func getCompletionsForImage(model string, temperature float64, prompt string, imageName string, jsonSchema string, imageData []byte) (AgentRequestBody, error) { - request := AgentRequestBody{ - Model: model, - Temperature: temperature, - ResponseFormat: ResponseFormat{ - Type: "json_schema", - JsonSchema: jsonSchema, - }, - } - - // TODO: Add build pattern here that deals with errors in some internal state? - // I want a monad!!! - err := request.AddSystem(prompt) - if err != nil { - return request, err - } - - log.Println(request) - - err = request.AddImage(imageName, imageData) - if err != nil { - return request, err - } - - request.Tools = nil - - return request, nil -} - -type FunctionCall struct { - Name string `json:"name"` - Arguments string `json:"arguments"` -} - -type ToolCall struct { - Index int `json:"index"` - Id string `json:"id"` - Function FunctionCall `json:"function"` -} - -type ResponseChoiceMessage struct { - Role string `json:"role"` - Content string `json:"content"` - ToolCalls []ToolCall `json:"tool_calls"` -} - -type ResponseChoice struct { - Index int `json:"index"` - Message ResponseChoiceMessage `json:"message"` - FinishReason string `json:"finish_reason"` -} - -type AgentResponse struct { - Id string `json:"id"` - Object string `json:"object"` - Choices []ResponseChoice `json:"choices"` - Created int `json:"created"` -} - -// TODO: add usage parsing -func parseAgentResponse(jsonResponse []byte) (ImageInfo, error) { - response := AgentResponse{} - - err := json.Unmarshal(jsonResponse, &response) - if err != nil { - return ImageInfo{}, err - } - - if len(response.Choices) != 1 { - log.Println(string(jsonResponse)) - return ImageInfo{}, errors.New("Expected exactly one choice.") - } - - imageInfo := ImageInfo{} - err = json.Unmarshal([]byte(response.Choices[0].Message.Content), &imageInfo) - if err != nil { - return ImageInfo{}, errors.New("Could not parse content into image type.") - } - - return imageInfo, nil -} - -func (client AgentClient) Request(request *AgentRequestBody) (AgentResponse, error) { - jsonAiRequest, err := json.Marshal(request) - if err != nil { - return AgentResponse{}, err - } - - httpRequest, err := client.getRequest(jsonAiRequest) - if err != nil { - return AgentResponse{}, err - } - - resp, err := client.Do(httpRequest) - if err != nil { - return AgentResponse{}, err - } - - response, err := io.ReadAll(resp.Body) - if err != nil { - return AgentResponse{}, err - } - - agentResponse := AgentResponse{} - err = json.Unmarshal(response, &agentResponse) - - if err != nil { - return AgentResponse{}, err - } - - log.Println(string(response)) - - toolCalls := agentResponse.Choices[0].Message.ToolCalls - if len(toolCalls) > 0 { - // Should for sure be more flexible. - request.AddToolCall(AgentAssistantToolCall{ - Role: "assistant", - Content: "", - ToolCalls: toolCalls, - }) - } - - return agentResponse, nil -} - -func (client AgentClient) GetImageInfo(imageName string, imageData []byte) (ImageInfo, error) { - aiRequest, err := getCompletionsForImage("pixtral-12b-2409", 1.0, client.systemPrompt, imageName, RESPONSE_FORMAT, imageData) - if err != nil { - return ImageInfo{}, err - } - - var jsonSchema any - err = json.Unmarshal([]byte(RESPONSE_FORMAT), &jsonSchema) - if err != nil { - return ImageInfo{}, err - } - - aiRequest.ResponseFormat = ResponseFormat{ - Type: "json_schema", - JsonSchema: jsonSchema, - } - - jsonAiRequest, err := json.Marshal(aiRequest) - if err != nil { - return ImageInfo{}, err - } - - request, err := client.getRequest(jsonAiRequest) - if err != nil { - return ImageInfo{}, err - } - - resp, err := client.Do(request) - if err != nil { - return ImageInfo{}, err - } - - response, err := io.ReadAll(resp.Body) - if err != nil { - return ImageInfo{}, err - } - - log.Println(string(response)) - - return parseAgentResponse(response) -} diff --git a/backend/agents/agent_test.go b/backend/agents/agent_test.go deleted file mode 100644 index 770e0ab..0000000 --- a/backend/agents/agent_test.go +++ /dev/null @@ -1,213 +0,0 @@ -package agents - -import ( - "bytes" - "encoding/base64" - "encoding/json" - "fmt" - "io" - "net/http" - "testing" -) - -func TestMessageBuilder(t *testing.T) { - content := AgentMessages{} - - err := content.AddSystem("Some prompt") - - if err != nil { - t.Log(err) - t.FailNow() - } - - if len(content.Messages) != 1 { - t.Logf("Expected length 1, got %d.\n", len(content.Messages)) - t.FailNow() - } -} - -func TestMessageBuilderImage(t *testing.T) { - content := AgentMessages{} - - prompt := "some prompt" - imageTitle := "image.png" - data := []byte("some data") - - content.AddSystem(prompt) - content.AddImage(imageTitle, data) - - if len(content.Messages) != 2 { - t.Logf("Expected length 2, got %d.\n", len(content.Messages)) - t.FailNow() - } - - promptMessage, ok := content.Messages[0].(AgentTextMessage) - if !ok { - t.Logf("Expected text content message, got %T\n", content.Messages[0]) - t.FailNow() - } - - if promptMessage.Role != ROLE_SYSTEM { - t.Log("Prompt message role is incorrect.") - t.FailNow() - } - - if promptMessage.Content != prompt { - t.Log("Prompt message content is incorrect.") - t.FailNow() - } - - arrayContentMessage, ok := content.Messages[1].(AgentArrayMessage) - if !ok { - t.Logf("Expected text content message, got %T\n", content.Messages[1]) - t.FailNow() - } - - if arrayContentMessage.Role != ROLE_USER { - t.Log("Array content message role is incorrect.") - t.FailNow() - } - - if len(arrayContentMessage.Content) != 1 { - t.Logf("Expected length 1, got %d.\n", len(arrayContentMessage.Content)) - t.FailNow() - } - - imageContent, ok := arrayContentMessage.Content[0].(AgentImage) - if !ok { - t.Logf("Expected text content message, got %T\n", arrayContentMessage.Content[0]) - t.FailNow() - } - - base64data := base64.StdEncoding.EncodeToString(data) - url := fmt.Sprintf("data:image/%s;base64,%s", "png", base64data) - - if imageContent.ImageUrl != url { - t.Logf("Expected %s, but got %s.\n", url, imageContent.ImageUrl) - t.FailNow() - } -} - -func TestFullImageRequest(t *testing.T) { - request, err := getCompletionsForImage("model", 0.1, "You are an assistant", "image.png", "", []byte("some data")) - if err != nil { - t.Log(request) - t.FailNow() - } - - jsonData, err := json.Marshal(request) - if err != nil { - t.Log(err) - t.FailNow() - } - - expectedJson := `{"model":"model","temperature":0.1,"response_format":{"type":"json_schema","json_schema":""},"messages":[{"role":"system","content":"You are an assistant"},{"role":"user","content":[{"type":"image_url","image_url":{"url":""}}]}]}` - - if string(jsonData) != expectedJson { - t.Logf("Expected:\n%s\n Got:\n%s\n", expectedJson, string(jsonData)) - t.FailNow() - } -} - -func TestResponse(t *testing.T) { - testResponse := `{"tags": ["tag1", "tag2"], "text": ["text1"], "links": []}` - buffer := bytes.NewReader([]byte(testResponse)) - - body := io.NopCloser(buffer) - - client := AgentClient{ - url: "http://localhost:1234", - apiKey: "some-key", - Do: func(_req *http.Request) (*http.Response, error) { - return &http.Response{Body: body}, nil - }, - } - - info, err := client.GetImageInfo("image.png", []byte("some data")) - if err != nil { - t.Log(err) - t.FailNow() - } - - if len(info.Tags) != 2 || len(info.Text) != 1 || len(info.Links) != 0 { - t.Logf("Some lengths are wrong.\nTags: %d\nText: %d\nLinks: %d\n", len(info.Tags), len(info.Text), len(info.Links)) - t.FailNow() - } - - if info.Tags[0] != "tag1" { - t.Log("0th tag is wrong.") - t.FailNow() - } - - if info.Tags[1] != "tag2" { - t.Log("1th tag is wrong.") - t.FailNow() - } - - if info.Text[0] != "text1" { - t.Log("0th text is wrong.") - t.FailNow() - } -} - -func TestResponseParsing(t *testing.T) { - response := `{ - "id": "chatcmpl-B4XgiHcd7A2nyK7eyARdggSvfFuWQ", - "object": "chat.completion", - "created": 1740422508, - "model": "gpt-4o-mini-2024-07-18", - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": "{\"links\":[\"link\"],\"tags\":[\"tag\"],\"text\":[\"text\"]}", - "refusal": null - }, - "logprobs": null, - "finish_reason": "stop" - } - ], - "usage": { - "prompt_tokens": 775, - "completion_tokens": 33, - "total_tokens": 808, - "prompt_tokens_details": { - "cached_tokens": 0, - "audio_tokens": 0 - }, - "completion_tokens_details": { - "reasoning_tokens": 0, - "audio_tokens": 0, - "accepted_prediction_tokens": 0, - "rejected_prediction_tokens": 0 - } - }, - "service_tier": "default", - "system_fingerprint": "fp_7fcd609668" -}` - - imageParsed, err := parseAgentResponse([]byte(response)) - if err != nil { - t.Log(err) - t.FailNow() - } - - if len(imageParsed.Links) != 1 || imageParsed.Links[0] != "link" { - t.Log(imageParsed) - t.Log("Should have one link called 'link'.") - t.FailNow() - } - - if len(imageParsed.Tags) != 1 || imageParsed.Tags[0] != "tag" { - t.Log(imageParsed) - t.Log("Should have one tag called 'tag'.") - t.FailNow() - } - - if len(imageParsed.Text) != 1 || imageParsed.Text[0] != "text" { - t.Log(imageParsed) - t.Log("Should have one text called 'text'.") - t.FailNow() - } -} diff --git a/backend/agents/client/client.go b/backend/agents/client/client.go new file mode 100644 index 0000000..c137b8a --- /dev/null +++ b/backend/agents/client/client.go @@ -0,0 +1,265 @@ +package client + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "net/http" + "os" + "path/filepath" + "screenmark/screenmark/.gen/haystack/haystack/model" +) + +type ImageInfo struct { + Tags []string `json:"tags"` + Text []string `json:"text"` + Links []string `json:"links"` + + Locations []model.Locations `json:"locations"` + Events []model.Events `json:"events"` +} + +type ResponseFormat struct { + Type string `json:"type"` + JsonSchema any `json:"json_schema"` +} + +type AgentRequestBody struct { + Model string `json:"model"` + Temperature float64 `json:"temperature"` + ResponseFormat ResponseFormat `json:"response_format"` + + Tools *any `json:"tools,omitempty"` + ToolChoice *string `json:"tool_choice,omitempty"` + + // ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` + + AgentMessages +} + +type AgentMessages struct { + Messages []AgentMessage `json:"messages"` +} + +type AgentMessage interface { + MessageToJson() ([]byte, error) +} + +type AgentTextMessage struct { + Role string `json:"role"` + Content string `json:"content"` + ToolCallId string `json:"tool_call_id,omitempty"` + Name string `json:"name,omitempty"` +} + +func (textContent AgentTextMessage) MessageToJson() ([]byte, error) { + // TODO: Validate the `Role`. + return json.Marshal(textContent) +} + +type ToolCall struct { + Index int `json:"index"` + Id string `json:"id"` + Function FunctionCall `json:"function"` +} + +type FunctionCall struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +type AgentAssistantToolCall struct { + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []ToolCall `json:"tool_calls"` +} + +func (toolCall AgentAssistantToolCall) MessageToJson() ([]byte, error) { + return json.Marshal(toolCall) +} + +type AgentArrayMessage struct { + Role string `json:"role"` + Content []AgentContent `json:"content"` +} + +func (arrayContent AgentArrayMessage) MessageToJson() ([]byte, error) { + return json.Marshal(arrayContent) +} + +func (content *AgentMessages) AddText(message AgentTextMessage) { + content.Messages = append(content.Messages, message) +} + +func (content *AgentMessages) AddToolCall(toolCall AgentAssistantToolCall) { + content.Messages = append(content.Messages, toolCall) +} + +func (content *AgentMessages) AddImage(imageName string, image []byte) error { + extension := filepath.Ext(imageName) + if len(extension) == 0 { + // TODO: could also validate for image types we support. + return errors.New("Image does not have extension") + } + + extension = extension[1:] + + encodedString := base64.StdEncoding.EncodeToString(image) + + arrayMessage := AgentArrayMessage{Role: ROLE_USER, Content: make([]AgentContent, 1)} + arrayMessage.Content[0] = AgentImage{ + ImageType: IMAGE_TYPE, + ImageUrl: fmt.Sprintf("data:image/%s;base64,%s", extension, encodedString), + } + + content.Messages = append(content.Messages, arrayMessage) + + return nil +} + +func (content *AgentMessages) AddSystem(prompt string) error { + if len(content.Messages) != 0 { + return errors.New("You can only add a system prompt at the beginning") + } + + content.Messages = append(content.Messages, AgentTextMessage{ + Role: ROLE_SYSTEM, + Content: prompt, + }) + + return nil +} + +type AgentContent interface { + ToJson() ([]byte, error) +} + +type ImageUrl struct { + Url string `json:"url"` +} + +type AgentImage struct { + ImageType string `json:"type"` + ImageUrl string `json:"image_url"` +} + +func (imageMessage AgentImage) ToJson() ([]byte, error) { + imageMessage.ImageType = IMAGE_TYPE + return json.Marshal(imageMessage) +} + +type AiClient interface { + GetImageInfo(imageName string, imageData []byte) (ImageInfo, error) +} + +type ResponseChoiceMessage struct { + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []ToolCall `json:"tool_calls"` +} + +type ResponseChoice struct { + Index int `json:"index"` + Message ResponseChoiceMessage `json:"message"` + FinishReason string `json:"finish_reason"` +} + +type AgentResponse struct { + Id string `json:"id"` + Object string `json:"object"` + Choices []ResponseChoice `json:"choices"` + Created int `json:"created"` +} + +type AgentClient struct { + url string + apiKey string + systemPrompt string + responseFormat string + + ToolHandler ToolsHandlers + + Do func(req *http.Request) (*http.Response, error) +} + +const OPENAI_API_KEY = "OPENAI_API_KEY" +const ROLE_USER = "user" +const ROLE_SYSTEM = "system" +const IMAGE_TYPE = "image_url" + +func CreateAgentClient(prompt string) (AgentClient, error) { + apiKey := os.Getenv(OPENAI_API_KEY) + + if len(apiKey) == 0 { + return AgentClient{}, errors.New(OPENAI_API_KEY + " was not found.") + } + + return AgentClient{ + apiKey: apiKey, + url: "https://api.mistral.ai/v1/chat/completions", + systemPrompt: prompt, + Do: func(req *http.Request) (*http.Response, error) { + client := &http.Client{} + return client.Do(req) + }, + }, nil +} + +func (client AgentClient) getRequest(body []byte) (*http.Request, error) { + req, err := http.NewRequest("POST", client.url, bytes.NewBuffer(body)) + if err != nil { + return req, err + } + + req.Header.Add("Authorization", "Bearer "+client.apiKey) + req.Header.Add("Content-Type", "application/json") + + return req, nil +} + +func (client AgentClient) Request(request *AgentRequestBody) (AgentResponse, error) { + jsonAiRequest, err := json.Marshal(request) + if err != nil { + return AgentResponse{}, err + } + + httpRequest, err := client.getRequest(jsonAiRequest) + if err != nil { + return AgentResponse{}, err + } + + resp, err := client.Do(httpRequest) + if err != nil { + return AgentResponse{}, err + } + + response, err := io.ReadAll(resp.Body) + if err != nil { + return AgentResponse{}, err + } + + agentResponse := AgentResponse{} + err = json.Unmarshal(response, &agentResponse) + + if err != nil { + return AgentResponse{}, err + } + + log.Println(string(response)) + + toolCalls := agentResponse.Choices[0].Message.ToolCalls + if len(toolCalls) > 0 { + // Should for sure be more flexible. + request.AddToolCall(AgentAssistantToolCall{ + Role: "assistant", + Content: "", + ToolCalls: toolCalls, + }) + } + + return agentResponse, nil +} diff --git a/backend/agents/client/tools.go b/backend/agents/client/tools.go new file mode 100644 index 0000000..b9c8f42 --- /dev/null +++ b/backend/agents/client/tools.go @@ -0,0 +1,79 @@ +package client + +import ( + "encoding/json" + "errors" + "log" + + "github.com/google/uuid" +) + +type ToolHandlerInfo struct { + UserId uuid.UUID + ImageId uuid.UUID +} + +type ToolHandler struct { + Fn func(info ToolHandlerInfo, args string, call ToolCall) (string, error) +} + +type ToolsHandlers struct { + handlers *map[string]ToolHandler +} + +func (handler ToolsHandlers) Handle(info ToolHandlerInfo, request *AgentRequestBody) (string, error) { + agentMessage := request.Messages[len(request.Messages)-1] + + toolCall, ok := agentMessage.(AgentAssistantToolCall) + if !ok { + return "", errors.New("Latest message was not a tool call.") + } + + fnName := toolCall.ToolCalls[0].Function.Name + arguments := toolCall.ToolCalls[0].Function.Arguments + + fnHandler, exists := (*handler.handlers)[fnName] + if !exists { + return "", errors.New("Could not find tool with this name.") + } + + log.Printf("Calling: %s\n", fnName) + res, err := fnHandler.Fn(info, arguments, toolCall.ToolCalls[0]) + if err != nil { + return "", err + } + + request.AddText(AgentTextMessage{ + Role: "tool", + Name: "createLocation", + Content: res, + ToolCallId: toolCall.ToolCalls[0].Id, + }) + + return res, nil +} + +func (handler ToolsHandlers) AddTool(name string, getArgs func() any, fn func(info ToolHandlerInfo, args any, call ToolCall) (any, error)) { + (*handler.handlers)["createLocation"] = ToolHandler{ + Fn: func(info ToolHandlerInfo, args string, call ToolCall) (string, error) { + argsStruct := getArgs() + + err := json.Unmarshal([]byte(args), &argsStruct) + if err != nil { + return "", err + } + + res, err := fn(info, argsStruct, call) + if err != nil { + return "", err + } + + marshalledRes, err := json.Marshal(res) + if err != nil { + return "", err + } + + return string(marshalledRes), nil + }, + } +} diff --git a/backend/agents/event_location_agent.go b/backend/agents/event_location_agent.go index 72757ad..d131d19 100644 --- a/backend/agents/event_location_agent.go +++ b/backend/agents/event_location_agent.go @@ -5,8 +5,8 @@ import ( "encoding/json" "errors" "log" - "reflect" "screenmark/screenmark/.gen/haystack/haystack/model" + "screenmark/screenmark/agents/client" "screenmark/screenmark/models" "time" @@ -116,13 +116,13 @@ const TOOLS = ` ` type EventLocationAgent struct { - client AgentClient + client client.AgentClient eventModel models.EventModel locationModel models.LocationModel contactModel models.ContactModel - toolHandler ToolsHandlers + toolHandler client.ToolsHandlers } type ListLocationArguments struct{} @@ -158,12 +158,12 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID toolChoice := "any" - request := AgentRequestBody{ + request := client.AgentRequestBody{ Tools: &tools, ToolChoice: &toolChoice, Model: "pixtral-12b-2409", Temperature: 0.3, - ResponseFormat: ResponseFormat{ + ResponseFormat: client.ResponseFormat{ Type: "text", }, } @@ -180,11 +180,12 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID return err } - toolHandlerInfo := ToolHandlerInfo{ - imageId: imageId, - userId: userId, + toolHandlerInfo := client.ToolHandlerInfo{ + ImageId: imageId, + UserId: userId, } + // TODO: this should go into a loop with toolHandler Handle function and not be here bruh. _, err = agent.toolHandler.Handle(toolHandlerInfo, &request) for err == nil { @@ -208,107 +209,41 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID return err } -// TODO: extract this into a more general tool handler package. -func (handler ToolsHandlers) Handle(info ToolHandlerInfo, request *AgentRequestBody) (string, error) { - agentMessage := request.Messages[len(request.Messages)-1] - - toolCall, ok := agentMessage.(AgentAssistantToolCall) - if !ok { - return "", errors.New("Latest message was not a tool call.") - } - - fnName := toolCall.ToolCalls[0].Function.Name - arguments := toolCall.ToolCalls[0].Function.Arguments - - if fnName == "finish" { - return "", errors.New("This is the end! Maybe we just return a boolean.") - } - - fn, exists := handler.Handlers[fnName] - if !exists { - return "", errors.New("Could not find tool with this name.") - } - - // holy jesus what the fuck. - parseMethod := reflect.ValueOf(fn).Field(1) - if !parseMethod.IsValid() { - return "", errors.New("Parse method not found") - } - - parsedArgs := parseMethod.Call([]reflect.Value{reflect.ValueOf(arguments)}) - if !parsedArgs[1].IsNil() { - return "", parsedArgs[1].Interface().(error) - } - - log.Printf("Calling: %s\n", fnName) - - fnMethod := reflect.ValueOf(fn).Field(2) - if !fnMethod.IsValid() { - return "", errors.New("Fn method not found") - } - - response := fnMethod.Call([]reflect.Value{reflect.ValueOf(info), parsedArgs[0], reflect.ValueOf(toolCall.ToolCalls[0])}) - if !response[1].IsNil() { - return "", response[1].Interface().(error) - } - - stringResponse, err := json.Marshal(response[0].Interface()) - if err != nil { - return "", err - } - - request.AddText(AgentTextMessage{ - Role: "tool", - Name: "createLocation", - Content: string(stringResponse), - ToolCallId: toolCall.ToolCalls[0].Id, - }) - - return string(stringResponse), nil -} - func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel, contactModel models.ContactModel) (EventLocationAgent, error) { - client, err := CreateAgentClient(eventLocationPrompt) + agentClient, err := client.CreateAgentClient(eventLocationPrompt) if err != nil { return EventLocationAgent{}, err } agent := EventLocationAgent{ - client: client, + client: agentClient, locationModel: locationModel, eventModel: eventModel, contactModel: contactModel, } - toolHandler := ToolsHandlers{ - Handlers: make(map[string]ToolHandlerInterface), - } - - toolHandler.Handlers["listLocations"] = ToolHandler[ListLocationArguments, []model.Locations]{ - FunctionName: "listLocations", - Parse: func(stringArgs string) (ListLocationArguments, error) { - args := ListLocationArguments{} - err := json.Unmarshal([]byte(stringArgs), &args) - - return args, err + agentClient.ToolHandler.AddTool("listLocations", + func() any { + return ListLocationArguments{} }, - Fn: func(info ToolHandlerInfo, _args ListLocationArguments, call ToolCall) ([]model.Locations, error) { - return agent.locationModel.List(context.Background(), info.userId) + func(info client.ToolHandlerInfo, _args any, call client.ToolCall) (any, error) { + return agent.locationModel.List(context.Background(), info.UserId) }, - } + ) - toolHandler.Handlers["createLocation"] = ToolHandler[CreateLocationArguments, model.Locations]{ - FunctionName: "createLocation", - Parse: func(stringArgs string) (CreateLocationArguments, error) { - args := CreateLocationArguments{} - err := json.Unmarshal([]byte(stringArgs), &args) - - return args, err + agentClient.ToolHandler.AddTool("createLocation", + func() any { + return CreateLocationArguments{} }, - Fn: func(info ToolHandlerInfo, args CreateLocationArguments, call ToolCall) (model.Locations, error) { + func(info client.ToolHandlerInfo, _args any, call client.ToolCall) (any, error) { + args, ok := _args.(CreateLocationArguments) + if !ok { + return _args, errors.New("Type error, arguments are not of the correct struct type") + } + ctx := context.Background() - location, err := agent.locationModel.Save(ctx, info.userId, model.Locations{ + location, err := agent.locationModel.Save(ctx, info.UserId, model.Locations{ Name: args.Name, Address: args.Address, }) @@ -317,21 +252,22 @@ func NewLocationEventAgent(locationModel models.LocationModel, eventModel models return location, err } - _, err = agent.locationModel.SaveToImage(ctx, info.imageId, location.ID) + _, err = agent.locationModel.SaveToImage(ctx, info.ImageId, location.ID) return location, err }, - } + ) - toolHandler.Handlers["createEvent"] = ToolHandler[CreateEventArguments, model.Events]{ - FunctionName: "createEvent", - Parse: func(stringArgs string) (CreateEventArguments, error) { - args := CreateEventArguments{} - err := json.Unmarshal([]byte(stringArgs), &args) - - return args, err + agentClient.ToolHandler.AddTool("createEvent", + func() any { + return CreateEventArguments{} }, - Fn: func(info ToolHandlerInfo, args CreateEventArguments, call ToolCall) (model.Events, error) { + func(info client.ToolHandlerInfo, _args any, call client.ToolCall) (any, error) { + args, ok := _args.(CreateEventArguments) + if !ok { + return _args, errors.New("Type error, arguments are not of the correct struct type") + } + ctx := context.Background() layout := "2006-01-02T15:04:05Z" @@ -346,7 +282,7 @@ func NewLocationEventAgent(locationModel models.LocationModel, eventModel models return model.Events{}, err } - event, err := agent.eventModel.Save(ctx, info.userId, model.Events{ + event, err := agent.eventModel.Save(ctx, info.UserId, model.Events{ Name: args.Name, StartDateTime: &startTime, EndDateTime: &endTime, @@ -356,7 +292,7 @@ func NewLocationEventAgent(locationModel models.LocationModel, eventModel models return event, err } - organizer, err := agent.contactModel.Save(ctx, info.userId, model.Contacts{ + organizer, err := agent.contactModel.Save(ctx, info.UserId, model.Contacts{ Name: args.Name, }) @@ -364,12 +300,12 @@ func NewLocationEventAgent(locationModel models.LocationModel, eventModel models return event, err } - _, err = agent.eventModel.SaveToImage(ctx, info.imageId, event.ID) + _, err = agent.eventModel.SaveToImage(ctx, info.ImageId, event.ID) if err != nil { return event, err } - _, err = agent.contactModel.SaveToImage(ctx, info.imageId, organizer.ID) + _, err = agent.contactModel.SaveToImage(ctx, info.ImageId, organizer.ID) if err != nil { return event, err } @@ -386,9 +322,7 @@ func NewLocationEventAgent(locationModel models.LocationModel, eventModel models return agent.eventModel.UpdateOrganizer(ctx, event.ID, organizer.ID) }, - } - - agent.toolHandler = toolHandler + ) return agent, nil } diff --git a/backend/agents/note_agent.go b/backend/agents/note_agent.go index 1616ea6..3d62c94 100644 --- a/backend/agents/note_agent.go +++ b/backend/agents/note_agent.go @@ -3,6 +3,7 @@ package agents import ( "context" "screenmark/screenmark/.gen/haystack/haystack/model" + "screenmark/screenmark/agents/client" "screenmark/screenmark/models" "github.com/google/uuid" @@ -19,16 +20,16 @@ Do not return anything except markdown. ` type NoteAgent struct { - client AgentClient + client client.AgentClient noteModel models.NoteModel } func (agent NoteAgent) GetNotes(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error { - request := AgentRequestBody{ + request := client.AgentRequestBody{ Model: "pixtral-12b-2409", Temperature: 0.3, - ResponseFormat: ResponseFormat{ + ResponseFormat: client.ResponseFormat{ Type: "text", }, } @@ -65,7 +66,7 @@ func (agent NoteAgent) GetNotes(userId uuid.UUID, imageId uuid.UUID, imageName s } func NewNoteAgent(noteModel models.NoteModel) (NoteAgent, error) { - client, err := CreateAgentClient(noteAgentPrompt) + client, err := client.CreateAgentClient(noteAgentPrompt) if err != nil { return NoteAgent{}, err } diff --git a/backend/agents/orchestrator.go b/backend/agents/orchestrator.go index a3d4934..cff08b3 100644 --- a/backend/agents/orchestrator.go +++ b/backend/agents/orchestrator.go @@ -3,6 +3,7 @@ package agents import ( "encoding/json" "fmt" + "screenmark/screenmark/agents/client" "github.com/google/uuid" ) @@ -74,7 +75,7 @@ const MY_TOOLS = ` ]` type OrchestratorAgent struct { - client AgentClient + client client.AgentClient } func (agent OrchestratorAgent) Orchestrate(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error { @@ -86,10 +87,10 @@ func (agent OrchestratorAgent) Orchestrate(userId uuid.UUID, imageId uuid.UUID, return err } - request := AgentRequestBody{ + request := client.AgentRequestBody{ Model: "pixtral-12b-2409", Temperature: 0.3, - ResponseFormat: ResponseFormat{ + ResponseFormat: client.ResponseFormat{ Type: "text", }, ToolChoice: &toolChoice, @@ -113,7 +114,7 @@ func (agent OrchestratorAgent) Orchestrate(userId uuid.UUID, imageId uuid.UUID, } func NewOrchestratorAgent() (OrchestratorAgent, error) { - agent, err := CreateAgentClient(orchestratorPrompt) + agent, err := client.CreateAgentClient(orchestratorPrompt) if err != nil { return OrchestratorAgent{}, err } diff --git a/backend/agents/tools_handler.go b/backend/agents/tools_handler.go deleted file mode 100644 index 5981d50..0000000 --- a/backend/agents/tools_handler.go +++ /dev/null @@ -1,26 +0,0 @@ -package agents - -import "github.com/google/uuid" - -type ToolHandlerInfo struct { - userId uuid.UUID - imageId uuid.UUID -} - -type ToolHandler[TArgs any, TResp any] struct { - FunctionName string - Parse func(args string) (TArgs, error) - Fn func(info ToolHandlerInfo, args TArgs, call ToolCall) (TResp, error) -} - -type ToolHandlerInterface interface { - GetFunctionName() string -} - -func (handler ToolHandler[TArgs, TResp]) GetFunctionName() string { - return handler.FunctionName -} - -type ToolsHandlers struct { - Handlers map[string]ToolHandlerInterface -} diff --git a/backend/main.go b/backend/main.go index f1b0530..dbc80df 100644 --- a/backend/main.go +++ b/backend/main.go @@ -13,6 +13,7 @@ import ( "path/filepath" "screenmark/screenmark/.gen/haystack/haystack/model" "screenmark/screenmark/agents" + "screenmark/screenmark/agents/client" "screenmark/screenmark/models" "time" @@ -24,41 +25,13 @@ import ( ) type TestAiClient struct { - ImageInfo agents.ImageInfo + ImageInfo client.ImageInfo } -func (client TestAiClient) GetImageInfo(imageName string, imageData []byte) (agents.ImageInfo, error) { +func (client TestAiClient) GetImageInfo(imageName string, imageData []byte) (client.ImageInfo, error) { return client.ImageInfo, nil } -func GetAiClient() (agents.AiClient, error) { - mode := os.Getenv("MODE") - if mode == "TESTING" { - address := "10 Downing Street" - description := "Cheese and Crackers" - - return TestAiClient{ - ImageInfo: agents.ImageInfo{ - Tags: []string{"tag"}, - Links: []string{"links"}, - Text: []string{"text"}, - Locations: []model.Locations{{ - ID: uuid.Nil, - Name: "London", - Address: &address, - }}, - Events: []model.Events{{ - ID: uuid.Nil, - Name: "Party", - Description: &description, - }}, - }, - }, nil - } - - return agents.CreateAgentClient(agents.PROMPT) -} - func main() { err := godotenv.Load() if err != nil { @@ -74,9 +47,6 @@ func main() { } imageModel := models.NewImageModel(db) - linkModel := models.NewLinkModel(db) - tagModel := models.NewTagModel(db) - textModel := models.NewTextModel(db) locationModel := models.NewLocationModel(db) eventModel := models.NewEventModel(db) userModel := models.NewUserModel(db) @@ -105,11 +75,6 @@ func main() { ctx := context.Background() go func() { - openAiClient, err := GetAiClient() - if err != nil { - panic(err) - } - locationAgent, err := agents.NewLocationEventAgent(locationModel, eventModel, contactModel) if err != nil { panic(err) @@ -132,7 +97,7 @@ func main() { return } - userImage, err := imageModel.FinishProcessing(ctx, image.ID) + _, err = imageModel.FinishProcessing(ctx, image.ID) if err != nil { log.Println("Failed to FinishProcessing") log.Println(err) @@ -152,36 +117,6 @@ func main() { log.Println("Calling noteAgent!") err = noteAgent.GetNotes(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image) log.Println(err) - - return - - imageInfo, err := openAiClient.GetImageInfo(image.Image.ImageName, image.Image.Image) - if err != nil { - log.Println("Failed to GetImageInfo") - log.Println(err) - return - } - - err = tagModel.SaveToImage(ctx, userImage.ImageID, imageInfo.Tags) - if err != nil { - log.Println("Failed to save tags") - log.Println(err) - return - } - - err = linkModel.Save(ctx, userImage.ImageID, imageInfo.Links) - if err != nil { - log.Println("Failed to save links") - log.Println(err) - return - } - - err = textModel.Save(ctx, userImage.ImageID, imageInfo.Text) - if err != nil { - log.Println("Failed to save text") - log.Println(err) - return - } }() } } -- 2.47.2 From cd27f1105a88525e30d9b42d8e96ec19e0af999b Mon Sep 17 00:00:00 2001 From: John Costa Date: Fri, 4 Apr 2025 22:17:58 +0100 Subject: [PATCH 03/17] refactor(tool-calls): to be handled more generally --- backend/agents/client/client.go | 24 ++++++++++++++++++++++++ backend/agents/event_location_agent.go | 23 +---------------------- 2 files changed, 25 insertions(+), 22 deletions(-) diff --git a/backend/agents/client/client.go b/backend/agents/client/client.go index c137b8a..59ea381 100644 --- a/backend/agents/client/client.go +++ b/backend/agents/client/client.go @@ -263,3 +263,27 @@ func (client AgentClient) Request(request *AgentRequestBody) (AgentResponse, err return agentResponse, nil } + +func (client AgentClient) Process(info ToolHandlerInfo, request AgentRequestBody) error { + var err error + + for err == nil { + log.Printf("Latest message: %+v\n", request.AgentMessages.Messages[len(request.AgentMessages.Messages)-1]) + + response, requestError := client.Request(&request) + if requestError != nil { + return requestError + } + + log.Println(response) + + a, innerErr := client.ToolHandler.Handle(info, &request) + + err = innerErr + + log.Println(a) + log.Println("--------------------------") + } + + return nil +} diff --git a/backend/agents/event_location_agent.go b/backend/agents/event_location_agent.go index d131d19..eb2a41b 100644 --- a/backend/agents/event_location_agent.go +++ b/backend/agents/event_location_agent.go @@ -185,28 +185,7 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID UserId: userId, } - // TODO: this should go into a loop with toolHandler Handle function and not be here bruh. - _, err = agent.toolHandler.Handle(toolHandlerInfo, &request) - - for err == nil { - log.Printf("Latest message: %+v\n", request.AgentMessages.Messages[len(request.AgentMessages.Messages)-1]) - - response, requestError := agent.client.Request(&request) - if requestError != nil { - return requestError - } - - log.Println(response) - - a, innerErr := agent.toolHandler.Handle(toolHandlerInfo, &request) - - err = innerErr - - log.Println(a) - log.Println("--------------------------") - } - - return err + return agent.client.Process(toolHandlerInfo, request) } func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel, contactModel models.ContactModel) (EventLocationAgent, error) { -- 2.47.2 From aa153de185e88838781925abfb88463e2d95dc40 Mon Sep 17 00:00:00 2001 From: John Costa Date: Fri, 4 Apr 2025 22:40:45 +0100 Subject: [PATCH 04/17] refactor(agents): working e2e now I guess some repeated code doesnt hurt anyone, if it keeps things simpler. Trying to be fancy with the interfaces didn't work so well. --- backend/agents/client/client.go | 29 ++++++++++++------------ backend/agents/client/tools.go | 28 ++++++++++------------- backend/agents/event_location_agent.go | 31 +++++++++----------------- backend/main.go | 2 -- 4 files changed, 37 insertions(+), 53 deletions(-) diff --git a/backend/agents/client/client.go b/backend/agents/client/client.go index 59ea381..767cbed 100644 --- a/backend/agents/client/client.go +++ b/backend/agents/client/client.go @@ -206,6 +206,10 @@ func CreateAgentClient(prompt string) (AgentClient, error) { client := &http.Client{} return client.Do(req) }, + + ToolHandler: ToolsHandlers{ + handlers: &map[string]ToolHandler{}, + }, }, nil } @@ -267,23 +271,18 @@ func (client AgentClient) Request(request *AgentRequestBody) (AgentResponse, err func (client AgentClient) Process(info ToolHandlerInfo, request AgentRequestBody) error { var err error - for err == nil { - log.Printf("Latest message: %+v\n", request.AgentMessages.Messages[len(request.AgentMessages.Messages)-1]) - - response, requestError := client.Request(&request) - if requestError != nil { - return requestError + for { + err = client.ToolHandler.Handle(info, &request) + if err != nil { + break } - log.Println(response) - - a, innerErr := client.ToolHandler.Handle(info, &request) - - err = innerErr - - log.Println(a) - log.Println("--------------------------") + _, err = client.Request(&request) + if err != nil { + break + } } - return nil + log.Println(err) + return err } diff --git a/backend/agents/client/tools.go b/backend/agents/client/tools.go index b9c8f42..5f67a0c 100644 --- a/backend/agents/client/tools.go +++ b/backend/agents/client/tools.go @@ -21,49 +21,45 @@ type ToolsHandlers struct { handlers *map[string]ToolHandler } -func (handler ToolsHandlers) Handle(info ToolHandlerInfo, request *AgentRequestBody) (string, error) { +func (handler ToolsHandlers) Handle(info ToolHandlerInfo, request *AgentRequestBody) error { agentMessage := request.Messages[len(request.Messages)-1] toolCall, ok := agentMessage.(AgentAssistantToolCall) if !ok { - return "", errors.New("Latest message was not a tool call.") + return errors.New("Latest message was not a tool call.") } fnName := toolCall.ToolCalls[0].Function.Name arguments := toolCall.ToolCalls[0].Function.Arguments + log.Println(handler.handlers) + fnHandler, exists := (*handler.handlers)[fnName] if !exists { - return "", errors.New("Could not find tool with this name.") + return errors.New("Could not find tool with this name.") } log.Printf("Calling: %s\n", fnName) res, err := fnHandler.Fn(info, arguments, toolCall.ToolCalls[0]) if err != nil { - return "", err + return err } + log.Println(res) request.AddText(AgentTextMessage{ Role: "tool", - Name: "createLocation", + Name: fnName, Content: res, ToolCallId: toolCall.ToolCalls[0].Id, }) - return res, nil + return nil } -func (handler ToolsHandlers) AddTool(name string, getArgs func() any, fn func(info ToolHandlerInfo, args any, call ToolCall) (any, error)) { - (*handler.handlers)["createLocation"] = ToolHandler{ +func (handler ToolsHandlers) AddTool(name string, fn func(info ToolHandlerInfo, args string, call ToolCall) (any, error)) { + (*handler.handlers)[name] = ToolHandler{ Fn: func(info ToolHandlerInfo, args string, call ToolCall) (string, error) { - argsStruct := getArgs() - - err := json.Unmarshal([]byte(args), &argsStruct) - if err != nil { - return "", err - } - - res, err := fn(info, argsStruct, call) + res, err := fn(info, args, call) if err != nil { return "", err } diff --git a/backend/agents/event_location_agent.go b/backend/agents/event_location_agent.go index eb2a41b..c489cdb 100644 --- a/backend/agents/event_location_agent.go +++ b/backend/agents/event_location_agent.go @@ -3,8 +3,6 @@ package agents import ( "context" "encoding/json" - "errors" - "log" "screenmark/screenmark/.gen/haystack/haystack/model" "screenmark/screenmark/agents/client" "screenmark/screenmark/models" @@ -202,22 +200,17 @@ func NewLocationEventAgent(locationModel models.LocationModel, eventModel models } agentClient.ToolHandler.AddTool("listLocations", - func() any { - return ListLocationArguments{} - }, - func(info client.ToolHandlerInfo, _args any, call client.ToolCall) (any, error) { + func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { return agent.locationModel.List(context.Background(), info.UserId) }, ) agentClient.ToolHandler.AddTool("createLocation", - func() any { - return CreateLocationArguments{} - }, - func(info client.ToolHandlerInfo, _args any, call client.ToolCall) (any, error) { - args, ok := _args.(CreateLocationArguments) - if !ok { - return _args, errors.New("Type error, arguments are not of the correct struct type") + func(info client.ToolHandlerInfo, _args string, call client.ToolCall) (any, error) { + args := CreateLocationArguments{} + err := json.Unmarshal([]byte(_args), &args) + if err != nil { + return model.Locations{}, err } ctx := context.Background() @@ -238,13 +231,11 @@ func NewLocationEventAgent(locationModel models.LocationModel, eventModel models ) agentClient.ToolHandler.AddTool("createEvent", - func() any { - return CreateEventArguments{} - }, - func(info client.ToolHandlerInfo, _args any, call client.ToolCall) (any, error) { - args, ok := _args.(CreateEventArguments) - if !ok { - return _args, errors.New("Type error, arguments are not of the correct struct type") + func(info client.ToolHandlerInfo, _args string, call client.ToolCall) (any, error) { + args := CreateEventArguments{} + err := json.Unmarshal([]byte(_args), &args) + if err != nil { + return model.Locations{}, err } ctx := context.Background() diff --git a/backend/main.go b/backend/main.go index dbc80df..8e4230a 100644 --- a/backend/main.go +++ b/backend/main.go @@ -106,8 +106,6 @@ func main() { orchestrator.Orchestrate(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image) - return - // TODO: this can very much be parallel log.Println("Calling locationAgent!") -- 2.47.2 From 286a9a8472a37d8a85c284956d1c4c5dbc8db56c Mon Sep 17 00:00:00 2001 From: John Costa Date: Fri, 4 Apr 2025 22:50:19 +0100 Subject: [PATCH 05/17] fix(tool): raw text not scaling so well ey? --- backend/agents/client/client.go | 4 +++- backend/agents/client/tools.go | 2 -- backend/agents/event_location_agent.go | 19 ++++++++----------- 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/backend/agents/client/client.go b/backend/agents/client/client.go index 767cbed..34c463b 100644 --- a/backend/agents/client/client.go +++ b/backend/agents/client/client.go @@ -283,6 +283,8 @@ func (client AgentClient) Process(info ToolHandlerInfo, request AgentRequestBody } } - log.Println(err) + if err != nil { + log.Println(err) + } return err } diff --git a/backend/agents/client/tools.go b/backend/agents/client/tools.go index 5f67a0c..38dbb1f 100644 --- a/backend/agents/client/tools.go +++ b/backend/agents/client/tools.go @@ -32,8 +32,6 @@ func (handler ToolsHandlers) Handle(info ToolHandlerInfo, request *AgentRequestB fnName := toolCall.ToolCalls[0].Function.Name arguments := toolCall.ToolCalls[0].Function.Arguments - log.Println(handler.handlers) - fnHandler, exists := (*handler.handlers)[fnName] if !exists { return errors.New("Could not find tool with this name.") diff --git a/backend/agents/event_location_agent.go b/backend/agents/event_location_agent.go index c489cdb..a1b60cd 100644 --- a/backend/agents/event_location_agent.go +++ b/backend/agents/event_location_agent.go @@ -99,17 +99,14 @@ const TOOLS = ` } } }, - { - "type": "function", - "function": { - "name": "finish", - "description": "Nothing else to do, call this function.", - "parameters": { - "type": "object", - "properties": {} - } - } - } + { + "type": "function", + "function": { + "name": "finish", + "description": "Nothing else to do. call this function.", + "parameters": {} + } + } ] ` -- 2.47.2 From 03e78034675297941abd0498d8e27ca3c82b7374 Mon Sep 17 00:00:00 2001 From: John Costa Date: Sat, 5 Apr 2025 11:01:43 +0100 Subject: [PATCH 06/17] feat(orchestrator): calling needed agents when it needs to --- backend/agents/client/client.go | 6 ----- backend/agents/orchestrator.go | 46 +++++++++++++++++++++++++++++---- backend/main.go | 20 ++++---------- 3 files changed, 46 insertions(+), 26 deletions(-) diff --git a/backend/agents/client/client.go b/backend/agents/client/client.go index 34c463b..464dba2 100644 --- a/backend/agents/client/client.go +++ b/backend/agents/client/client.go @@ -152,10 +152,6 @@ func (imageMessage AgentImage) ToJson() ([]byte, error) { return json.Marshal(imageMessage) } -type AiClient interface { - GetImageInfo(imageName string, imageData []byte) (ImageInfo, error) -} - type ResponseChoiceMessage struct { Role string `json:"role"` Content string `json:"content"` @@ -253,8 +249,6 @@ func (client AgentClient) Request(request *AgentRequestBody) (AgentResponse, err return AgentResponse{}, err } - log.Println(string(response)) - toolCalls := agentResponse.Choices[0].Message.ToolCalls if len(toolCalls) > 0 { // Should for sure be more flexible. diff --git a/backend/agents/orchestrator.go b/backend/agents/orchestrator.go index cff08b3..4fe7d3a 100644 --- a/backend/agents/orchestrator.go +++ b/backend/agents/orchestrator.go @@ -2,7 +2,7 @@ package agents import ( "encoding/json" - "fmt" + "errors" "screenmark/screenmark/agents/client" "github.com/google/uuid" @@ -78,6 +78,12 @@ type OrchestratorAgent struct { client client.AgentClient } +type Status struct { + Ok bool `json:"ok"` +} + +// TODO: the primary function of the agent could be extracted outwards. +// This is basically the same function as we have in the `event_location_agent.go` func (agent OrchestratorAgent) Orchestrate(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error { toolChoice := "any" @@ -103,22 +109,52 @@ func (agent OrchestratorAgent) Orchestrate(userId uuid.UUID, imageId uuid.UUID, } request.AddImage(imageName, imageData) - resp, err := agent.client.Request(&request) + _, err = agent.client.Request(&request) if err != nil { return err } - fmt.Println(resp) + toolHandlerInfo := client.ToolHandlerInfo{ + ImageId: imageId, + UserId: userId, + } - return nil + return agent.client.Process(toolHandlerInfo, request) } -func NewOrchestratorAgent() (OrchestratorAgent, error) { +func NewOrchestratorAgent(eventLocationAgent EventLocationAgent, noteAgent NoteAgent, imageName string, imageData []byte) (OrchestratorAgent, error) { agent, err := client.CreateAgentClient(orchestratorPrompt) if err != nil { return OrchestratorAgent{}, err } + agent.ToolHandler.AddTool("eventLocationAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { + // We need a way to keep track of this async? + // Probably just a DB, because we don't want to wait. The orchistrator shouldnt wait for this stuff to finish. + + eventLocationAgent.GetLocations(info.UserId, info.ImageId, imageName, imageData) + + return Status{ + Ok: true, + }, nil + }) + + agent.ToolHandler.AddTool("noteAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { + noteAgent.GetNotes(info.UserId, info.ImageId, imageName, imageData) + + return Status{ + Ok: true, + }, nil + }) + + agent.ToolHandler.AddTool("defaultAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { + // To nothing + + return Status{ + Ok: true, + }, errors.New("Finished! Kinda bad return type but...") + }) + return OrchestratorAgent{ client: agent, }, nil diff --git a/backend/main.go b/backend/main.go index 8e4230a..205815f 100644 --- a/backend/main.go +++ b/backend/main.go @@ -85,11 +85,6 @@ func main() { panic(err) } - orchestrator, err := agents.NewOrchestratorAgent() - if err != nil { - panic(err) - } - image, err := imageModel.GetToProcessWithData(ctx, imageId) if err != nil { log.Println("Failed to GetToProcessWithData") @@ -104,17 +99,12 @@ func main() { return } + orchestrator, err := agents.NewOrchestratorAgent(locationAgent, noteAgent, image.Image.ImageName, image.Image.Image) + if err != nil { + panic(err) + } + orchestrator.Orchestrate(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image) - - // TODO: this can very much be parallel - - log.Println("Calling locationAgent!") - err = locationAgent.GetLocations(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image) - log.Println(err) - - log.Println("Calling noteAgent!") - err = noteAgent.GetNotes(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image) - log.Println(err) }() } } -- 2.47.2 From a1ce96d2e324a537a1373b5aa6f29d6bf2c43537 Mon Sep 17 00:00:00 2001 From: John Costa Date: Sat, 5 Apr 2025 14:35:54 +0100 Subject: [PATCH 07/17] test(tools): starting test suite for tools --- backend/agents/client/client.go | 7 ++++- backend/agents/client/tools.go | 30 ++++++------------ backend/agents/client/tools_test.go | 47 +++++++++++++++++++++++++++++ backend/go.mod | 2 +- backend/go.sum | 2 ++ 5 files changed, 65 insertions(+), 23 deletions(-) create mode 100644 backend/agents/client/tools_test.go diff --git a/backend/agents/client/client.go b/backend/agents/client/client.go index 464dba2..a49c86d 100644 --- a/backend/agents/client/client.go +++ b/backend/agents/client/client.go @@ -266,7 +266,12 @@ func (client AgentClient) Process(info ToolHandlerInfo, request AgentRequestBody var err error for { - err = client.ToolHandler.Handle(info, &request) + toolCall, ok := request.Messages[len(request.Messages)-1].(AgentAssistantToolCall) + if !ok { + return errors.New("Latest message isnt a tool call. TODO") + } + + _, err = client.ToolHandler.Handle(info, toolCall) if err != nil { break } diff --git a/backend/agents/client/tools.go b/backend/agents/client/tools.go index 38dbb1f..8c0d977 100644 --- a/backend/agents/client/tools.go +++ b/backend/agents/client/tools.go @@ -3,7 +3,6 @@ package client import ( "encoding/json" "errors" - "log" "github.com/google/uuid" ) @@ -21,37 +20,26 @@ type ToolsHandlers struct { handlers *map[string]ToolHandler } -func (handler ToolsHandlers) Handle(info ToolHandlerInfo, request *AgentRequestBody) error { - agentMessage := request.Messages[len(request.Messages)-1] - - toolCall, ok := agentMessage.(AgentAssistantToolCall) - if !ok { - return errors.New("Latest message was not a tool call.") - } - - fnName := toolCall.ToolCalls[0].Function.Name - arguments := toolCall.ToolCalls[0].Function.Arguments +func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage AgentAssistantToolCall) (AgentTextMessage, error) { + fnName := toolCallMessage.ToolCalls[0].Function.Name + arguments := toolCallMessage.ToolCalls[0].Function.Arguments fnHandler, exists := (*handler.handlers)[fnName] if !exists { - return errors.New("Could not find tool with this name.") + return AgentTextMessage{}, errors.New("Could not find tool with this name.") } - log.Printf("Calling: %s\n", fnName) - res, err := fnHandler.Fn(info, arguments, toolCall.ToolCalls[0]) + res, err := fnHandler.Fn(info, arguments, toolCallMessage.ToolCalls[0]) if err != nil { - return err + return AgentTextMessage{}, err } - log.Println(res) - request.AddText(AgentTextMessage{ + return AgentTextMessage{ Role: "tool", Name: fnName, Content: res, - ToolCallId: toolCall.ToolCalls[0].Id, - }) - - return nil + ToolCallId: toolCallMessage.ToolCalls[0].Id, + }, nil } func (handler ToolsHandlers) AddTool(name string, fn func(info ToolHandlerInfo, args string, call ToolCall) (any, error)) { diff --git a/backend/agents/client/tools_test.go b/backend/agents/client/tools_test.go new file mode 100644 index 0000000..9b17ce0 --- /dev/null +++ b/backend/agents/client/tools_test.go @@ -0,0 +1,47 @@ +package client + +import ( + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" +) + +func TestSingleToolCall(t *testing.T) { + assert := assert.New(t) + + tools := ToolsHandlers{ + handlers: &map[string]ToolHandler{}, + } + + tools.AddTool("a", func(info ToolHandlerInfo, args string, call ToolCall) (any, error) { + return true, nil + }) + + response, err := tools.Handle( + ToolHandlerInfo{ + UserId: uuid.Nil, + ImageId: uuid.Nil, + }, + AgentAssistantToolCall{ + Role: "assistant", + Content: "", + ToolCalls: []ToolCall{{ + Index: 0, + Id: "1", + Function: FunctionCall{ + Name: "a", + Arguments: "", + }, + }}, + }) + + if assert.NoError(err, "Tool call shouldnt return an error") { + assert.EqualValues(response, AgentTextMessage{ + Role: "tool", + Content: "true", + ToolCallId: "1", + Name: "a", + }) + } +} diff --git a/backend/go.mod b/backend/go.mod index fcdc38a..df64294 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -10,6 +10,6 @@ require ( github.com/joho/godotenv v1.5.1 // indirect github.com/lib/pq v1.10.9 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/stretchr/testify v1.9.0 // indirect + github.com/stretchr/testify v1.10.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/backend/go.sum b/backend/go.sum index ecd9db0..b982fe0 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -14,6 +14,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -- 2.47.2 From 1cafc31e0a0b5d5310bf554bb5af791d5d9e6893 Mon Sep 17 00:00:00 2001 From: John Costa Date: Sat, 5 Apr 2025 14:52:31 +0100 Subject: [PATCH 08/17] test(tools): more robust multiple tool call handling --- backend/agents/client/tools.go | 42 +++++++----- backend/agents/client/tools_test.go | 100 +++++++++++++++++++++++++--- 2 files changed, 117 insertions(+), 25 deletions(-) diff --git a/backend/agents/client/tools.go b/backend/agents/client/tools.go index 8c0d977..0e259ec 100644 --- a/backend/agents/client/tools.go +++ b/backend/agents/client/tools.go @@ -20,26 +20,38 @@ type ToolsHandlers struct { handlers *map[string]ToolHandler } -func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage AgentAssistantToolCall) (AgentTextMessage, error) { - fnName := toolCallMessage.ToolCalls[0].Function.Name - arguments := toolCallMessage.ToolCalls[0].Function.Arguments +var NoToolCallError = errors.New("An assistant tool call with no tool calls was provided.") - fnHandler, exists := (*handler.handlers)[fnName] - if !exists { - return AgentTextMessage{}, errors.New("Could not find tool with this name.") +func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage AgentAssistantToolCall) ([]AgentTextMessage, error) { + if len(toolCallMessage.ToolCalls) == 0 { + return []AgentTextMessage{}, NoToolCallError } - res, err := fnHandler.Fn(info, arguments, toolCallMessage.ToolCalls[0]) - if err != nil { - return AgentTextMessage{}, err + responses := make([]AgentTextMessage, len(toolCallMessage.ToolCalls)) + + for i, toolCall := range toolCallMessage.ToolCalls { + fnName := toolCall.Function.Name + arguments := toolCall.Function.Arguments + + fnHandler, exists := (*handler.handlers)[fnName] + if !exists { + return []AgentTextMessage{}, errors.New("Could not find tool with this name.") + } + + res, err := fnHandler.Fn(info, arguments, toolCallMessage.ToolCalls[0]) + if err != nil { + return []AgentTextMessage{}, err + } + + responses[i] = AgentTextMessage{ + Role: "tool", + Name: fnName, + Content: res, + ToolCallId: toolCall.Id, + } } - return AgentTextMessage{ - Role: "tool", - Name: fnName, - Content: res, - ToolCallId: toolCallMessage.ToolCalls[0].Id, - }, nil + return responses, nil } func (handler ToolsHandlers) AddTool(name string, fn func(info ToolHandlerInfo, args string, call ToolCall) (any, error)) { diff --git a/backend/agents/client/tools_test.go b/backend/agents/client/tools_test.go index 9b17ce0..35423ff 100644 --- a/backend/agents/client/tools_test.go +++ b/backend/agents/client/tools_test.go @@ -4,21 +4,30 @@ import ( "testing" "github.com/google/uuid" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) -func TestSingleToolCall(t *testing.T) { - assert := assert.New(t) +type ToolTestSuite struct { + suite.Suite - tools := ToolsHandlers{ + handler ToolsHandlers +} + +func (suite *ToolTestSuite) SetupTest() { + suite.handler = ToolsHandlers{ handlers: &map[string]ToolHandler{}, } - tools.AddTool("a", func(info ToolHandlerInfo, args string, call ToolCall) (any, error) { + suite.handler.AddTool("a", func(info ToolHandlerInfo, args string, call ToolCall) (any, error) { return true, nil }) +} - response, err := tools.Handle( +func (suite *ToolTestSuite) TestSingleToolCall() { + assert := suite.Assert() + require := suite.Require() + + response, err := suite.handler.Handle( ToolHandlerInfo{ UserId: uuid.Nil, ImageId: uuid.Nil, @@ -36,12 +45,83 @@ func TestSingleToolCall(t *testing.T) { }}, }) - if assert.NoError(err, "Tool call shouldnt return an error") { - assert.EqualValues(response, AgentTextMessage{ + require.NoError(err, "Tool call shouldnt return an error") + + assert.EqualValues(response, []AgentTextMessage{{ + Role: "tool", + Content: "true", + ToolCallId: "1", + Name: "a", + }}) +} + +func (suite *ToolTestSuite) TestEmptyCall() { + require := suite.Require() + + _, err := suite.handler.Handle( + ToolHandlerInfo{ + UserId: uuid.Nil, + ImageId: uuid.Nil, + }, + AgentAssistantToolCall{ + Role: "assistant", + Content: "", + ToolCalls: []ToolCall{}, + }) + + require.ErrorIs(err, NoToolCallError) +} + +func (suite *ToolTestSuite) TestMultipleToolCalls() { + assert := suite.Assert() + require := suite.Require() + + response, err := suite.handler.Handle( + ToolHandlerInfo{ + UserId: uuid.Nil, + ImageId: uuid.Nil, + }, + AgentAssistantToolCall{ + Role: "assistant", + Content: "", + ToolCalls: []ToolCall{ + { + Index: 0, + Id: "1", + Function: FunctionCall{ + Name: "a", + Arguments: "", + }, + }, + { + Index: 1, + Id: "2", + Function: FunctionCall{ + Name: "a", + Arguments: "", + }, + }, + }, + }) + + require.NoError(err, "Tool call shouldnt return an error") + + assert.EqualValues(response, []AgentTextMessage{ + { Role: "tool", Content: "true", ToolCallId: "1", Name: "a", - }) - } + }, + { + Role: "tool", + Content: "true", + ToolCallId: "2", + Name: "a", + }, + }) +} + +func TestToolSuite(t *testing.T) { + suite.Run(t, &ToolTestSuite{}) } -- 2.47.2 From d78f34a7aa3b05e24353fe62154f49cacddfd38d Mon Sep 17 00:00:00 2001 From: John Costa Date: Sat, 5 Apr 2025 14:58:38 +0100 Subject: [PATCH 09/17] feat(tools): return error to agent if any happened --- backend/agents/client/tools.go | 14 +++--- backend/agents/client/tools_test.go | 69 ++++++++++++++++++++++++++--- 2 files changed, 71 insertions(+), 12 deletions(-) diff --git a/backend/agents/client/tools.go b/backend/agents/client/tools.go index 0e259ec..193fe92 100644 --- a/backend/agents/client/tools.go +++ b/backend/agents/client/tools.go @@ -39,16 +39,20 @@ func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage AgentA } res, err := fnHandler.Fn(info, arguments, toolCallMessage.ToolCalls[0]) - if err != nil { - return []AgentTextMessage{}, err - } - responses[i] = AgentTextMessage{ + responseMessage := AgentTextMessage{ Role: "tool", Name: fnName, - Content: res, ToolCallId: toolCall.Id, } + + if err != nil { + responseMessage.Content = err.Error() + } else { + responseMessage.Content = res + } + + responses[i] = responseMessage } return responses, nil diff --git a/backend/agents/client/tools_test.go b/backend/agents/client/tools_test.go index 35423ff..cb07ece 100644 --- a/backend/agents/client/tools_test.go +++ b/backend/agents/client/tools_test.go @@ -1,6 +1,7 @@ package client import ( + "errors" "testing" "github.com/google/uuid" @@ -19,7 +20,11 @@ func (suite *ToolTestSuite) SetupTest() { } suite.handler.AddTool("a", func(info ToolHandlerInfo, args string, call ToolCall) (any, error) { - return true, nil + return args, nil + }) + + suite.handler.AddTool("error", func(info ToolHandlerInfo, args string, call ToolCall) (any, error) { + return false, errors.New("I will always error") }) } @@ -40,7 +45,7 @@ func (suite *ToolTestSuite) TestSingleToolCall() { Id: "1", Function: FunctionCall{ Name: "a", - Arguments: "", + Arguments: "return", }, }}, }) @@ -49,7 +54,7 @@ func (suite *ToolTestSuite) TestSingleToolCall() { assert.EqualValues(response, []AgentTextMessage{{ Role: "tool", - Content: "true", + Content: "\"return\"", ToolCallId: "1", Name: "a", }}) @@ -90,7 +95,7 @@ func (suite *ToolTestSuite) TestMultipleToolCalls() { Id: "1", Function: FunctionCall{ Name: "a", - Arguments: "", + Arguments: "first-call", }, }, { @@ -98,7 +103,7 @@ func (suite *ToolTestSuite) TestMultipleToolCalls() { Id: "2", Function: FunctionCall{ Name: "a", - Arguments: "", + Arguments: "second-call", }, }, }, @@ -109,13 +114,63 @@ func (suite *ToolTestSuite) TestMultipleToolCalls() { assert.EqualValues(response, []AgentTextMessage{ { Role: "tool", - Content: "true", + Content: "\"first-call\"", ToolCallId: "1", Name: "a", }, { Role: "tool", - Content: "true", + Content: "\"second-call\"", + ToolCallId: "2", + Name: "a", + }, + }) +} + +func (suite *ToolTestSuite) TestMultipleToolCallsWithErrors() { + assert := suite.Assert() + require := suite.Require() + + response, err := suite.handler.Handle( + ToolHandlerInfo{ + UserId: uuid.Nil, + ImageId: uuid.Nil, + }, + AgentAssistantToolCall{ + Role: "assistant", + Content: "", + ToolCalls: []ToolCall{ + { + Index: 0, + Id: "1", + Function: FunctionCall{ + Name: "error", + Arguments: "", + }, + }, + { + Index: 1, + Id: "2", + Function: FunctionCall{ + Name: "a", + Arguments: "no-error", + }, + }, + }, + }) + + require.NoError(err, "Tool call shouldnt return an error") + + assert.EqualValues(response, []AgentTextMessage{ + { + Role: "tool", + Content: "I will always error", + ToolCallId: "1", + Name: "error", + }, + { + Role: "tool", + Content: "\"no-error\"", ToolCallId: "2", Name: "a", }, -- 2.47.2 From d474b1700a1e6694880afaf9a39c4c01119088ae Mon Sep 17 00:00:00 2001 From: John Costa Date: Sat, 5 Apr 2025 14:59:50 +0100 Subject: [PATCH 10/17] refactor(tools): removing pointer map This is not needed --- backend/agents/client/client.go | 2 +- backend/agents/client/tools.go | 8 ++++---- backend/agents/client/tools_test.go | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/backend/agents/client/client.go b/backend/agents/client/client.go index a49c86d..3964aca 100644 --- a/backend/agents/client/client.go +++ b/backend/agents/client/client.go @@ -204,7 +204,7 @@ func CreateAgentClient(prompt string) (AgentClient, error) { }, ToolHandler: ToolsHandlers{ - handlers: &map[string]ToolHandler{}, + handlers: map[string]ToolHandler{}, }, }, nil } diff --git a/backend/agents/client/tools.go b/backend/agents/client/tools.go index 193fe92..a994363 100644 --- a/backend/agents/client/tools.go +++ b/backend/agents/client/tools.go @@ -17,7 +17,7 @@ type ToolHandler struct { } type ToolsHandlers struct { - handlers *map[string]ToolHandler + handlers map[string]ToolHandler } var NoToolCallError = errors.New("An assistant tool call with no tool calls was provided.") @@ -33,7 +33,7 @@ func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage AgentA fnName := toolCall.Function.Name arguments := toolCall.Function.Arguments - fnHandler, exists := (*handler.handlers)[fnName] + fnHandler, exists := handler.handlers[fnName] if !exists { return []AgentTextMessage{}, errors.New("Could not find tool with this name.") } @@ -58,8 +58,8 @@ func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage AgentA return responses, nil } -func (handler ToolsHandlers) AddTool(name string, fn func(info ToolHandlerInfo, args string, call ToolCall) (any, error)) { - (*handler.handlers)[name] = ToolHandler{ +func (handler *ToolsHandlers) AddTool(name string, fn func(info ToolHandlerInfo, args string, call ToolCall) (any, error)) { + handler.handlers[name] = ToolHandler{ Fn: func(info ToolHandlerInfo, args string, call ToolCall) (string, error) { res, err := fn(info, args, call) if err != nil { diff --git a/backend/agents/client/tools_test.go b/backend/agents/client/tools_test.go index cb07ece..f20973b 100644 --- a/backend/agents/client/tools_test.go +++ b/backend/agents/client/tools_test.go @@ -16,7 +16,7 @@ type ToolTestSuite struct { func (suite *ToolTestSuite) SetupTest() { suite.handler = ToolsHandlers{ - handlers: &map[string]ToolHandler{}, + handlers: map[string]ToolHandler{}, } suite.handler.AddTool("a", func(info ToolHandlerInfo, args string, call ToolCall) (any, error) { -- 2.47.2 From f5f40080340f40584a86d9b752b4b5967f72e41a Mon Sep 17 00:00:00 2001 From: John Costa Date: Sat, 5 Apr 2025 15:04:09 +0100 Subject: [PATCH 11/17] fix(tools): dont error if AI invested a tool --- backend/agents/client/tools.go | 18 +++++++++++------- backend/agents/client/tools_test.go | 16 +++++++++++++++- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/backend/agents/client/tools.go b/backend/agents/client/tools.go index a994363..5bfd6d1 100644 --- a/backend/agents/client/tools.go +++ b/backend/agents/client/tools.go @@ -22,6 +22,8 @@ type ToolsHandlers struct { var NoToolCallError = errors.New("An assistant tool call with no tool calls was provided.") +const NonExistantTool = "This tool does not exist" + func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage AgentAssistantToolCall) ([]AgentTextMessage, error) { if len(toolCallMessage.ToolCalls) == 0 { return []AgentTextMessage{}, NoToolCallError @@ -33,19 +35,21 @@ func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage AgentA fnName := toolCall.Function.Name arguments := toolCall.Function.Arguments - fnHandler, exists := handler.handlers[fnName] - if !exists { - return []AgentTextMessage{}, errors.New("Could not find tool with this name.") - } - - res, err := fnHandler.Fn(info, arguments, toolCallMessage.ToolCalls[0]) - responseMessage := AgentTextMessage{ Role: "tool", Name: fnName, ToolCallId: toolCall.Id, } + fnHandler, exists := handler.handlers[fnName] + if !exists { + responseMessage.Content = NonExistantTool + responses[i] = responseMessage + continue + } + + res, err := fnHandler.Fn(info, arguments, toolCallMessage.ToolCalls[0]) + if err != nil { responseMessage.Content = err.Error() } else { diff --git a/backend/agents/client/tools_test.go b/backend/agents/client/tools_test.go index f20973b..41d4072 100644 --- a/backend/agents/client/tools_test.go +++ b/backend/agents/client/tools_test.go @@ -151,6 +151,14 @@ func (suite *ToolTestSuite) TestMultipleToolCallsWithErrors() { { Index: 1, Id: "2", + Function: FunctionCall{ + Name: "non-existant", + Arguments: "", + }, + }, + { + Index: 2, + Id: "3", Function: FunctionCall{ Name: "a", Arguments: "no-error", @@ -170,8 +178,14 @@ func (suite *ToolTestSuite) TestMultipleToolCallsWithErrors() { }, { Role: "tool", - Content: "\"no-error\"", + Content: "This tool does not exist", ToolCallId: "2", + Name: "non-existant", + }, + { + Role: "tool", + Content: "\"no-error\"", + ToolCallId: "3", Name: "a", }, }) -- 2.47.2 From 28ee32e2ff6e5bc4520ff5eb2d6f7e64a9fa0c1f Mon Sep 17 00:00:00 2001 From: John Costa Date: Sun, 6 Apr 2025 20:24:40 +0100 Subject: [PATCH 12/17] fixup(chat): better way to organize agent messages and tool calls --- backend/agents/client/client.go | 50 +++++++++++++++++++------- backend/agents/client/client_test.go | 40 +++++++++++++++++++++ backend/agents/client/tools.go | 8 ++--- backend/agents/client/tools_test.go | 6 ++-- backend/agents/event_location_agent.go | 2 +- backend/agents/note_agent.go | 2 +- backend/agents/orchestrator.go | 2 +- 7 files changed, 87 insertions(+), 23 deletions(-) create mode 100644 backend/agents/client/client_test.go diff --git a/backend/agents/client/client.go b/backend/agents/client/client.go index 3964aca..616d71a 100644 --- a/backend/agents/client/client.go +++ b/backend/agents/client/client.go @@ -49,14 +49,16 @@ type AgentMessage interface { MessageToJson() ([]byte, error) } -type AgentTextMessage struct { - Role string `json:"role"` - Content string `json:"content"` - ToolCallId string `json:"tool_call_id,omitempty"` - Name string `json:"name,omitempty"` +type AgentResponseMessage struct { + Role string `json:"role"` + Content string `json:"content"` + // Not sure I need this field. + ToolCallId string `json:"tool_call_id,omitempty"` + ToolCalls *[]ToolCall `json:"tool_calls,omitempty"` + Name string `json:"name,omitempty"` } -func (textContent AgentTextMessage) MessageToJson() ([]byte, error) { +func (textContent AgentResponseMessage) MessageToJson() ([]byte, error) { // TODO: Validate the `Role`. return json.Marshal(textContent) } @@ -91,7 +93,7 @@ func (arrayContent AgentArrayMessage) MessageToJson() ([]byte, error) { return json.Marshal(arrayContent) } -func (content *AgentMessages) AddText(message AgentTextMessage) { +func (content *AgentMessages) AddResponse(message ResponseChoiceMessage) { content.Messages = append(content.Messages, message) } @@ -126,7 +128,7 @@ func (content *AgentMessages) AddSystem(prompt string) error { return errors.New("You can only add a system prompt at the beginning") } - content.Messages = append(content.Messages, AgentTextMessage{ + content.Messages = append(content.Messages, AgentResponseMessage{ Role: ROLE_SYSTEM, Content: prompt, }) @@ -134,6 +136,26 @@ func (content *AgentMessages) AddSystem(prompt string) error { return nil } +// TODO: `AgentMessages` is not really a good name. +// It's a step above that, like a real chat. AgentChat or something. +func (chat *AgentMessages) HandleResponse(response AgentResponse) error { + if len(chat.Messages) == 0 { + return errors.New("This chat doesnt contain any messages therefore cannot be handled.") + } + + for _, choice := range response.Choices { + // TOOD + // if len(choice.Message.ToolCalls) > 0 { + // for _, toolCall := choice.Message.ToolCalls { + // chat.AddToolCall() + // } + // } + chat.AddResponse(choice.Message) + } + + return nil +} + type AgentContent interface { ToJson() ([]byte, error) } @@ -158,6 +180,10 @@ type ResponseChoiceMessage struct { ToolCalls []ToolCall `json:"tool_calls"` } +func (choice ResponseChoiceMessage) MessageToJson() ([]byte, error) { + return json.Marshal(choice) +} + type ResponseChoice struct { Index int `json:"index"` Message ResponseChoiceMessage `json:"message"` @@ -174,7 +200,6 @@ type AgentResponse struct { type AgentClient struct { url string apiKey string - systemPrompt string responseFormat string ToolHandler ToolsHandlers @@ -187,7 +212,7 @@ const ROLE_USER = "user" const ROLE_SYSTEM = "system" const IMAGE_TYPE = "image_url" -func CreateAgentClient(prompt string) (AgentClient, error) { +func CreateAgentClient() (AgentClient, error) { apiKey := os.Getenv(OPENAI_API_KEY) if len(apiKey) == 0 { @@ -195,9 +220,8 @@ func CreateAgentClient(prompt string) (AgentClient, error) { } return AgentClient{ - apiKey: apiKey, - url: "https://api.mistral.ai/v1/chat/completions", - systemPrompt: prompt, + apiKey: apiKey, + url: "https://api.mistral.ai/v1/chat/completions", Do: func(req *http.Request) (*http.Response, error) { client := &http.Client{} return client.Do(req) diff --git a/backend/agents/client/client_test.go b/backend/agents/client/client_test.go new file mode 100644 index 0000000..b6b21bf --- /dev/null +++ b/backend/agents/client/client_test.go @@ -0,0 +1,40 @@ +package client + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSimpleResponse(t *testing.T) { + // assert := assert.New(t) + require := require.New(t) + + chat := AgentMessages{ + Messages: make([]AgentMessage, 0), + } + + chat.AddSystem("system message") + + err := chat.HandleResponse(AgentResponse{ + Id: "0", + Object: "chat.completion", + Created: 1, + Choices: []ResponseChoice{{ + Index: 0, + Message: ResponseChoiceMessage{ + Role: "assistant", + Content: "some basic content", + }, + FinishReason: "", + }}, + }) + + require.NoError(err) + require.Len(chat.Messages, 2) + + require.EqualValues(chat.Messages[1], ResponseChoiceMessage{ + Role: "assistant", + Content: "some basic content", + }) +} diff --git a/backend/agents/client/tools.go b/backend/agents/client/tools.go index 5bfd6d1..4acaaef 100644 --- a/backend/agents/client/tools.go +++ b/backend/agents/client/tools.go @@ -24,18 +24,18 @@ var NoToolCallError = errors.New("An assistant tool call with no tool calls was const NonExistantTool = "This tool does not exist" -func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage AgentAssistantToolCall) ([]AgentTextMessage, error) { +func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage AgentAssistantToolCall) ([]AgentResponseMessage, error) { if len(toolCallMessage.ToolCalls) == 0 { - return []AgentTextMessage{}, NoToolCallError + return []AgentResponseMessage{}, NoToolCallError } - responses := make([]AgentTextMessage, len(toolCallMessage.ToolCalls)) + responses := make([]AgentResponseMessage, len(toolCallMessage.ToolCalls)) for i, toolCall := range toolCallMessage.ToolCalls { fnName := toolCall.Function.Name arguments := toolCall.Function.Arguments - responseMessage := AgentTextMessage{ + responseMessage := AgentResponseMessage{ Role: "tool", Name: fnName, ToolCallId: toolCall.Id, diff --git a/backend/agents/client/tools_test.go b/backend/agents/client/tools_test.go index 41d4072..f4d1d73 100644 --- a/backend/agents/client/tools_test.go +++ b/backend/agents/client/tools_test.go @@ -52,7 +52,7 @@ func (suite *ToolTestSuite) TestSingleToolCall() { require.NoError(err, "Tool call shouldnt return an error") - assert.EqualValues(response, []AgentTextMessage{{ + assert.EqualValues(response, []AgentResponseMessage{{ Role: "tool", Content: "\"return\"", ToolCallId: "1", @@ -111,7 +111,7 @@ func (suite *ToolTestSuite) TestMultipleToolCalls() { require.NoError(err, "Tool call shouldnt return an error") - assert.EqualValues(response, []AgentTextMessage{ + assert.EqualValues(response, []AgentResponseMessage{ { Role: "tool", Content: "\"first-call\"", @@ -169,7 +169,7 @@ func (suite *ToolTestSuite) TestMultipleToolCallsWithErrors() { require.NoError(err, "Tool call shouldnt return an error") - assert.EqualValues(response, []AgentTextMessage{ + assert.EqualValues(response, []AgentResponseMessage{ { Role: "tool", Content: "I will always error", diff --git a/backend/agents/event_location_agent.go b/backend/agents/event_location_agent.go index a1b60cd..dfef53a 100644 --- a/backend/agents/event_location_agent.go +++ b/backend/agents/event_location_agent.go @@ -184,7 +184,7 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID } func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel, contactModel models.ContactModel) (EventLocationAgent, error) { - agentClient, err := client.CreateAgentClient(eventLocationPrompt) + agentClient, err := client.CreateAgentClient() if err != nil { return EventLocationAgent{}, err } diff --git a/backend/agents/note_agent.go b/backend/agents/note_agent.go index 3d62c94..29b4ce8 100644 --- a/backend/agents/note_agent.go +++ b/backend/agents/note_agent.go @@ -66,7 +66,7 @@ func (agent NoteAgent) GetNotes(userId uuid.UUID, imageId uuid.UUID, imageName s } func NewNoteAgent(noteModel models.NoteModel) (NoteAgent, error) { - client, err := client.CreateAgentClient(noteAgentPrompt) + client, err := client.CreateAgentClient() if err != nil { return NoteAgent{}, err } diff --git a/backend/agents/orchestrator.go b/backend/agents/orchestrator.go index 4fe7d3a..392d989 100644 --- a/backend/agents/orchestrator.go +++ b/backend/agents/orchestrator.go @@ -123,7 +123,7 @@ func (agent OrchestratorAgent) Orchestrate(userId uuid.UUID, imageId uuid.UUID, } func NewOrchestratorAgent(eventLocationAgent EventLocationAgent, noteAgent NoteAgent, imageName string, imageData []byte) (OrchestratorAgent, error) { - agent, err := client.CreateAgentClient(orchestratorPrompt) + agent, err := client.CreateAgentClient() if err != nil { return OrchestratorAgent{}, err } -- 2.47.2 From 5502fc6b19c3d8f6fa40398bd798efe7966a1863 Mon Sep 17 00:00:00 2001 From: John Costa Date: Wed, 9 Apr 2025 12:04:44 +0100 Subject: [PATCH 13/17] feat(chat): more simplified chat messages and tool handling --- backend/agents/client/chat.go | 184 +++++++++++++++++++++++ backend/agents/client/client.go | 212 ++++----------------------- backend/agents/client/client_test.go | 40 ----- backend/agents/client/tools.go | 52 +++---- 4 files changed, 234 insertions(+), 254 deletions(-) create mode 100644 backend/agents/client/chat.go delete mode 100644 backend/agents/client/client_test.go diff --git a/backend/agents/client/chat.go b/backend/agents/client/chat.go new file mode 100644 index 0000000..6ab6fcf --- /dev/null +++ b/backend/agents/client/chat.go @@ -0,0 +1,184 @@ +package client + +import ( + "encoding/base64" + "errors" + "fmt" + "path/filepath" +) + +type Chat struct { + Messages []ChatMessage `json:"messages"` +} + +type ChatMessage interface { + IsResponse() bool +} + +// TODO: the role could be inferred from the type. +// This would solve some bugs. + +/* + +Is there a world where this actually becomes the product? +Where we build such a resilient system of AI calls that we +can build some app builder, or even just an API system, +with a fancy UI? + +Manage all the complexity for the user? + +*/ + +// ============================================= +// Messages from us to the AI. +// ============================================= + +type UserRole = string + +const ( + User UserRole = "user" + System UserRole = "system" +) + +type ToolRole = string + +const ( + Tool ToolRole = "tool" +) + +type ChatUserMessage struct { + Role UserRole `json:"role"` + + MessageContent `json:"content"` +} + +func (r ChatUserMessage) IsResponse() bool { + return false +} + +type ChatUserToolResponse struct { + Role ToolRole `json:"role"` + + // The name of the function we are responding to. + Name string `json:"name"` + Content string `json:"content"` + ToolCallId string `json:"tool_call_id"` +} + +func (r ChatUserToolResponse) IsResponse() bool { + return false +} + +type ChatAiMessage struct { + Role string `json:"role"` + ToolCalls *[]ToolCall `json:"tool_calls,omitempty"` + + MessageContent `json:"content"` +} + +func (m ChatAiMessage) IsResponse() bool { + return true +} + +// ============================================= +// Unique interface for message content. +// ============================================= + +type MessageContent interface { + IsSingleMessage() bool +} + +type SingleMessage struct { + Content string `json:"content"` +} + +func (m SingleMessage) IsSingleMessage() bool { + return true +} + +type ArrayMessage struct { + Content []ImageMessageContent `json:"content"` +} + +func (m ArrayMessage) IsSingleMessage() bool { + return false +} + +type ImageMessageContent struct { + ImageType string `json:"type"` + ImageUrl string `json:"image_url"` +} + +type ImageContentUrl struct { + Url string `json:"url"` +} + +// ============================================= +// Adjacent interfaces. +// ============================================= + +type ToolCall struct { + Index int `json:"index"` + Id string `json:"id"` + Function FunctionCall `json:"function"` +} + +type FunctionCall struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +// ============================================= +// Chat methods +// ============================================= + +func (chat *Chat) AddSystem(prompt string) { + chat.Messages = append(chat.Messages, ChatUserMessage{ + Role: System, + MessageContent: SingleMessage{ + Content: prompt, + }, + }) +} + +func (chat *Chat) AddImage(imageName string, image []byte) error { + extension := filepath.Ext(imageName) + if len(extension) == 0 { + // TODO: could also validate for image types we support. + return errors.New("Image does not have extension") + } + + extension = extension[1:] + + encodedString := base64.StdEncoding.EncodeToString(image) + + messageContent := ArrayMessage{ + Content: make([]ImageMessageContent, 1), + } + + messageContent.Content[0] = ImageMessageContent{ + ImageType: "image_url", + ImageUrl: fmt.Sprintf("data:image/%s;base64,%s", extension, encodedString), + } + + arrayMessage := ChatUserMessage{Role: User, MessageContent: messageContent} + chat.Messages = append(chat.Messages, arrayMessage) + + return nil +} + +func (chat *Chat) AddAiResponse(res ChatAiMessage) { + chat.Messages = append(chat.Messages, res) +} + +func (chat *Chat) AddToolResponse(res ChatUserToolResponse) { + chat.Messages = append(chat.Messages, res) +} + +func (chat Chat) GetLatest() (ChatMessage, error) { + if len(chat.Messages) == 0 { + return nil, errors.New("Not enough messages") + } + + return chat.Messages[len(chat.Messages)-1], nil +} diff --git a/backend/agents/client/client.go b/backend/agents/client/client.go index 616d71a..be9590e 100644 --- a/backend/agents/client/client.go +++ b/backend/agents/client/client.go @@ -2,7 +2,6 @@ package client import ( "bytes" - "encoding/base64" "encoding/json" "errors" "fmt" @@ -10,19 +9,8 @@ import ( "log" "net/http" "os" - "path/filepath" - "screenmark/screenmark/.gen/haystack/haystack/model" ) -type ImageInfo struct { - Tags []string `json:"tags"` - Text []string `json:"text"` - Links []string `json:"links"` - - Locations []model.Locations `json:"locations"` - Events []model.Events `json:"events"` -} - type ResponseFormat struct { Type string `json:"type"` JsonSchema any `json:"json_schema"` @@ -36,158 +24,13 @@ type AgentRequestBody struct { Tools *any `json:"tools,omitempty"` ToolChoice *string `json:"tool_choice,omitempty"` - // ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` - - AgentMessages -} - -type AgentMessages struct { - Messages []AgentMessage `json:"messages"` -} - -type AgentMessage interface { - MessageToJson() ([]byte, error) -} - -type AgentResponseMessage struct { - Role string `json:"role"` - Content string `json:"content"` - // Not sure I need this field. - ToolCallId string `json:"tool_call_id,omitempty"` - ToolCalls *[]ToolCall `json:"tool_calls,omitempty"` - Name string `json:"name,omitempty"` -} - -func (textContent AgentResponseMessage) MessageToJson() ([]byte, error) { - // TODO: Validate the `Role`. - return json.Marshal(textContent) -} - -type ToolCall struct { - Index int `json:"index"` - Id string `json:"id"` - Function FunctionCall `json:"function"` -} - -type FunctionCall struct { - Name string `json:"name"` - Arguments string `json:"arguments"` -} - -type AgentAssistantToolCall struct { - Role string `json:"role"` - Content string `json:"content"` - ToolCalls []ToolCall `json:"tool_calls"` -} - -func (toolCall AgentAssistantToolCall) MessageToJson() ([]byte, error) { - return json.Marshal(toolCall) -} - -type AgentArrayMessage struct { - Role string `json:"role"` - Content []AgentContent `json:"content"` -} - -func (arrayContent AgentArrayMessage) MessageToJson() ([]byte, error) { - return json.Marshal(arrayContent) -} - -func (content *AgentMessages) AddResponse(message ResponseChoiceMessage) { - content.Messages = append(content.Messages, message) -} - -func (content *AgentMessages) AddToolCall(toolCall AgentAssistantToolCall) { - content.Messages = append(content.Messages, toolCall) -} - -func (content *AgentMessages) AddImage(imageName string, image []byte) error { - extension := filepath.Ext(imageName) - if len(extension) == 0 { - // TODO: could also validate for image types we support. - return errors.New("Image does not have extension") - } - - extension = extension[1:] - - encodedString := base64.StdEncoding.EncodeToString(image) - - arrayMessage := AgentArrayMessage{Role: ROLE_USER, Content: make([]AgentContent, 1)} - arrayMessage.Content[0] = AgentImage{ - ImageType: IMAGE_TYPE, - ImageUrl: fmt.Sprintf("data:image/%s;base64,%s", extension, encodedString), - } - - content.Messages = append(content.Messages, arrayMessage) - - return nil -} - -func (content *AgentMessages) AddSystem(prompt string) error { - if len(content.Messages) != 0 { - return errors.New("You can only add a system prompt at the beginning") - } - - content.Messages = append(content.Messages, AgentResponseMessage{ - Role: ROLE_SYSTEM, - Content: prompt, - }) - - return nil -} - -// TODO: `AgentMessages` is not really a good name. -// It's a step above that, like a real chat. AgentChat or something. -func (chat *AgentMessages) HandleResponse(response AgentResponse) error { - if len(chat.Messages) == 0 { - return errors.New("This chat doesnt contain any messages therefore cannot be handled.") - } - - for _, choice := range response.Choices { - // TOOD - // if len(choice.Message.ToolCalls) > 0 { - // for _, toolCall := choice.Message.ToolCalls { - // chat.AddToolCall() - // } - // } - chat.AddResponse(choice.Message) - } - - return nil -} - -type AgentContent interface { - ToJson() ([]byte, error) -} - -type ImageUrl struct { - Url string `json:"url"` -} - -type AgentImage struct { - ImageType string `json:"type"` - ImageUrl string `json:"image_url"` -} - -func (imageMessage AgentImage) ToJson() ([]byte, error) { - imageMessage.ImageType = IMAGE_TYPE - return json.Marshal(imageMessage) -} - -type ResponseChoiceMessage struct { - Role string `json:"role"` - Content string `json:"content"` - ToolCalls []ToolCall `json:"tool_calls"` -} - -func (choice ResponseChoiceMessage) MessageToJson() ([]byte, error) { - return json.Marshal(choice) + Chat Chat `json:"messages"` } type ResponseChoice struct { - Index int `json:"index"` - Message ResponseChoiceMessage `json:"message"` - FinishReason string `json:"finish_reason"` + Index int `json:"index"` + Message ChatAiMessage `json:"message"` + FinishReason string `json:"finish_reason"` } type AgentResponse struct { @@ -208,9 +51,6 @@ type AgentClient struct { } const OPENAI_API_KEY = "OPENAI_API_KEY" -const ROLE_USER = "user" -const ROLE_SYSTEM = "system" -const IMAGE_TYPE = "image_url" func CreateAgentClient() (AgentClient, error) { apiKey := os.Getenv(OPENAI_API_KEY) @@ -245,8 +85,8 @@ func (client AgentClient) getRequest(body []byte) (*http.Request, error) { return req, nil } -func (client AgentClient) Request(request *AgentRequestBody) (AgentResponse, error) { - jsonAiRequest, err := json.Marshal(request) +func (client AgentClient) Request(chat *Chat) (AgentResponse, error) { + jsonAiRequest, err := json.Marshal(chat) if err != nil { return AgentResponse{}, err } @@ -273,34 +113,42 @@ func (client AgentClient) Request(request *AgentRequestBody) (AgentResponse, err return AgentResponse{}, err } - toolCalls := agentResponse.Choices[0].Message.ToolCalls - if len(toolCalls) > 0 { - // Should for sure be more flexible. - request.AddToolCall(AgentAssistantToolCall{ - Role: "assistant", - Content: "", - ToolCalls: toolCalls, - }) + if len(agentResponse.Choices) != 1 { + return AgentResponse{}, errors.New("Unsupported. We currently only accept 1 choice from AI.") } + chat.AddAiResponse(agentResponse.Choices[0].Message) + return agentResponse, nil } -func (client AgentClient) Process(info ToolHandlerInfo, request AgentRequestBody) error { +func (client AgentClient) Process(info ToolHandlerInfo, chat *Chat) error { var err error - for { - toolCall, ok := request.Messages[len(request.Messages)-1].(AgentAssistantToolCall) - if !ok { - return errors.New("Latest message isnt a tool call. TODO") - } + message, err := chat.GetLatest() + if err != nil { + return err + } - _, err = client.ToolHandler.Handle(info, toolCall) + aiMessage, ok := message.(ChatAiMessage) + if !ok { + return errors.New("Latest message isnt an AI message") + } + + if aiMessage.ToolCalls == nil { + // Not an error, we just dont have any tool calls to process. + return nil + } + + for _, toolCall := range *aiMessage.ToolCalls { + toolResponse, err := client.ToolHandler.Handle(info, toolCall) if err != nil { break } - _, err = client.Request(&request) + chat.AddToolResponse(toolResponse) + + _, err = client.Request(chat) if err != nil { break } diff --git a/backend/agents/client/client_test.go b/backend/agents/client/client_test.go deleted file mode 100644 index b6b21bf..0000000 --- a/backend/agents/client/client_test.go +++ /dev/null @@ -1,40 +0,0 @@ -package client - -import ( - "testing" - - "github.com/stretchr/testify/require" -) - -func TestSimpleResponse(t *testing.T) { - // assert := assert.New(t) - require := require.New(t) - - chat := AgentMessages{ - Messages: make([]AgentMessage, 0), - } - - chat.AddSystem("system message") - - err := chat.HandleResponse(AgentResponse{ - Id: "0", - Object: "chat.completion", - Created: 1, - Choices: []ResponseChoice{{ - Index: 0, - Message: ResponseChoiceMessage{ - Role: "assistant", - Content: "some basic content", - }, - FinishReason: "", - }}, - }) - - require.NoError(err) - require.Len(chat.Messages, 2) - - require.EqualValues(chat.Messages[1], ResponseChoiceMessage{ - Role: "assistant", - Content: "some basic content", - }) -} diff --git a/backend/agents/client/tools.go b/backend/agents/client/tools.go index 4acaaef..4a54958 100644 --- a/backend/agents/client/tools.go +++ b/backend/agents/client/tools.go @@ -24,42 +24,30 @@ var NoToolCallError = errors.New("An assistant tool call with no tool calls was const NonExistantTool = "This tool does not exist" -func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage AgentAssistantToolCall) ([]AgentResponseMessage, error) { - if len(toolCallMessage.ToolCalls) == 0 { - return []AgentResponseMessage{}, NoToolCallError +func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage ToolCall) (ChatUserToolResponse, error) { + fnName := toolCallMessage.Function.Name + arguments := toolCallMessage.Function.Arguments + + responseMessage := ChatUserToolResponse{ + Role: "tool", + Name: fnName, + ToolCallId: toolCallMessage.Id, } - responses := make([]AgentResponseMessage, len(toolCallMessage.ToolCalls)) - - for i, toolCall := range toolCallMessage.ToolCalls { - fnName := toolCall.Function.Name - arguments := toolCall.Function.Arguments - - responseMessage := AgentResponseMessage{ - Role: "tool", - Name: fnName, - ToolCallId: toolCall.Id, - } - - fnHandler, exists := handler.handlers[fnName] - if !exists { - responseMessage.Content = NonExistantTool - responses[i] = responseMessage - continue - } - - res, err := fnHandler.Fn(info, arguments, toolCallMessage.ToolCalls[0]) - - if err != nil { - responseMessage.Content = err.Error() - } else { - responseMessage.Content = res - } - - responses[i] = responseMessage + fnHandler, exists := handler.handlers[fnName] + if !exists { + return ChatUserToolResponse{}, errors.New(NonExistantTool) } - return responses, nil + res, err := fnHandler.Fn(info, arguments, toolCallMessage) + + if err != nil { + responseMessage.Content = err.Error() + } else { + responseMessage.Content = res + } + + return responseMessage, nil } func (handler *ToolsHandlers) AddTool(name string, fn func(info ToolHandlerInfo, args string, call ToolCall) (any, error)) { -- 2.47.2 From 88fda321258a88bd39b51b3524897e165ca3dd1d Mon Sep 17 00:00:00 2001 From: John Costa Date: Wed, 9 Apr 2025 12:12:09 +0100 Subject: [PATCH 14/17] fix(types): agent processing stuff --- backend/agents/client/chat.go | 2 +- backend/agents/client/client.go | 1 - backend/agents/event_location_agent.go | 15 +++++++-------- backend/agents/note_agent.go | 12 ++++++------ backend/agents/orchestrator.go | 15 ++++++++------- backend/main.go | 4 ++-- 6 files changed, 24 insertions(+), 25 deletions(-) diff --git a/backend/agents/client/chat.go b/backend/agents/client/chat.go index 6ab6fcf..2adf08f 100644 --- a/backend/agents/client/chat.go +++ b/backend/agents/client/chat.go @@ -73,7 +73,7 @@ type ChatAiMessage struct { Role string `json:"role"` ToolCalls *[]ToolCall `json:"tool_calls,omitempty"` - MessageContent `json:"content"` + Content string `json:"content"` } func (m ChatAiMessage) IsResponse() bool { diff --git a/backend/agents/client/client.go b/backend/agents/client/client.go index be9590e..db2d643 100644 --- a/backend/agents/client/client.go +++ b/backend/agents/client/client.go @@ -4,7 +4,6 @@ import ( "bytes" "encoding/json" "errors" - "fmt" "io" "log" "net/http" diff --git a/backend/agents/event_location_agent.go b/backend/agents/event_location_agent.go index dfef53a..88bbe94 100644 --- a/backend/agents/event_location_agent.go +++ b/backend/agents/event_location_agent.go @@ -161,16 +161,15 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID ResponseFormat: client.ResponseFormat{ Type: "text", }, + Chat: client.Chat{ + Messages: make([]client.ChatMessage, 0), + }, } - err = request.AddSystem(eventLocationPrompt) - if err != nil { - return err - } + request.Chat.AddSystem(eventLocationPrompt) + request.Chat.AddImage(imageName, imageData) - request.AddImage(imageName, imageData) - - _, err = agent.client.Request(&request) + _, err = agent.client.Request(&request.Chat) if err != nil { return err } @@ -180,7 +179,7 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID UserId: userId, } - return agent.client.Process(toolHandlerInfo, request) + return agent.client.Process(toolHandlerInfo, &request.Chat) } func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel, contactModel models.ContactModel) (EventLocationAgent, error) { diff --git a/backend/agents/note_agent.go b/backend/agents/note_agent.go index 29b4ce8..02436d0 100644 --- a/backend/agents/note_agent.go +++ b/backend/agents/note_agent.go @@ -32,15 +32,15 @@ func (agent NoteAgent) GetNotes(userId uuid.UUID, imageId uuid.UUID, imageName s ResponseFormat: client.ResponseFormat{ Type: "text", }, + Chat: client.Chat{ + Messages: make([]client.ChatMessage, 0), + }, } - err := request.AddSystem(noteAgentPrompt) - if err != nil { - return err - } + request.Chat.AddSystem(noteAgentPrompt) + request.Chat.AddImage(imageName, imageData) - request.AddImage(imageName, imageData) - resp, err := agent.client.Request(&request) + resp, err := agent.client.Request(&request.Chat) if err != nil { return err } diff --git a/backend/agents/orchestrator.go b/backend/agents/orchestrator.go index 392d989..67ea5ff 100644 --- a/backend/agents/orchestrator.go +++ b/backend/agents/orchestrator.go @@ -101,15 +101,16 @@ func (agent OrchestratorAgent) Orchestrate(userId uuid.UUID, imageId uuid.UUID, }, ToolChoice: &toolChoice, Tools: &tools, + + Chat: client.Chat{ + Messages: make([]client.ChatMessage, 0), + }, } - err = request.AddSystem(orchestratorPrompt) - if err != nil { - return err - } + request.Chat.AddSystem(orchestratorPrompt) + request.Chat.AddImage(imageName, imageData) - request.AddImage(imageName, imageData) - _, err = agent.client.Request(&request) + _, err = agent.client.Request(&request.Chat) if err != nil { return err } @@ -119,7 +120,7 @@ func (agent OrchestratorAgent) Orchestrate(userId uuid.UUID, imageId uuid.UUID, UserId: userId, } - return agent.client.Process(toolHandlerInfo, request) + return agent.client.Process(toolHandlerInfo, &request.Chat) } func NewOrchestratorAgent(eventLocationAgent EventLocationAgent, noteAgent NoteAgent, imageName string, imageData []byte) (OrchestratorAgent, error) { diff --git a/backend/main.go b/backend/main.go index 205815f..aa180d9 100644 --- a/backend/main.go +++ b/backend/main.go @@ -25,10 +25,10 @@ import ( ) type TestAiClient struct { - ImageInfo client.ImageInfo + ImageInfo client.ImageMessageContent } -func (client TestAiClient) GetImageInfo(imageName string, imageData []byte) (client.ImageInfo, error) { +func (client TestAiClient) GetImageInfo(imageName string, imageData []byte) (client.ImageMessageContent, error) { return client.ImageInfo, nil } -- 2.47.2 From f294f9cdc07680a80843898133dc376b2e21bd59 Mon Sep 17 00:00:00 2001 From: John Costa Date: Wed, 9 Apr 2025 13:56:03 +0100 Subject: [PATCH 15/17] fix(tools): testing and processing fix --- backend/agents/client/chat.go | 26 +++++- backend/agents/client/chat_test.go | 24 ++++++ backend/agents/client/client.go | 31 +++----- backend/agents/client/tools.go | 7 +- backend/agents/client/tools_test.go | 106 ++++++++++++------------- backend/agents/event_location_agent.go | 6 +- backend/agents/note_agent.go | 4 +- backend/agents/orchestrator.go | 9 ++- backend/main.go | 5 +- 9 files changed, 128 insertions(+), 90 deletions(-) create mode 100644 backend/agents/client/chat_test.go diff --git a/backend/agents/client/chat.go b/backend/agents/client/chat.go index 2adf08f..6628e56 100644 --- a/backend/agents/client/chat.go +++ b/backend/agents/client/chat.go @@ -2,6 +2,7 @@ package client import ( "encoding/base64" + "encoding/json" "errors" "fmt" "path/filepath" @@ -49,7 +50,30 @@ const ( type ChatUserMessage struct { Role UserRole `json:"role"` - MessageContent `json:"content"` + MessageContent `json:"MessageContent"` +} + +func (m ChatUserMessage) MarshalJSON() ([]byte, error) { + switch t := m.MessageContent.(type) { + case SingleMessage: + return json.Marshal(&struct { + Role UserRole `json:"role"` + Content string `json:"content"` + }{ + Role: User, + Content: t.Content, + }) + case ArrayMessage: + return json.Marshal(&struct { + Role UserRole `json:"role"` + Content []ImageMessageContent `json:"content"` + }{ + Role: User, + Content: t.Content, + }) + } + + return []byte{}, errors.New("Unreachable") } func (r ChatUserMessage) IsResponse() bool { diff --git a/backend/agents/client/chat_test.go b/backend/agents/client/chat_test.go new file mode 100644 index 0000000..09ca5ba --- /dev/null +++ b/backend/agents/client/chat_test.go @@ -0,0 +1,24 @@ +package client + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestFlatMarshallSingleMessage(t *testing.T) { + require := require.New(t) + + message := ChatUserMessage{ + Role: User, + MessageContent: SingleMessage{ + Content: "Hello", + }, + } + + json, err := json.Marshal(message) + require.NoError(err) + + require.Equal(string(json), "{\"role\":\"user\",\"content\":\"Hello\"}") +} diff --git a/backend/agents/client/client.go b/backend/agents/client/client.go index db2d643..d1eff3b 100644 --- a/backend/agents/client/client.go +++ b/backend/agents/client/client.go @@ -4,8 +4,8 @@ import ( "bytes" "encoding/json" "errors" + "fmt" "io" - "log" "net/http" "os" ) @@ -23,7 +23,7 @@ type AgentRequestBody struct { Tools *any `json:"tools,omitempty"` ToolChoice *string `json:"tool_choice,omitempty"` - Chat Chat `json:"messages"` + Chat *Chat `json:"messages"` } type ResponseChoice struct { @@ -84,8 +84,8 @@ func (client AgentClient) getRequest(body []byte) (*http.Request, error) { return req, nil } -func (client AgentClient) Request(chat *Chat) (AgentResponse, error) { - jsonAiRequest, err := json.Marshal(chat) +func (client AgentClient) Request(req *AgentRequestBody) (AgentResponse, error) { + jsonAiRequest, err := json.Marshal(req) if err != nil { return AgentResponse{}, err } @@ -105,6 +105,8 @@ func (client AgentClient) Request(chat *Chat) (AgentResponse, error) { return AgentResponse{}, err } + fmt.Println(string(response)) + agentResponse := AgentResponse{} err = json.Unmarshal(response, &agentResponse) @@ -116,15 +118,15 @@ func (client AgentClient) Request(chat *Chat) (AgentResponse, error) { return AgentResponse{}, errors.New("Unsupported. We currently only accept 1 choice from AI.") } - chat.AddAiResponse(agentResponse.Choices[0].Message) + req.Chat.AddAiResponse(agentResponse.Choices[0].Message) return agentResponse, nil } -func (client AgentClient) Process(info ToolHandlerInfo, chat *Chat) error { +func (client AgentClient) Process(info ToolHandlerInfo, req *AgentRequestBody) error { var err error - message, err := chat.GetLatest() + message, err := req.Chat.GetLatest() if err != nil { return err } @@ -140,21 +142,10 @@ func (client AgentClient) Process(info ToolHandlerInfo, chat *Chat) error { } for _, toolCall := range *aiMessage.ToolCalls { - toolResponse, err := client.ToolHandler.Handle(info, toolCall) - if err != nil { - break - } + toolResponse := client.ToolHandler.Handle(info, toolCall) - chat.AddToolResponse(toolResponse) - - _, err = client.Request(chat) - if err != nil { - break - } + req.Chat.AddToolResponse(toolResponse) } - if err != nil { - log.Println(err) - } return err } diff --git a/backend/agents/client/tools.go b/backend/agents/client/tools.go index 4a54958..8dc9244 100644 --- a/backend/agents/client/tools.go +++ b/backend/agents/client/tools.go @@ -24,7 +24,7 @@ var NoToolCallError = errors.New("An assistant tool call with no tool calls was const NonExistantTool = "This tool does not exist" -func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage ToolCall) (ChatUserToolResponse, error) { +func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage ToolCall) ChatUserToolResponse { fnName := toolCallMessage.Function.Name arguments := toolCallMessage.Function.Arguments @@ -36,7 +36,8 @@ func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage ToolCa fnHandler, exists := handler.handlers[fnName] if !exists { - return ChatUserToolResponse{}, errors.New(NonExistantTool) + responseMessage.Content = NonExistantTool + return responseMessage } res, err := fnHandler.Fn(info, arguments, toolCallMessage) @@ -47,7 +48,7 @@ func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage ToolCa responseMessage.Content = res } - return responseMessage, nil + return responseMessage } func (handler *ToolsHandlers) AddTool(name string, fn func(info ToolHandlerInfo, args string, call ToolCall) (any, error)) { diff --git a/backend/agents/client/tools_test.go b/backend/agents/client/tools_test.go index f4d1d73..912d94b 100644 --- a/backend/agents/client/tools_test.go +++ b/backend/agents/client/tools_test.go @@ -12,6 +12,7 @@ type ToolTestSuite struct { suite.Suite handler ToolsHandlers + client AgentClient } func (suite *ToolTestSuite) SetupTest() { @@ -26,70 +27,44 @@ func (suite *ToolTestSuite) SetupTest() { suite.handler.AddTool("error", func(info ToolHandlerInfo, args string, call ToolCall) (any, error) { return false, errors.New("I will always error") }) + + suite.client.ToolHandler = suite.handler } func (suite *ToolTestSuite) TestSingleToolCall() { - assert := suite.Assert() require := suite.Require() - response, err := suite.handler.Handle( + response := suite.handler.Handle( ToolHandlerInfo{ UserId: uuid.Nil, ImageId: uuid.Nil, }, - AgentAssistantToolCall{ - Role: "assistant", - Content: "", - ToolCalls: []ToolCall{{ - Index: 0, - Id: "1", - Function: FunctionCall{ - Name: "a", - Arguments: "return", - }, - }}, + ToolCall{ + Index: 0, + Id: "1", + Function: FunctionCall{ + Name: "a", + Arguments: "return", + }, }) - require.NoError(err, "Tool call shouldnt return an error") - - assert.EqualValues(response, []AgentResponseMessage{{ + require.EqualValues(response, ChatUserToolResponse{ Role: "tool", Content: "\"return\"", ToolCallId: "1", Name: "a", - }}) -} - -func (suite *ToolTestSuite) TestEmptyCall() { - require := suite.Require() - - _, err := suite.handler.Handle( - ToolHandlerInfo{ - UserId: uuid.Nil, - ImageId: uuid.Nil, - }, - AgentAssistantToolCall{ - Role: "assistant", - Content: "", - ToolCalls: []ToolCall{}, - }) - - require.ErrorIs(err, NoToolCallError) + }) } func (suite *ToolTestSuite) TestMultipleToolCalls() { assert := suite.Assert() require := suite.Require() - response, err := suite.handler.Handle( - ToolHandlerInfo{ - UserId: uuid.Nil, - ImageId: uuid.Nil, - }, - AgentAssistantToolCall{ + chat := Chat{ + Messages: []ChatMessage{ChatAiMessage{ Role: "assistant", Content: "", - ToolCalls: []ToolCall{ + ToolCalls: &[]ToolCall{ { Index: 0, Id: "1", @@ -107,18 +82,27 @@ func (suite *ToolTestSuite) TestMultipleToolCalls() { }, }, }, + }}, + } + + err := suite.client.Process( + ToolHandlerInfo{ + UserId: uuid.Nil, + ImageId: uuid.Nil, + }, + &AgentRequestBody{ + Chat: &chat, }) require.NoError(err, "Tool call shouldnt return an error") - - assert.EqualValues(response, []AgentResponseMessage{ - { + assert.EqualValues(chat.Messages[1:], []ChatMessage{ + ChatUserToolResponse{ Role: "tool", Content: "\"first-call\"", ToolCallId: "1", Name: "a", }, - { + ChatUserToolResponse{ Role: "tool", Content: "\"second-call\"", ToolCallId: "2", @@ -131,15 +115,11 @@ func (suite *ToolTestSuite) TestMultipleToolCallsWithErrors() { assert := suite.Assert() require := suite.Require() - response, err := suite.handler.Handle( - ToolHandlerInfo{ - UserId: uuid.Nil, - ImageId: uuid.Nil, - }, - AgentAssistantToolCall{ + chat := Chat{ + Messages: []ChatMessage{ChatAiMessage{ Role: "assistant", Content: "", - ToolCalls: []ToolCall{ + ToolCalls: &[]ToolCall{ { Index: 0, Id: "1", @@ -165,24 +145,34 @@ func (suite *ToolTestSuite) TestMultipleToolCallsWithErrors() { }, }, }, + }}, + } + + err := suite.client.Process( + ToolHandlerInfo{ + UserId: uuid.Nil, + ImageId: uuid.Nil, + }, + &AgentRequestBody{ + Chat: &chat, }) require.NoError(err, "Tool call shouldnt return an error") - assert.EqualValues(response, []AgentResponseMessage{ - { + assert.EqualValues(chat.Messages[1:], []ChatMessage{ + ChatUserToolResponse{ Role: "tool", Content: "I will always error", ToolCallId: "1", Name: "error", }, - { + ChatUserToolResponse{ Role: "tool", Content: "This tool does not exist", ToolCallId: "2", Name: "non-existant", }, - { + ChatUserToolResponse{ Role: "tool", Content: "\"no-error\"", ToolCallId: "3", @@ -192,5 +182,7 @@ func (suite *ToolTestSuite) TestMultipleToolCallsWithErrors() { } func TestToolSuite(t *testing.T) { - suite.Run(t, &ToolTestSuite{}) + suite.Run(t, &ToolTestSuite{ + client: AgentClient{}, + }) } diff --git a/backend/agents/event_location_agent.go b/backend/agents/event_location_agent.go index 88bbe94..53bc017 100644 --- a/backend/agents/event_location_agent.go +++ b/backend/agents/event_location_agent.go @@ -161,7 +161,7 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID ResponseFormat: client.ResponseFormat{ Type: "text", }, - Chat: client.Chat{ + Chat: &client.Chat{ Messages: make([]client.ChatMessage, 0), }, } @@ -169,7 +169,7 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID request.Chat.AddSystem(eventLocationPrompt) request.Chat.AddImage(imageName, imageData) - _, err = agent.client.Request(&request.Chat) + _, err = agent.client.Request(&request) if err != nil { return err } @@ -179,7 +179,7 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID UserId: userId, } - return agent.client.Process(toolHandlerInfo, &request.Chat) + return agent.client.Process(toolHandlerInfo, &request) } func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel, contactModel models.ContactModel) (EventLocationAgent, error) { diff --git a/backend/agents/note_agent.go b/backend/agents/note_agent.go index 02436d0..59b2fb0 100644 --- a/backend/agents/note_agent.go +++ b/backend/agents/note_agent.go @@ -32,7 +32,7 @@ func (agent NoteAgent) GetNotes(userId uuid.UUID, imageId uuid.UUID, imageName s ResponseFormat: client.ResponseFormat{ Type: "text", }, - Chat: client.Chat{ + Chat: &client.Chat{ Messages: make([]client.ChatMessage, 0), }, } @@ -40,7 +40,7 @@ func (agent NoteAgent) GetNotes(userId uuid.UUID, imageId uuid.UUID, imageName s request.Chat.AddSystem(noteAgentPrompt) request.Chat.AddImage(imageName, imageData) - resp, err := agent.client.Request(&request.Chat) + resp, err := agent.client.Request(&request) if err != nil { return err } diff --git a/backend/agents/orchestrator.go b/backend/agents/orchestrator.go index 67ea5ff..a7577a7 100644 --- a/backend/agents/orchestrator.go +++ b/backend/agents/orchestrator.go @@ -3,6 +3,7 @@ package agents import ( "encoding/json" "errors" + "fmt" "screenmark/screenmark/agents/client" "github.com/google/uuid" @@ -102,7 +103,7 @@ func (agent OrchestratorAgent) Orchestrate(userId uuid.UUID, imageId uuid.UUID, ToolChoice: &toolChoice, Tools: &tools, - Chat: client.Chat{ + Chat: &client.Chat{ Messages: make([]client.ChatMessage, 0), }, } @@ -110,17 +111,19 @@ func (agent OrchestratorAgent) Orchestrate(userId uuid.UUID, imageId uuid.UUID, request.Chat.AddSystem(orchestratorPrompt) request.Chat.AddImage(imageName, imageData) - _, err = agent.client.Request(&request.Chat) + res, err := agent.client.Request(&request) if err != nil { return err } + fmt.Println(res) + toolHandlerInfo := client.ToolHandlerInfo{ ImageId: imageId, UserId: userId, } - return agent.client.Process(toolHandlerInfo, &request.Chat) + return agent.client.Process(toolHandlerInfo, &request) } func NewOrchestratorAgent(eventLocationAgent EventLocationAgent, noteAgent NoteAgent, imageName string, imageData []byte) (OrchestratorAgent, error) { diff --git a/backend/main.go b/backend/main.go index aa180d9..1f395e0 100644 --- a/backend/main.go +++ b/backend/main.go @@ -104,7 +104,10 @@ func main() { panic(err) } - orchestrator.Orchestrate(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image) + err = orchestrator.Orchestrate(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image) + if err != nil { + fmt.Println(err) + } }() } } -- 2.47.2 From c35951063a55bef63b2c65448ca29e4e75a1928e Mon Sep 17 00:00:00 2001 From: John Costa Date: Wed, 9 Apr 2025 15:15:31 +0100 Subject: [PATCH 16/17] fix(tool-calls): ToolLoop --- backend/agents/client/client.go | 43 ++++++++++++++++++++++++++ backend/agents/event_location_agent.go | 3 +- 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/backend/agents/client/client.go b/backend/agents/client/client.go index d1eff3b..18af351 100644 --- a/backend/agents/client/client.go +++ b/backend/agents/client/client.go @@ -23,9 +23,32 @@ type AgentRequestBody struct { Tools *any `json:"tools,omitempty"` ToolChoice *string `json:"tool_choice,omitempty"` + EndToolCall string `json:"-"` + Chat *Chat `json:"messages"` } +func (req AgentRequestBody) MarshalJSON() ([]byte, error) { + return json.Marshal(&struct { + Model string `json:"model"` + Temperature float64 `json:"temperature"` + ResponseFormat ResponseFormat `json:"response_format"` + + Tools *any `json:"tools,omitempty"` + ToolChoice *string `json:"tool_choice,omitempty"` + Messages []ChatMessage `json:"messages"` + }{ + Model: req.Model, + Temperature: req.Temperature, + ResponseFormat: req.ResponseFormat, + + Tools: req.Tools, + ToolChoice: req.ToolChoice, + + Messages: req.Chat.Messages, + }) +} + type ResponseChoice struct { Index int `json:"index"` Message ChatAiMessage `json:"message"` @@ -123,6 +146,22 @@ func (client AgentClient) Request(req *AgentRequestBody) (AgentResponse, error) return agentResponse, nil } +func (client AgentClient) ToolLoop(info ToolHandlerInfo, req *AgentRequestBody) error { + for { + err := client.Process(info, req) + if err != nil { + return err + } + + _, err = client.Request(req) + if err != nil { + return err + } + } +} + +var FinishedCall = errors.New("Last tool tool was called") + func (client AgentClient) Process(info ToolHandlerInfo, req *AgentRequestBody) error { var err error @@ -142,6 +181,10 @@ func (client AgentClient) Process(info ToolHandlerInfo, req *AgentRequestBody) e } for _, toolCall := range *aiMessage.ToolCalls { + if toolCall.Function.Name == req.EndToolCall { + return FinishedCall + } + toolResponse := client.ToolHandler.Handle(info, toolCall) req.Chat.AddToolResponse(toolResponse) diff --git a/backend/agents/event_location_agent.go b/backend/agents/event_location_agent.go index 53bc017..22672f6 100644 --- a/backend/agents/event_location_agent.go +++ b/backend/agents/event_location_agent.go @@ -158,6 +158,7 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID ToolChoice: &toolChoice, Model: "pixtral-12b-2409", Temperature: 0.3, + EndToolCall: "finish", ResponseFormat: client.ResponseFormat{ Type: "text", }, @@ -179,7 +180,7 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID UserId: userId, } - return agent.client.Process(toolHandlerInfo, &request) + return agent.client.ToolLoop(toolHandlerInfo, &request) } func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel, contactModel models.ContactModel) (EventLocationAgent, error) { -- 2.47.2 From 1a9b707533fb4f79b867688fc54901cb50b6878f Mon Sep 17 00:00:00 2001 From: John Costa Date: Wed, 9 Apr 2025 15:23:51 +0100 Subject: [PATCH 17/17] feat(orchestrator): async processing and ending the loop3 --- backend/agents/orchestrator.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/backend/agents/orchestrator.go b/backend/agents/orchestrator.go index a7577a7..430699c 100644 --- a/backend/agents/orchestrator.go +++ b/backend/agents/orchestrator.go @@ -103,6 +103,8 @@ func (agent OrchestratorAgent) Orchestrate(userId uuid.UUID, imageId uuid.UUID, ToolChoice: &toolChoice, Tools: &tools, + EndToolCall: "defaultAgent", + Chat: &client.Chat{ Messages: make([]client.ChatMessage, 0), }, @@ -123,7 +125,7 @@ func (agent OrchestratorAgent) Orchestrate(userId uuid.UUID, imageId uuid.UUID, UserId: userId, } - return agent.client.Process(toolHandlerInfo, &request) + return agent.client.ToolLoop(toolHandlerInfo, &request) } func NewOrchestratorAgent(eventLocationAgent EventLocationAgent, noteAgent NoteAgent, imageName string, imageData []byte) (OrchestratorAgent, error) { @@ -136,7 +138,7 @@ func NewOrchestratorAgent(eventLocationAgent EventLocationAgent, noteAgent NoteA // We need a way to keep track of this async? // Probably just a DB, because we don't want to wait. The orchistrator shouldnt wait for this stuff to finish. - eventLocationAgent.GetLocations(info.UserId, info.ImageId, imageName, imageData) + go eventLocationAgent.GetLocations(info.UserId, info.ImageId, imageName, imageData) return Status{ Ok: true, @@ -144,7 +146,7 @@ func NewOrchestratorAgent(eventLocationAgent EventLocationAgent, noteAgent NoteA }) agent.ToolHandler.AddTool("noteAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { - noteAgent.GetNotes(info.UserId, info.ImageId, imageName, imageData) + go noteAgent.GetNotes(info.UserId, info.ImageId, imageName, imageData) return Status{ Ok: true, -- 2.47.2