From 28ee32e2ff6e5bc4520ff5eb2d6f7e64a9fa0c1f Mon Sep 17 00:00:00 2001 From: John Costa Date: Sun, 6 Apr 2025 20:24:40 +0100 Subject: [PATCH] fixup(chat): better way to organize agent messages and tool calls --- backend/agents/client/client.go | 50 +++++++++++++++++++------- backend/agents/client/client_test.go | 40 +++++++++++++++++++++ backend/agents/client/tools.go | 8 ++--- backend/agents/client/tools_test.go | 6 ++-- backend/agents/event_location_agent.go | 2 +- backend/agents/note_agent.go | 2 +- backend/agents/orchestrator.go | 2 +- 7 files changed, 87 insertions(+), 23 deletions(-) create mode 100644 backend/agents/client/client_test.go diff --git a/backend/agents/client/client.go b/backend/agents/client/client.go index 3964aca..616d71a 100644 --- a/backend/agents/client/client.go +++ b/backend/agents/client/client.go @@ -49,14 +49,16 @@ type AgentMessage interface { MessageToJson() ([]byte, error) } -type AgentTextMessage struct { - Role string `json:"role"` - Content string `json:"content"` - ToolCallId string `json:"tool_call_id,omitempty"` - Name string `json:"name,omitempty"` +type AgentResponseMessage struct { + Role string `json:"role"` + Content string `json:"content"` + // Not sure I need this field. + ToolCallId string `json:"tool_call_id,omitempty"` + ToolCalls *[]ToolCall `json:"tool_calls,omitempty"` + Name string `json:"name,omitempty"` } -func (textContent AgentTextMessage) MessageToJson() ([]byte, error) { +func (textContent AgentResponseMessage) MessageToJson() ([]byte, error) { // TODO: Validate the `Role`. return json.Marshal(textContent) } @@ -91,7 +93,7 @@ func (arrayContent AgentArrayMessage) MessageToJson() ([]byte, error) { return json.Marshal(arrayContent) } -func (content *AgentMessages) AddText(message AgentTextMessage) { +func (content *AgentMessages) AddResponse(message ResponseChoiceMessage) { content.Messages = append(content.Messages, message) } @@ -126,7 +128,7 @@ func (content *AgentMessages) AddSystem(prompt string) error { return errors.New("You can only add a system prompt at the beginning") } - content.Messages = append(content.Messages, AgentTextMessage{ + content.Messages = append(content.Messages, AgentResponseMessage{ Role: ROLE_SYSTEM, Content: prompt, }) @@ -134,6 +136,26 @@ func (content *AgentMessages) AddSystem(prompt string) error { return nil } +// TODO: `AgentMessages` is not really a good name. +// It's a step above that, like a real chat. AgentChat or something. +func (chat *AgentMessages) HandleResponse(response AgentResponse) error { + if len(chat.Messages) == 0 { + return errors.New("This chat doesnt contain any messages therefore cannot be handled.") + } + + for _, choice := range response.Choices { + // TOOD + // if len(choice.Message.ToolCalls) > 0 { + // for _, toolCall := choice.Message.ToolCalls { + // chat.AddToolCall() + // } + // } + chat.AddResponse(choice.Message) + } + + return nil +} + type AgentContent interface { ToJson() ([]byte, error) } @@ -158,6 +180,10 @@ type ResponseChoiceMessage struct { ToolCalls []ToolCall `json:"tool_calls"` } +func (choice ResponseChoiceMessage) MessageToJson() ([]byte, error) { + return json.Marshal(choice) +} + type ResponseChoice struct { Index int `json:"index"` Message ResponseChoiceMessage `json:"message"` @@ -174,7 +200,6 @@ type AgentResponse struct { type AgentClient struct { url string apiKey string - systemPrompt string responseFormat string ToolHandler ToolsHandlers @@ -187,7 +212,7 @@ const ROLE_USER = "user" const ROLE_SYSTEM = "system" const IMAGE_TYPE = "image_url" -func CreateAgentClient(prompt string) (AgentClient, error) { +func CreateAgentClient() (AgentClient, error) { apiKey := os.Getenv(OPENAI_API_KEY) if len(apiKey) == 0 { @@ -195,9 +220,8 @@ func CreateAgentClient(prompt string) (AgentClient, error) { } return AgentClient{ - apiKey: apiKey, - url: "https://api.mistral.ai/v1/chat/completions", - systemPrompt: prompt, + apiKey: apiKey, + url: "https://api.mistral.ai/v1/chat/completions", Do: func(req *http.Request) (*http.Response, error) { client := &http.Client{} return client.Do(req) diff --git a/backend/agents/client/client_test.go b/backend/agents/client/client_test.go new file mode 100644 index 0000000..b6b21bf --- /dev/null +++ b/backend/agents/client/client_test.go @@ -0,0 +1,40 @@ +package client + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSimpleResponse(t *testing.T) { + // assert := assert.New(t) + require := require.New(t) + + chat := AgentMessages{ + Messages: make([]AgentMessage, 0), + } + + chat.AddSystem("system message") + + err := chat.HandleResponse(AgentResponse{ + Id: "0", + Object: "chat.completion", + Created: 1, + Choices: []ResponseChoice{{ + Index: 0, + Message: ResponseChoiceMessage{ + Role: "assistant", + Content: "some basic content", + }, + FinishReason: "", + }}, + }) + + require.NoError(err) + require.Len(chat.Messages, 2) + + require.EqualValues(chat.Messages[1], ResponseChoiceMessage{ + Role: "assistant", + Content: "some basic content", + }) +} diff --git a/backend/agents/client/tools.go b/backend/agents/client/tools.go index 5bfd6d1..4acaaef 100644 --- a/backend/agents/client/tools.go +++ b/backend/agents/client/tools.go @@ -24,18 +24,18 @@ var NoToolCallError = errors.New("An assistant tool call with no tool calls was const NonExistantTool = "This tool does not exist" -func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage AgentAssistantToolCall) ([]AgentTextMessage, error) { +func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage AgentAssistantToolCall) ([]AgentResponseMessage, error) { if len(toolCallMessage.ToolCalls) == 0 { - return []AgentTextMessage{}, NoToolCallError + return []AgentResponseMessage{}, NoToolCallError } - responses := make([]AgentTextMessage, len(toolCallMessage.ToolCalls)) + responses := make([]AgentResponseMessage, len(toolCallMessage.ToolCalls)) for i, toolCall := range toolCallMessage.ToolCalls { fnName := toolCall.Function.Name arguments := toolCall.Function.Arguments - responseMessage := AgentTextMessage{ + responseMessage := AgentResponseMessage{ Role: "tool", Name: fnName, ToolCallId: toolCall.Id, diff --git a/backend/agents/client/tools_test.go b/backend/agents/client/tools_test.go index 41d4072..f4d1d73 100644 --- a/backend/agents/client/tools_test.go +++ b/backend/agents/client/tools_test.go @@ -52,7 +52,7 @@ func (suite *ToolTestSuite) TestSingleToolCall() { require.NoError(err, "Tool call shouldnt return an error") - assert.EqualValues(response, []AgentTextMessage{{ + assert.EqualValues(response, []AgentResponseMessage{{ Role: "tool", Content: "\"return\"", ToolCallId: "1", @@ -111,7 +111,7 @@ func (suite *ToolTestSuite) TestMultipleToolCalls() { require.NoError(err, "Tool call shouldnt return an error") - assert.EqualValues(response, []AgentTextMessage{ + assert.EqualValues(response, []AgentResponseMessage{ { Role: "tool", Content: "\"first-call\"", @@ -169,7 +169,7 @@ func (suite *ToolTestSuite) TestMultipleToolCallsWithErrors() { require.NoError(err, "Tool call shouldnt return an error") - assert.EqualValues(response, []AgentTextMessage{ + assert.EqualValues(response, []AgentResponseMessage{ { Role: "tool", Content: "I will always error", diff --git a/backend/agents/event_location_agent.go b/backend/agents/event_location_agent.go index a1b60cd..dfef53a 100644 --- a/backend/agents/event_location_agent.go +++ b/backend/agents/event_location_agent.go @@ -184,7 +184,7 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID } func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel, contactModel models.ContactModel) (EventLocationAgent, error) { - agentClient, err := client.CreateAgentClient(eventLocationPrompt) + agentClient, err := client.CreateAgentClient() if err != nil { return EventLocationAgent{}, err } diff --git a/backend/agents/note_agent.go b/backend/agents/note_agent.go index 3d62c94..29b4ce8 100644 --- a/backend/agents/note_agent.go +++ b/backend/agents/note_agent.go @@ -66,7 +66,7 @@ func (agent NoteAgent) GetNotes(userId uuid.UUID, imageId uuid.UUID, imageName s } func NewNoteAgent(noteModel models.NoteModel) (NoteAgent, error) { - client, err := client.CreateAgentClient(noteAgentPrompt) + client, err := client.CreateAgentClient() if err != nil { return NoteAgent{}, err } diff --git a/backend/agents/orchestrator.go b/backend/agents/orchestrator.go index 4fe7d3a..392d989 100644 --- a/backend/agents/orchestrator.go +++ b/backend/agents/orchestrator.go @@ -123,7 +123,7 @@ func (agent OrchestratorAgent) Orchestrate(userId uuid.UUID, imageId uuid.UUID, } func NewOrchestratorAgent(eventLocationAgent EventLocationAgent, noteAgent NoteAgent, imageName string, imageData []byte) (OrchestratorAgent, error) { - agent, err := client.CreateAgentClient(orchestratorPrompt) + agent, err := client.CreateAgentClient() if err != nil { return OrchestratorAgent{}, err }