From 5502fc6b19c3d8f6fa40398bd798efe7966a1863 Mon Sep 17 00:00:00 2001 From: John Costa Date: Wed, 9 Apr 2025 12:04:44 +0100 Subject: [PATCH] feat(chat): more simplified chat messages and tool handling --- backend/agents/client/chat.go | 184 +++++++++++++++++++++++ backend/agents/client/client.go | 212 ++++----------------------- backend/agents/client/client_test.go | 40 ----- backend/agents/client/tools.go | 52 +++---- 4 files changed, 234 insertions(+), 254 deletions(-) create mode 100644 backend/agents/client/chat.go delete mode 100644 backend/agents/client/client_test.go diff --git a/backend/agents/client/chat.go b/backend/agents/client/chat.go new file mode 100644 index 0000000..6ab6fcf --- /dev/null +++ b/backend/agents/client/chat.go @@ -0,0 +1,184 @@ +package client + +import ( + "encoding/base64" + "errors" + "fmt" + "path/filepath" +) + +type Chat struct { + Messages []ChatMessage `json:"messages"` +} + +type ChatMessage interface { + IsResponse() bool +} + +// TODO: the role could be inferred from the type. +// This would solve some bugs. + +/* + +Is there a world where this actually becomes the product? +Where we build such a resilient system of AI calls that we +can build some app builder, or even just an API system, +with a fancy UI? + +Manage all the complexity for the user? + +*/ + +// ============================================= +// Messages from us to the AI. +// ============================================= + +type UserRole = string + +const ( + User UserRole = "user" + System UserRole = "system" +) + +type ToolRole = string + +const ( + Tool ToolRole = "tool" +) + +type ChatUserMessage struct { + Role UserRole `json:"role"` + + MessageContent `json:"content"` +} + +func (r ChatUserMessage) IsResponse() bool { + return false +} + +type ChatUserToolResponse struct { + Role ToolRole `json:"role"` + + // The name of the function we are responding to. + Name string `json:"name"` + Content string `json:"content"` + ToolCallId string `json:"tool_call_id"` +} + +func (r ChatUserToolResponse) IsResponse() bool { + return false +} + +type ChatAiMessage struct { + Role string `json:"role"` + ToolCalls *[]ToolCall `json:"tool_calls,omitempty"` + + MessageContent `json:"content"` +} + +func (m ChatAiMessage) IsResponse() bool { + return true +} + +// ============================================= +// Unique interface for message content. +// ============================================= + +type MessageContent interface { + IsSingleMessage() bool +} + +type SingleMessage struct { + Content string `json:"content"` +} + +func (m SingleMessage) IsSingleMessage() bool { + return true +} + +type ArrayMessage struct { + Content []ImageMessageContent `json:"content"` +} + +func (m ArrayMessage) IsSingleMessage() bool { + return false +} + +type ImageMessageContent struct { + ImageType string `json:"type"` + ImageUrl string `json:"image_url"` +} + +type ImageContentUrl struct { + Url string `json:"url"` +} + +// ============================================= +// Adjacent interfaces. +// ============================================= + +type ToolCall struct { + Index int `json:"index"` + Id string `json:"id"` + Function FunctionCall `json:"function"` +} + +type FunctionCall struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +// ============================================= +// Chat methods +// ============================================= + +func (chat *Chat) AddSystem(prompt string) { + chat.Messages = append(chat.Messages, ChatUserMessage{ + Role: System, + MessageContent: SingleMessage{ + Content: prompt, + }, + }) +} + +func (chat *Chat) AddImage(imageName string, image []byte) error { + extension := filepath.Ext(imageName) + if len(extension) == 0 { + // TODO: could also validate for image types we support. + return errors.New("Image does not have extension") + } + + extension = extension[1:] + + encodedString := base64.StdEncoding.EncodeToString(image) + + messageContent := ArrayMessage{ + Content: make([]ImageMessageContent, 1), + } + + messageContent.Content[0] = ImageMessageContent{ + ImageType: "image_url", + ImageUrl: fmt.Sprintf("data:image/%s;base64,%s", extension, encodedString), + } + + arrayMessage := ChatUserMessage{Role: User, MessageContent: messageContent} + chat.Messages = append(chat.Messages, arrayMessage) + + return nil +} + +func (chat *Chat) AddAiResponse(res ChatAiMessage) { + chat.Messages = append(chat.Messages, res) +} + +func (chat *Chat) AddToolResponse(res ChatUserToolResponse) { + chat.Messages = append(chat.Messages, res) +} + +func (chat Chat) GetLatest() (ChatMessage, error) { + if len(chat.Messages) == 0 { + return nil, errors.New("Not enough messages") + } + + return chat.Messages[len(chat.Messages)-1], nil +} diff --git a/backend/agents/client/client.go b/backend/agents/client/client.go index 616d71a..be9590e 100644 --- a/backend/agents/client/client.go +++ b/backend/agents/client/client.go @@ -2,7 +2,6 @@ package client import ( "bytes" - "encoding/base64" "encoding/json" "errors" "fmt" @@ -10,19 +9,8 @@ import ( "log" "net/http" "os" - "path/filepath" - "screenmark/screenmark/.gen/haystack/haystack/model" ) -type ImageInfo struct { - Tags []string `json:"tags"` - Text []string `json:"text"` - Links []string `json:"links"` - - Locations []model.Locations `json:"locations"` - Events []model.Events `json:"events"` -} - type ResponseFormat struct { Type string `json:"type"` JsonSchema any `json:"json_schema"` @@ -36,158 +24,13 @@ type AgentRequestBody struct { Tools *any `json:"tools,omitempty"` ToolChoice *string `json:"tool_choice,omitempty"` - // ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` - - AgentMessages -} - -type AgentMessages struct { - Messages []AgentMessage `json:"messages"` -} - -type AgentMessage interface { - MessageToJson() ([]byte, error) -} - -type AgentResponseMessage struct { - Role string `json:"role"` - Content string `json:"content"` - // Not sure I need this field. - ToolCallId string `json:"tool_call_id,omitempty"` - ToolCalls *[]ToolCall `json:"tool_calls,omitempty"` - Name string `json:"name,omitempty"` -} - -func (textContent AgentResponseMessage) MessageToJson() ([]byte, error) { - // TODO: Validate the `Role`. - return json.Marshal(textContent) -} - -type ToolCall struct { - Index int `json:"index"` - Id string `json:"id"` - Function FunctionCall `json:"function"` -} - -type FunctionCall struct { - Name string `json:"name"` - Arguments string `json:"arguments"` -} - -type AgentAssistantToolCall struct { - Role string `json:"role"` - Content string `json:"content"` - ToolCalls []ToolCall `json:"tool_calls"` -} - -func (toolCall AgentAssistantToolCall) MessageToJson() ([]byte, error) { - return json.Marshal(toolCall) -} - -type AgentArrayMessage struct { - Role string `json:"role"` - Content []AgentContent `json:"content"` -} - -func (arrayContent AgentArrayMessage) MessageToJson() ([]byte, error) { - return json.Marshal(arrayContent) -} - -func (content *AgentMessages) AddResponse(message ResponseChoiceMessage) { - content.Messages = append(content.Messages, message) -} - -func (content *AgentMessages) AddToolCall(toolCall AgentAssistantToolCall) { - content.Messages = append(content.Messages, toolCall) -} - -func (content *AgentMessages) AddImage(imageName string, image []byte) error { - extension := filepath.Ext(imageName) - if len(extension) == 0 { - // TODO: could also validate for image types we support. - return errors.New("Image does not have extension") - } - - extension = extension[1:] - - encodedString := base64.StdEncoding.EncodeToString(image) - - arrayMessage := AgentArrayMessage{Role: ROLE_USER, Content: make([]AgentContent, 1)} - arrayMessage.Content[0] = AgentImage{ - ImageType: IMAGE_TYPE, - ImageUrl: fmt.Sprintf("data:image/%s;base64,%s", extension, encodedString), - } - - content.Messages = append(content.Messages, arrayMessage) - - return nil -} - -func (content *AgentMessages) AddSystem(prompt string) error { - if len(content.Messages) != 0 { - return errors.New("You can only add a system prompt at the beginning") - } - - content.Messages = append(content.Messages, AgentResponseMessage{ - Role: ROLE_SYSTEM, - Content: prompt, - }) - - return nil -} - -// TODO: `AgentMessages` is not really a good name. -// It's a step above that, like a real chat. AgentChat or something. -func (chat *AgentMessages) HandleResponse(response AgentResponse) error { - if len(chat.Messages) == 0 { - return errors.New("This chat doesnt contain any messages therefore cannot be handled.") - } - - for _, choice := range response.Choices { - // TOOD - // if len(choice.Message.ToolCalls) > 0 { - // for _, toolCall := choice.Message.ToolCalls { - // chat.AddToolCall() - // } - // } - chat.AddResponse(choice.Message) - } - - return nil -} - -type AgentContent interface { - ToJson() ([]byte, error) -} - -type ImageUrl struct { - Url string `json:"url"` -} - -type AgentImage struct { - ImageType string `json:"type"` - ImageUrl string `json:"image_url"` -} - -func (imageMessage AgentImage) ToJson() ([]byte, error) { - imageMessage.ImageType = IMAGE_TYPE - return json.Marshal(imageMessage) -} - -type ResponseChoiceMessage struct { - Role string `json:"role"` - Content string `json:"content"` - ToolCalls []ToolCall `json:"tool_calls"` -} - -func (choice ResponseChoiceMessage) MessageToJson() ([]byte, error) { - return json.Marshal(choice) + Chat Chat `json:"messages"` } type ResponseChoice struct { - Index int `json:"index"` - Message ResponseChoiceMessage `json:"message"` - FinishReason string `json:"finish_reason"` + Index int `json:"index"` + Message ChatAiMessage `json:"message"` + FinishReason string `json:"finish_reason"` } type AgentResponse struct { @@ -208,9 +51,6 @@ type AgentClient struct { } const OPENAI_API_KEY = "OPENAI_API_KEY" -const ROLE_USER = "user" -const ROLE_SYSTEM = "system" -const IMAGE_TYPE = "image_url" func CreateAgentClient() (AgentClient, error) { apiKey := os.Getenv(OPENAI_API_KEY) @@ -245,8 +85,8 @@ func (client AgentClient) getRequest(body []byte) (*http.Request, error) { return req, nil } -func (client AgentClient) Request(request *AgentRequestBody) (AgentResponse, error) { - jsonAiRequest, err := json.Marshal(request) +func (client AgentClient) Request(chat *Chat) (AgentResponse, error) { + jsonAiRequest, err := json.Marshal(chat) if err != nil { return AgentResponse{}, err } @@ -273,34 +113,42 @@ func (client AgentClient) Request(request *AgentRequestBody) (AgentResponse, err return AgentResponse{}, err } - toolCalls := agentResponse.Choices[0].Message.ToolCalls - if len(toolCalls) > 0 { - // Should for sure be more flexible. - request.AddToolCall(AgentAssistantToolCall{ - Role: "assistant", - Content: "", - ToolCalls: toolCalls, - }) + if len(agentResponse.Choices) != 1 { + return AgentResponse{}, errors.New("Unsupported. We currently only accept 1 choice from AI.") } + chat.AddAiResponse(agentResponse.Choices[0].Message) + return agentResponse, nil } -func (client AgentClient) Process(info ToolHandlerInfo, request AgentRequestBody) error { +func (client AgentClient) Process(info ToolHandlerInfo, chat *Chat) error { var err error - for { - toolCall, ok := request.Messages[len(request.Messages)-1].(AgentAssistantToolCall) - if !ok { - return errors.New("Latest message isnt a tool call. TODO") - } + message, err := chat.GetLatest() + if err != nil { + return err + } - _, err = client.ToolHandler.Handle(info, toolCall) + aiMessage, ok := message.(ChatAiMessage) + if !ok { + return errors.New("Latest message isnt an AI message") + } + + if aiMessage.ToolCalls == nil { + // Not an error, we just dont have any tool calls to process. + return nil + } + + for _, toolCall := range *aiMessage.ToolCalls { + toolResponse, err := client.ToolHandler.Handle(info, toolCall) if err != nil { break } - _, err = client.Request(&request) + chat.AddToolResponse(toolResponse) + + _, err = client.Request(chat) if err != nil { break } diff --git a/backend/agents/client/client_test.go b/backend/agents/client/client_test.go deleted file mode 100644 index b6b21bf..0000000 --- a/backend/agents/client/client_test.go +++ /dev/null @@ -1,40 +0,0 @@ -package client - -import ( - "testing" - - "github.com/stretchr/testify/require" -) - -func TestSimpleResponse(t *testing.T) { - // assert := assert.New(t) - require := require.New(t) - - chat := AgentMessages{ - Messages: make([]AgentMessage, 0), - } - - chat.AddSystem("system message") - - err := chat.HandleResponse(AgentResponse{ - Id: "0", - Object: "chat.completion", - Created: 1, - Choices: []ResponseChoice{{ - Index: 0, - Message: ResponseChoiceMessage{ - Role: "assistant", - Content: "some basic content", - }, - FinishReason: "", - }}, - }) - - require.NoError(err) - require.Len(chat.Messages, 2) - - require.EqualValues(chat.Messages[1], ResponseChoiceMessage{ - Role: "assistant", - Content: "some basic content", - }) -} diff --git a/backend/agents/client/tools.go b/backend/agents/client/tools.go index 4acaaef..4a54958 100644 --- a/backend/agents/client/tools.go +++ b/backend/agents/client/tools.go @@ -24,42 +24,30 @@ var NoToolCallError = errors.New("An assistant tool call with no tool calls was const NonExistantTool = "This tool does not exist" -func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage AgentAssistantToolCall) ([]AgentResponseMessage, error) { - if len(toolCallMessage.ToolCalls) == 0 { - return []AgentResponseMessage{}, NoToolCallError +func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage ToolCall) (ChatUserToolResponse, error) { + fnName := toolCallMessage.Function.Name + arguments := toolCallMessage.Function.Arguments + + responseMessage := ChatUserToolResponse{ + Role: "tool", + Name: fnName, + ToolCallId: toolCallMessage.Id, } - responses := make([]AgentResponseMessage, len(toolCallMessage.ToolCalls)) - - for i, toolCall := range toolCallMessage.ToolCalls { - fnName := toolCall.Function.Name - arguments := toolCall.Function.Arguments - - responseMessage := AgentResponseMessage{ - Role: "tool", - Name: fnName, - ToolCallId: toolCall.Id, - } - - fnHandler, exists := handler.handlers[fnName] - if !exists { - responseMessage.Content = NonExistantTool - responses[i] = responseMessage - continue - } - - res, err := fnHandler.Fn(info, arguments, toolCallMessage.ToolCalls[0]) - - if err != nil { - responseMessage.Content = err.Error() - } else { - responseMessage.Content = res - } - - responses[i] = responseMessage + fnHandler, exists := handler.handlers[fnName] + if !exists { + return ChatUserToolResponse{}, errors.New(NonExistantTool) } - return responses, nil + res, err := fnHandler.Fn(info, arguments, toolCallMessage) + + if err != nil { + responseMessage.Content = err.Error() + } else { + responseMessage.Content = res + } + + return responseMessage, nil } func (handler *ToolsHandlers) AddTool(name string, fn func(info ToolHandlerInfo, args string, call ToolCall) (any, error)) {