fix(tools): testing and processing

fix
This commit is contained in:
2025-04-09 13:56:03 +01:00
parent d36dec8d60
commit f169fd2ba2
9 changed files with 128 additions and 90 deletions

View File

@ -2,6 +2,7 @@ package client
import ( import (
"encoding/base64" "encoding/base64"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"path/filepath" "path/filepath"
@ -49,7 +50,30 @@ const (
type ChatUserMessage struct { type ChatUserMessage struct {
Role UserRole `json:"role"` Role UserRole `json:"role"`
MessageContent `json:"content"` MessageContent `json:"MessageContent"`
}
func (m ChatUserMessage) MarshalJSON() ([]byte, error) {
switch t := m.MessageContent.(type) {
case SingleMessage:
return json.Marshal(&struct {
Role UserRole `json:"role"`
Content string `json:"content"`
}{
Role: User,
Content: t.Content,
})
case ArrayMessage:
return json.Marshal(&struct {
Role UserRole `json:"role"`
Content []ImageMessageContent `json:"content"`
}{
Role: User,
Content: t.Content,
})
}
return []byte{}, errors.New("Unreachable")
} }
func (r ChatUserMessage) IsResponse() bool { func (r ChatUserMessage) IsResponse() bool {

View File

@ -0,0 +1,24 @@
package client
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
)
func TestFlatMarshallSingleMessage(t *testing.T) {
require := require.New(t)
message := ChatUserMessage{
Role: User,
MessageContent: SingleMessage{
Content: "Hello",
},
}
json, err := json.Marshal(message)
require.NoError(err)
require.Equal(string(json), "{\"role\":\"user\",\"content\":\"Hello\"}")
}

View File

@ -4,8 +4,8 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"io" "io"
"log"
"net/http" "net/http"
"os" "os"
) )
@ -23,7 +23,7 @@ type AgentRequestBody struct {
Tools *any `json:"tools,omitempty"` Tools *any `json:"tools,omitempty"`
ToolChoice *string `json:"tool_choice,omitempty"` ToolChoice *string `json:"tool_choice,omitempty"`
Chat Chat `json:"messages"` Chat *Chat `json:"messages"`
} }
type ResponseChoice struct { type ResponseChoice struct {
@ -84,8 +84,8 @@ func (client AgentClient) getRequest(body []byte) (*http.Request, error) {
return req, nil return req, nil
} }
func (client AgentClient) Request(chat *Chat) (AgentResponse, error) { func (client AgentClient) Request(req *AgentRequestBody) (AgentResponse, error) {
jsonAiRequest, err := json.Marshal(chat) jsonAiRequest, err := json.Marshal(req)
if err != nil { if err != nil {
return AgentResponse{}, err return AgentResponse{}, err
} }
@ -105,6 +105,8 @@ func (client AgentClient) Request(chat *Chat) (AgentResponse, error) {
return AgentResponse{}, err return AgentResponse{}, err
} }
fmt.Println(string(response))
agentResponse := AgentResponse{} agentResponse := AgentResponse{}
err = json.Unmarshal(response, &agentResponse) err = json.Unmarshal(response, &agentResponse)
@ -116,15 +118,15 @@ func (client AgentClient) Request(chat *Chat) (AgentResponse, error) {
return AgentResponse{}, errors.New("Unsupported. We currently only accept 1 choice from AI.") return AgentResponse{}, errors.New("Unsupported. We currently only accept 1 choice from AI.")
} }
chat.AddAiResponse(agentResponse.Choices[0].Message) req.Chat.AddAiResponse(agentResponse.Choices[0].Message)
return agentResponse, nil return agentResponse, nil
} }
func (client AgentClient) Process(info ToolHandlerInfo, chat *Chat) error { func (client AgentClient) Process(info ToolHandlerInfo, req *AgentRequestBody) error {
var err error var err error
message, err := chat.GetLatest() message, err := req.Chat.GetLatest()
if err != nil { if err != nil {
return err return err
} }
@ -140,21 +142,10 @@ func (client AgentClient) Process(info ToolHandlerInfo, chat *Chat) error {
} }
for _, toolCall := range *aiMessage.ToolCalls { for _, toolCall := range *aiMessage.ToolCalls {
toolResponse, err := client.ToolHandler.Handle(info, toolCall) toolResponse := client.ToolHandler.Handle(info, toolCall)
if err != nil {
break
}
chat.AddToolResponse(toolResponse) req.Chat.AddToolResponse(toolResponse)
_, err = client.Request(chat)
if err != nil {
break
}
} }
if err != nil {
log.Println(err)
}
return err return err
} }

