From f169fd2ba2d04703fcb5c877724689fcc8812274 Mon Sep 17 00:00:00 2001 From: John Costa Date: Wed, 9 Apr 2025 13:56:03 +0100 Subject: [PATCH] fix(tools): testing and processing fix --- backend/agents/client/chat.go | 26 +++++- backend/agents/client/chat_test.go | 24 ++++++ backend/agents/client/client.go | 31 +++----- backend/agents/client/tools.go | 7 +- backend/agents/client/tools_test.go | 106 ++++++++++++------------- backend/agents/event_location_agent.go | 6 +- backend/agents/note_agent.go | 4 +- backend/agents/orchestrator.go | 9 ++- backend/main.go | 5 +- 9 files changed, 128 insertions(+), 90 deletions(-) create mode 100644 backend/agents/client/chat_test.go diff --git a/backend/agents/client/chat.go b/backend/agents/client/chat.go index 2adf08f..6628e56 100644 --- a/backend/agents/client/chat.go +++ b/backend/agents/client/chat.go @@ -2,6 +2,7 @@ package client import ( "encoding/base64" + "encoding/json" "errors" "fmt" "path/filepath" @@ -49,7 +50,30 @@ const ( type ChatUserMessage struct { Role UserRole `json:"role"` - MessageContent `json:"content"` + MessageContent `json:"MessageContent"` +} + +func (m ChatUserMessage) MarshalJSON() ([]byte, error) { + switch t := m.MessageContent.(type) { + case SingleMessage: + return json.Marshal(&struct { + Role UserRole `json:"role"` + Content string `json:"content"` + }{ + Role: User, + Content: t.Content, + }) + case ArrayMessage: + return json.Marshal(&struct { + Role UserRole `json:"role"` + Content []ImageMessageContent `json:"content"` + }{ + Role: User, + Content: t.Content, + }) + } + + return []byte{}, errors.New("Unreachable") } func (r ChatUserMessage) IsResponse() bool { diff --git a/backend/agents/client/chat_test.go b/backend/agents/client/chat_test.go new file mode 100644 index 0000000..09ca5ba --- /dev/null +++ b/backend/agents/client/chat_test.go @@ -0,0 +1,24 @@ +package client + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestFlatMarshallSingleMessage(t *testing.T) { + require := require.New(t) + + message := ChatUserMessage{ + Role: User, + MessageContent: SingleMessage{ + Content: "Hello", + }, + } + + json, err := json.Marshal(message) + require.NoError(err) + + require.Equal(string(json), "{\"role\":\"user\",\"content\":\"Hello\"}") +} diff --git a/backend/agents/client/client.go b/backend/agents/client/client.go index db2d643..d1eff3b 100644 --- a/backend/agents/client/client.go +++ b/backend/agents/client/client.go @@ -4,8 +4,8 @@ import ( "bytes" "encoding/json" "errors" + "fmt" "io" - "log" "net/http" "os" ) @@ -23,7 +23,7 @@ type AgentRequestBody struct { Tools *any `json:"tools,omitempty"` ToolChoice *string `json:"tool_choice,omitempty"` - Chat Chat `json:"messages"` + Chat *Chat `json:"messages"` } type ResponseChoice struct { @@ -84,8 +84,8 @@ func (client AgentClient) getRequest(body []byte) (*http.Request, error) { return req, nil } -func (client AgentClient) Request(chat *Chat) (AgentResponse, error) { - jsonAiRequest, err := json.Marshal(chat) +func (client AgentClient) Request(req *AgentRequestBody) (AgentResponse, error) { + jsonAiRequest, err := json.Marshal(req) if err != nil { return AgentResponse{}, err } @@ -105,6 +105,8 @@ func (client AgentClient) Request(chat *Chat) (AgentResponse, error) { return AgentResponse{}, err } + fmt.Println(string(response)) + agentResponse := AgentResponse{} err = json.Unmarshal(response, &agentResponse) @@ -116,15 +118,15 @@ func (client AgentClient) Request(chat *Chat) (AgentResponse, error) { return AgentResponse{}, errors.New("Unsupported. We currently only accept 1 choice from AI.") } - chat.AddAiResponse(agentResponse.Choices[0].Message) + req.Chat.AddAiResponse(agentResponse.Choices[0].Message) return agentResponse, nil } -func (client AgentClient) Process(info ToolHandlerInfo, chat *Chat) error { +func (client AgentClient) Process(info ToolHandlerInfo, req *AgentRequestBody) error { var err error - message, err := chat.GetLatest() + message, err := req.Chat.GetLatest() if err != nil { return err } @@ -140,21 +142,10 @@ func (client AgentClient) Process(info ToolHandlerInfo, chat *Chat) error { } for _, toolCall := range *aiMessage.ToolCalls { - toolResponse, err := client.ToolHandler.Handle(info, toolCall) - if err != nil { - break - } + toolResponse := client.ToolHandler.Handle(info, toolCall) - chat.AddToolResponse(toolResponse) - - _, err = client.Request(chat) - if err != nil { - break - } + req.Chat.AddToolResponse(toolResponse) } - if err != nil { - log.Println(err) - } return err } diff --git a/backend/agents/client/tools.go b/backend/agents/client/tools.go index 4a54958..8dc9244 100644 --- a/backend/agents/client/tools.go +++ b/backend/agents/client/tools.go @@ -24,7 +24,7 @@ var NoToolCallError = errors.New("An assistant tool call with no tool calls was const NonExistantTool = "This tool does not exist" -func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage ToolCall) (ChatUserToolResponse, error) { +func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage ToolCall) ChatUserToolResponse { fnName := toolCallMessage.Function.Name arguments := toolCallMessage.Function.Arguments @@ -36,7 +36,8 @@ func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage ToolCa fnHandler, exists := handler.handlers[fnName] if !exists { - return ChatUserToolResponse{}, errors.New(NonExistantTool) + responseMessage.Content = NonExistantTool + return responseMessage } res, err := fnHandler.Fn(info, arguments, toolCallMessage) @@ -47,7 +48,7 @@ func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage ToolCa responseMessage.Content = res } - return responseMessage, nil + return responseMessage } func (handler *ToolsHandlers) AddTool(name string, fn func(info ToolHandlerInfo, args string, call ToolCall) (any, error)) { diff --git a/backend/agents/client/tools_test.go b/backend/agents/client/tools_test.go index f4d1d73..912d94b 100644 --- a/backend/agents/client/tools_test.go +++ b/backend/agents/client/tools_test.go @@ -12,6 +12,7 @@ type ToolTestSuite struct { suite.Suite handler ToolsHandlers + client AgentClient } func (suite *ToolTestSuite) SetupTest() { @@ -26,70 +27,44 @@ func (suite *ToolTestSuite) SetupTest() { suite.handler.AddTool("error", func(info ToolHandlerInfo, args string, call ToolCall) (any, error) { return false, errors.New("I will always error") }) + + suite.client.ToolHandler = suite.handler } func (suite *ToolTestSuite) TestSingleToolCall() { - assert := suite.Assert() require := suite.Require() - response, err := suite.handler.Handle( + response := suite.handler.Handle( ToolHandlerInfo{ UserId: uuid.Nil, ImageId: uuid.Nil, }, - AgentAssistantToolCall{ - Role: "assistant", - Content: "", - ToolCalls: []ToolCall{{ - Index: 0, - Id: "1", - Function: FunctionCall{ - Name: "a", - Arguments: "return", - }, - }}, + ToolCall{ + Index: 0, + Id: "1", + Function: FunctionCall{ + Name: "a", + Arguments: "return", + }, }) - require.NoError(err, "Tool call shouldnt return an error") - - assert.EqualValues(response, []AgentResponseMessage{{ + require.EqualValues(response, ChatUserToolResponse{ Role: "tool", Content: "\"return\"", ToolCallId: "1", Name: "a", - }}) -} - -func (suite *ToolTestSuite) TestEmptyCall() { - require := suite.Require() - - _, err := suite.handler.Handle( - ToolHandlerInfo{ - UserId: uuid.Nil, - ImageId: uuid.Nil, - }, - AgentAssistantToolCall{ - Role: "assistant", - Content: "", - ToolCalls: []ToolCall{}, - }) - - require.ErrorIs(err, NoToolCallError) + }) } func (suite *ToolTestSuite) TestMultipleToolCalls() { assert := suite.Assert() require := suite.Require() - response, err := suite.handler.Handle( - ToolHandlerInfo{ - UserId: uuid.Nil, - ImageId: uuid.Nil, - }, - AgentAssistantToolCall{ + chat := Chat{ + Messages: []ChatMessage{ChatAiMessage{ Role: "assistant", Content: "", - ToolCalls: []ToolCall{ + ToolCalls: &[]ToolCall{ { Index: 0, Id: "1", @@ -107,18 +82,27 @@ func (suite *ToolTestSuite) TestMultipleToolCalls() { }, }, }, + }}, + } + + err := suite.client.Process( + ToolHandlerInfo{ + UserId: uuid.Nil, + ImageId: uuid.Nil, + }, + &AgentRequestBody{ + Chat: &chat, }) require.NoError(err, "Tool call shouldnt return an error") - - assert.EqualValues(response, []AgentResponseMessage{ - { + assert.EqualValues(chat.Messages[1:], []ChatMessage{ + ChatUserToolResponse{ Role: "tool", Content: "\"first-call\"", ToolCallId: "1", Name: "a", }, - { + ChatUserToolResponse{ Role: "tool", Content: "\"second-call\"", ToolCallId: "2", @@ -131,15 +115,11 @@ func (suite *ToolTestSuite) TestMultipleToolCallsWithErrors() { assert := suite.Assert() require := suite.Require() - response, err := suite.handler.Handle( - ToolHandlerInfo{ - UserId: uuid.Nil, - ImageId: uuid.Nil, - }, - AgentAssistantToolCall{ + chat := Chat{ + Messages: []ChatMessage{ChatAiMessage{ Role: "assistant", Content: "", - ToolCalls: []ToolCall{ + ToolCalls: &[]ToolCall{ { Index: 0, Id: "1", @@ -165,24 +145,34 @@ func (suite *ToolTestSuite) TestMultipleToolCallsWithErrors() { }, }, }, + }}, + } + + err := suite.client.Process( + ToolHandlerInfo{ + UserId: uuid.Nil, + ImageId: uuid.Nil, + }, + &AgentRequestBody{ + Chat: &chat, }) require.NoError(err, "Tool call shouldnt return an error") - assert.EqualValues(response, []AgentResponseMessage{ - { + assert.EqualValues(chat.Messages[1:], []ChatMessage{ + ChatUserToolResponse{ Role: "tool", Content: "I will always error", ToolCallId: "1", Name: "error", }, - { + ChatUserToolResponse{ Role: "tool", Content: "This tool does not exist", ToolCallId: "2", Name: "non-existant", }, - { + ChatUserToolResponse{ Role: "tool", Content: "\"no-error\"", ToolCallId: "3", @@ -192,5 +182,7 @@ func (suite *ToolTestSuite) TestMultipleToolCallsWithErrors() { } func TestToolSuite(t *testing.T) { - suite.Run(t, &ToolTestSuite{}) + suite.Run(t, &ToolTestSuite{ + client: AgentClient{}, + }) } diff --git a/backend/agents/event_location_agent.go b/backend/agents/event_location_agent.go index 88bbe94..53bc017 100644 --- a/backend/agents/event_location_agent.go +++ b/backend/agents/event_location_agent.go @@ -161,7 +161,7 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID ResponseFormat: client.ResponseFormat{ Type: "text", }, - Chat: client.Chat{ + Chat: &client.Chat{ Messages: make([]client.ChatMessage, 0), }, } @@ -169,7 +169,7 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID request.Chat.AddSystem(eventLocationPrompt) request.Chat.AddImage(imageName, imageData) - _, err = agent.client.Request(&request.Chat) + _, err = agent.client.Request(&request) if err != nil { return err } @@ -179,7 +179,7 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID UserId: userId, } - return agent.client.Process(toolHandlerInfo, &request.Chat) + return agent.client.Process(toolHandlerInfo, &request) } func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel, contactModel models.ContactModel) (EventLocationAgent, error) { diff --git a/backend/agents/note_agent.go b/backend/agents/note_agent.go index 02436d0..59b2fb0 100644 --- a/backend/agents/note_agent.go +++ b/backend/agents/note_agent.go @@ -32,7 +32,7 @@ func (agent NoteAgent) GetNotes(userId uuid.UUID, imageId uuid.UUID, imageName s ResponseFormat: client.ResponseFormat{ Type: "text", }, - Chat: client.Chat{ + Chat: &client.Chat{ Messages: make([]client.ChatMessage, 0), }, } @@ -40,7 +40,7 @@ func (agent NoteAgent) GetNotes(userId uuid.UUID, imageId uuid.UUID, imageName s request.Chat.AddSystem(noteAgentPrompt) request.Chat.AddImage(imageName, imageData) - resp, err := agent.client.Request(&request.Chat) + resp, err := agent.client.Request(&request) if err != nil { return err } diff --git a/backend/agents/orchestrator.go b/backend/agents/orchestrator.go index 67ea5ff..a7577a7 100644 --- a/backend/agents/orchestrator.go +++ b/backend/agents/orchestrator.go @@ -3,6 +3,7 @@ package agents import ( "encoding/json" "errors" + "fmt" "screenmark/screenmark/agents/client" "github.com/google/uuid" @@ -102,7 +103,7 @@ func (agent OrchestratorAgent) Orchestrate(userId uuid.UUID, imageId uuid.UUID, ToolChoice: &toolChoice, Tools: &tools, - Chat: client.Chat{ + Chat: &client.Chat{ Messages: make([]client.ChatMessage, 0), }, } @@ -110,17 +111,19 @@ func (agent OrchestratorAgent) Orchestrate(userId uuid.UUID, imageId uuid.UUID, request.Chat.AddSystem(orchestratorPrompt) request.Chat.AddImage(imageName, imageData) - _, err = agent.client.Request(&request.Chat) + res, err := agent.client.Request(&request) if err != nil { return err } + fmt.Println(res) + toolHandlerInfo := client.ToolHandlerInfo{ ImageId: imageId, UserId: userId, } - return agent.client.Process(toolHandlerInfo, &request.Chat) + return agent.client.Process(toolHandlerInfo, &request) } func NewOrchestratorAgent(eventLocationAgent EventLocationAgent, noteAgent NoteAgent, imageName string, imageData []byte) (OrchestratorAgent, error) { diff --git a/backend/main.go b/backend/main.go index aa180d9..1f395e0 100644 --- a/backend/main.go +++ b/backend/main.go @@ -104,7 +104,10 @@ func main() { panic(err) } - orchestrator.Orchestrate(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image) + err = orchestrator.Orchestrate(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image) + if err != nil { + fmt.Println(err) + } }() } }