diff --git a/backend/agents/client/chat.go b/backend/agents/client/chat.go index 6628e56..c6dff19 100644 --- a/backend/agents/client/chat.go +++ b/backend/agents/client/chat.go @@ -65,8 +65,8 @@ func (m ChatUserMessage) MarshalJSON() ([]byte, error) { }) case ArrayMessage: return json.Marshal(&struct { - Role UserRole `json:"role"` - Content []ImageMessageContent `json:"content"` + Role UserRole `json:"role"` + Content []MessageContentMessage `json:"content"` }{ Role: User, Content: t.Content, @@ -121,18 +121,35 @@ func (m SingleMessage) IsSingleMessage() bool { } type ArrayMessage struct { - Content []ImageMessageContent `json:"content"` + Content []MessageContentMessage `json:"content"` } func (m ArrayMessage) IsSingleMessage() bool { return false } +type MessageContentMessage interface { + IsImageMessage() bool +} + +type TextMessageContent struct { + TextType string `json:"type"` + Text string `json:"text"` +} + +func (m TextMessageContent) IsImageMessage() bool { + return false +} + type ImageMessageContent struct { ImageType string `json:"type"` ImageUrl string `json:"image_url"` } +func (m ImageMessageContent) IsImageMessage() bool { + return true +} + type ImageContentUrl struct { Url string `json:"url"` } @@ -165,7 +182,7 @@ func (chat *Chat) AddSystem(prompt string) { }) } -func (chat *Chat) AddImage(imageName string, image []byte) error { +func (chat *Chat) AddImage(imageName string, image []byte, query *string) error { extension := filepath.Ext(imageName) if len(extension) == 0 { // TODO: could also validate for image types we support. @@ -173,14 +190,28 @@ func (chat *Chat) AddImage(imageName string, image []byte) error { } extension = extension[1:] - encodedString := base64.StdEncoding.EncodeToString(image) - messageContent := ArrayMessage{ - Content: make([]ImageMessageContent, 1), + contentLength := 1 + if query != nil { + contentLength = 2 } - messageContent.Content[0] = ImageMessageContent{ + messageContent := ArrayMessage{ + Content: make([]MessageContentMessage, contentLength), + } + + index := 0 + + if query != nil { + messageContent.Content[index] = TextMessageContent{ + TextType: "text", + Text: *query, + } + index += 1 + } + + messageContent.Content[index] = ImageMessageContent{ ImageType: "image_url", ImageUrl: fmt.Sprintf("data:image/%s;base64,%s", extension, encodedString), } diff --git a/backend/agents/client/client.go b/backend/agents/client/client.go index 244d539..850d358 100644 --- a/backend/agents/client/client.go +++ b/backend/agents/client/client.go @@ -220,7 +220,7 @@ func (client AgentClient) Process(info ToolHandlerInfo, req *AgentRequestBody) e return err } -func (client AgentClient) RunAgent(systemPrompt string, jsonTools string, endToolCall string, userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error { +func (client AgentClient) RunAgent(systemPrompt string, jsonTools string, endToolCall string, query *string, userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error { var tools any err := json.Unmarshal([]byte(jsonTools), &tools) @@ -241,7 +241,7 @@ func (client AgentClient) RunAgent(systemPrompt string, jsonTools string, endToo } request.Chat.AddSystem(systemPrompt) - request.Chat.AddImage(imageName, imageData) + request.Chat.AddImage(imageName, imageData, query) _, err = client.Request(&request) if err != nil { @@ -251,6 +251,7 @@ func (client AgentClient) RunAgent(systemPrompt string, jsonTools string, endToo toolHandlerInfo := ToolHandlerInfo{ ImageId: imageId, UserId: userId, + Image: &imageData, } return client.ToolLoop(toolHandlerInfo, &request) diff --git a/backend/agents/client/tools.go b/backend/agents/client/tools.go index 8dc9244..8061311 100644 --- a/backend/agents/client/tools.go +++ b/backend/agents/client/tools.go @@ -10,6 +10,9 @@ import ( type ToolHandlerInfo struct { UserId uuid.UUID ImageId uuid.UUID + + // Pointer because we don't want to copy this around too much. + Image *[]byte } type ToolHandler struct {