View File

@ -24,7 +24,7 @@ var NoToolCallError = errors.New("An assistant tool call with no tool calls was
const NonExistantTool = "This tool does not exist" const NonExistantTool = "This tool does not exist"
func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage ToolCall) (ChatUserToolResponse, error) { func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage ToolCall) ChatUserToolResponse {
fnName := toolCallMessage.Function.Name fnName := toolCallMessage.Function.Name
arguments := toolCallMessage.Function.Arguments arguments := toolCallMessage.Function.Arguments
@ -36,7 +36,8 @@ func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage ToolCa
fnHandler, exists := handler.handlers[fnName] fnHandler, exists := handler.handlers[fnName]
if !exists { if !exists {
return ChatUserToolResponse{}, errors.New(NonExistantTool) responseMessage.Content = NonExistantTool
return responseMessage
} }
res, err := fnHandler.Fn(info, arguments, toolCallMessage) res, err := fnHandler.Fn(info, arguments, toolCallMessage)
@ -47,7 +48,7 @@ func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage ToolCa
responseMessage.Content = res responseMessage.Content = res
} }
return responseMessage, nil return responseMessage
} }
func (handler *ToolsHandlers) AddTool(name string, fn func(info ToolHandlerInfo, args string, call ToolCall) (any, error)) { func (handler *ToolsHandlers) AddTool(name string, fn func(info ToolHandlerInfo, args string, call ToolCall) (any, error)) {

View File

@ -12,6 +12,7 @@ type ToolTestSuite struct {
suite.Suite suite.Suite
handler ToolsHandlers handler ToolsHandlers
client AgentClient
} }
func (suite *ToolTestSuite) SetupTest() { func (suite *ToolTestSuite) SetupTest() {
@ -26,70 +27,44 @@ func (suite *ToolTestSuite) SetupTest() {
suite.handler.AddTool("error", func(info ToolHandlerInfo, args string, call ToolCall) (any, error) { suite.handler.AddTool("error", func(info ToolHandlerInfo, args string, call ToolCall) (any, error) {
return false, errors.New("I will always error") return false, errors.New("I will always error")
}) })
suite.client.ToolHandler = suite.handler
} }
func (suite *ToolTestSuite) TestSingleToolCall() { func (suite *ToolTestSuite) TestSingleToolCall() {
assert := suite.Assert()
require := suite.Require() require := suite.Require()
response, err := suite.handler.Handle( response := suite.handler.Handle(
ToolHandlerInfo{ ToolHandlerInfo{
UserId: uuid.Nil, UserId: uuid.Nil,
ImageId: uuid.Nil, ImageId: uuid.Nil,
}, },
AgentAssistantToolCall{ ToolCall{
Role: "assistant", Index: 0,
Content: "", Id: "1",
ToolCalls: []ToolCall{{ Function: FunctionCall{
Index: 0, Name: "a",
Id: "1", Arguments: "return",
Function: FunctionCall{ },
Name: "a",
Arguments: "return",
},
}},
}) })
require.NoError(err, "Tool call shouldnt return an error") require.EqualValues(response, ChatUserToolResponse{
assert.EqualValues(response, []AgentResponseMessage{{
Role: "tool", Role: "tool",
Content: "\"return\"", Content: "\"return\"",
ToolCallId: "1", ToolCallId: "1",
Name: "a", Name: "a",
}}) })
}
func (suite *ToolTestSuite) TestEmptyCall() {
require := suite.Require()
_, err := suite.handler.Handle(
ToolHandlerInfo{
UserId: uuid.Nil,
ImageId: uuid.Nil,
},
AgentAssistantToolCall{
Role: "assistant",
Content: "",
ToolCalls: []ToolCall{},
})
require.ErrorIs(err, NoToolCallError)
} }
func (suite *ToolTestSuite) TestMultipleToolCalls() { func (suite *ToolTestSuite) TestMultipleToolCalls() {
assert := suite.Assert() assert := suite.Assert()
require := suite.Require() require := suite.Require()
response, err := suite.handler.Handle( chat := Chat{
ToolHandlerInfo{ Messages: []ChatMessage{ChatAiMessage{
UserId: uuid.Nil,
ImageId: uuid.Nil,
},
AgentAssistantToolCall{
Role: "assistant", Role: "assistant",
Content: "", Content: "",
ToolCalls: []ToolCall{ ToolCalls: &[]ToolCall{
{ {
Index: 0, Index: 0,
Id: "1", Id: "1",
@ -107,18 +82,27 @@ func (suite *ToolTestSuite) TestMultipleToolCalls() {
}, },
}, },
}, },
}},
}
err := suite.client.Process(
ToolHandlerInfo{
UserId: uuid.Nil,
ImageId: uuid.Nil,
},
&AgentRequestBody{
Chat: &chat,
}) })
require.NoError(err, "Tool call shouldnt return an error") require.NoError(err, "Tool call shouldnt return an error")
assert.EqualValues(chat.Messages[1:], []ChatMessage{
assert.EqualValues(response, []AgentResponseMessage{ ChatUserToolResponse{
{
Role: "tool", Role: "tool",
Content: "\"first-call\"", Content: "\"first-call\"",
ToolCallId: "1", ToolCallId: "1",
Name: "a", Name: "a",
}, },
{ ChatUserToolResponse{
Role: "tool", Role: "tool",
Content: "\"second-call\"", Content: "\"second-call\"",
ToolCallId: "2", ToolCallId: "2",
@ -131,15 +115,11 @@ func (suite *ToolTestSuite) TestMultipleToolCallsWithErrors() {
assert := suite.Assert() assert := suite.Assert()
require := suite.Require() require := suite.Require()
response, err := suite.handler.Handle( chat := Chat{
ToolHandlerInfo{ Messages: []ChatMessage{ChatAiMessage{
UserId: uuid.Nil,
ImageId: uuid.Nil,
},
AgentAssistantToolCall{
Role: "assistant", Role: "assistant",
Content: "", Content: "",
ToolCalls: []ToolCall{ ToolCalls: &[]ToolCall{
{ {
Index: 0, Index: 0,
Id: "1", Id: "1",
@ -165,24 +145,34 @@ func (suite *ToolTestSuite) TestMultipleToolCallsWithErrors() {
}, },
}, },
}, },
}},
}
err := suite.client.Process(
ToolHandlerInfo{
UserId: uuid.Nil,
ImageId: uuid.Nil,
},
&AgentRequestBody{
Chat: &chat,
}) })
require.NoError(err, "Tool call shouldnt return an error") require.NoError(err, "Tool call shouldnt return an error")
assert.EqualValues(response, []AgentResponseMessage{ assert.EqualValues(chat.Messages[1:], []ChatMessage{
{ ChatUserToolResponse{
Role: "tool", Role: "tool",
Content: "I will always error", Content: "I will always error",
ToolCallId: "1", ToolCallId: "1",
Name: "error", Name: "error",
}, },
{ ChatUserToolResponse{
Role: "tool", Role: "tool",
Content: "This tool does not exist", Content: "This tool does not exist",
ToolCallId: "2", ToolCallId: "2",
Name: "non-existant", Name: "non-existant",
}, },
{ ChatUserToolResponse{
Role: "tool", Role: "tool",
Content: "\"no-error\"", Content: "\"no-error\"",
ToolCallId: "3", ToolCallId: "3",
@ -192,5 +182,7 @@ func (suite *ToolTestSuite) TestMultipleToolCallsWithErrors() {
} }
func TestToolSuite(t *testing.T) { func TestToolSuite(t *testing.T) {
suite.Run(t, &ToolTestSuite{}) suite.Run(t, &ToolTestSuite{
client: AgentClient{},
})
} }

View File

@ -161,7 +161,7 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID
ResponseFormat: client.ResponseFormat{ ResponseFormat: client.ResponseFormat{
Type: "text", Type: "text",
}, },
Chat: client.Chat{ Chat: &client.Chat{
Messages: make([]client.ChatMessage, 0), Messages: make([]client.ChatMessage, 0),
}, },
} }
@ -169,7 +169,7 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID
request.Chat.AddSystem(eventLocationPrompt) request.Chat.AddSystem(eventLocationPrompt)
request.Chat.AddImage(imageName, imageData) request.Chat.AddImage(imageName, imageData)
_, err = agent.client.Request(&request.Chat) _, err = agent.client.Request(&request)
if err != nil { if err != nil {
return err return err
} }
@ -179,7 +179,7 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID
UserId: userId, UserId: userId,
} }
return agent.client.Process(toolHandlerInfo, &request.Chat) return agent.client.Process(toolHandlerInfo, &request)
} }
func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel, contactModel models.ContactModel) (EventLocationAgent, error) { func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel, contactModel models.ContactModel) (EventLocationAgent, error) {

View File

@ -32,7 +32,7 @@ func (agent NoteAgent) GetNotes(userId uuid.UUID, imageId uuid.UUID, imageName s
ResponseFormat: client.ResponseFormat{ ResponseFormat: client.ResponseFormat{
Type: "text", Type: "text",
}, },
Chat: client.Chat{ Chat: &client.Chat{
Messages: make([]client.ChatMessage, 0), Messages: make([]client.ChatMessage, 0),
}, },
} }
@ -40,7 +40,7 @@ func (agent NoteAgent) GetNotes(userId uuid.UUID, imageId uuid.UUID, imageName s
request.Chat.AddSystem(noteAgentPrompt) request.Chat.AddSystem(noteAgentPrompt)
request.Chat.AddImage(imageName, imageData) request.Chat.AddImage(imageName, imageData)
resp, err := agent.client.Request(&request.Chat) resp, err := agent.client.Request(&request)
if err != nil { if err != nil {
return err return err
} }

View File

@ -3,6 +3,7 @@ package agents
import ( import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"screenmark/screenmark/agents/client" "screenmark/screenmark/agents/client"
"github.com/google/uuid" "github.com/google/uuid"
@ -102,7 +103,7 @@ func (agent OrchestratorAgent) Orchestrate(userId uuid.UUID, imageId uuid.UUID,
ToolChoice: &toolChoice, ToolChoice: &toolChoice,
Tools: &tools, Tools: &tools,
Chat: client.Chat{ Chat: &client.Chat{
Messages: make([]client.ChatMessage, 0), Messages: make([]client.ChatMessage, 0),
}, },
} }
@ -110,17 +111,19 @@ func (agent OrchestratorAgent) Orchestrate(userId uuid.UUID, imageId uuid.UUID,
request.Chat.AddSystem(orchestratorPrompt) request.Chat.AddSystem(orchestratorPrompt)
request.Chat.AddImage(imageName, imageData) request.Chat.AddImage(imageName, imageData)
_, err = agent.client.Request(&request.Chat) res, err := agent.client.Request(&request)
if err != nil { if err != nil {
return err return err
} }
fmt.Println(res)
toolHandlerInfo := client.ToolHandlerInfo{ toolHandlerInfo := client.ToolHandlerInfo{
ImageId: imageId, ImageId: imageId,
UserId: userId, UserId: userId,
} }
return agent.client.Process(toolHandlerInfo, &request.Chat) return agent.client.Process(toolHandlerInfo, &request)
} }
func NewOrchestratorAgent(eventLocationAgent EventLocationAgent, noteAgent NoteAgent, imageName string, imageData []byte) (OrchestratorAgent, error) { func NewOrchestratorAgent(eventLocationAgent EventLocationAgent, noteAgent NoteAgent, imageName string, imageData []byte) (OrchestratorAgent, error) {

View File

@ -104,7 +104,10 @@ func main() {
panic(err) panic(err)
} }
orchestrator.Orchestrate(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image) err = orchestrator.Orchestrate(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
if err != nil {
fmt.Println(err)
}
}() }()
} }
} }