fix(tools): testing and processing

fix
This commit is contained in:
2025-04-09 13:56:03 +01:00
parent d36dec8d60
commit f169fd2ba2
9 changed files with 128 additions and 90 deletions

View File

@ -2,6 +2,7 @@ package client
import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"path/filepath"
@ -49,7 +50,30 @@ const (
type ChatUserMessage struct {
Role UserRole `json:"role"`
MessageContent `json:"content"`
MessageContent `json:"MessageContent"`
}
func (m ChatUserMessage) MarshalJSON() ([]byte, error) {
switch t := m.MessageContent.(type) {
case SingleMessage:
return json.Marshal(&struct {
Role UserRole `json:"role"`
Content string `json:"content"`
}{
Role: User,
Content: t.Content,
})
case ArrayMessage:
return json.Marshal(&struct {
Role UserRole `json:"role"`
Content []ImageMessageContent `json:"content"`
}{
Role: User,
Content: t.Content,
})
}
return []byte{}, errors.New("Unreachable")
}
func (r ChatUserMessage) IsResponse() bool {

View File

@ -0,0 +1,24 @@
package client
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
)
func TestFlatMarshallSingleMessage(t *testing.T) {
require := require.New(t)
message := ChatUserMessage{
Role: User,
MessageContent: SingleMessage{
Content: "Hello",
},
}
json, err := json.Marshal(message)
require.NoError(err)
require.Equal(string(json), "{\"role\":\"user\",\"content\":\"Hello\"}")
}

View File

@ -4,8 +4,8 @@ import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"os"
)
@ -23,7 +23,7 @@ type AgentRequestBody struct {
Tools *any `json:"tools,omitempty"`
ToolChoice *string `json:"tool_choice,omitempty"`
Chat Chat `json:"messages"`
Chat *Chat `json:"messages"`
}
type ResponseChoice struct {
@ -84,8 +84,8 @@ func (client AgentClient) getRequest(body []byte) (*http.Request, error) {
return req, nil
}
func (client AgentClient) Request(chat *Chat) (AgentResponse, error) {
jsonAiRequest, err := json.Marshal(chat)
func (client AgentClient) Request(req *AgentRequestBody) (AgentResponse, error) {
jsonAiRequest, err := json.Marshal(req)
if err != nil {
return AgentResponse{}, err
}
@ -105,6 +105,8 @@ func (client AgentClient) Request(chat *Chat) (AgentResponse, error) {
return AgentResponse{}, err
}
fmt.Println(string(response))
agentResponse := AgentResponse{}
err = json.Unmarshal(response, &agentResponse)
@ -116,15 +118,15 @@ func (client AgentClient) Request(chat *Chat) (AgentResponse, error) {
return AgentResponse{}, errors.New("Unsupported. We currently only accept 1 choice from AI.")
}
chat.AddAiResponse(agentResponse.Choices[0].Message)
req.Chat.AddAiResponse(agentResponse.Choices[0].Message)
return agentResponse, nil
}
func (client AgentClient) Process(info ToolHandlerInfo, chat *Chat) error {
func (client AgentClient) Process(info ToolHandlerInfo, req *AgentRequestBody) error {
var err error
message, err := chat.GetLatest()
message, err := req.Chat.GetLatest()
if err != nil {
return err
}
@ -140,21 +142,10 @@ func (client AgentClient) Process(info ToolHandlerInfo, chat *Chat) error {
}
for _, toolCall := range *aiMessage.ToolCalls {
toolResponse, err := client.ToolHandler.Handle(info, toolCall)
if err != nil {
break
}
toolResponse := client.ToolHandler.Handle(info, toolCall)
chat.AddToolResponse(toolResponse)
_, err = client.Request(chat)
if err != nil {
break
}
req.Chat.AddToolResponse(toolResponse)
}
if err != nil {
log.Println(err)
}
return err
}

View File

@ -24,7 +24,7 @@ var NoToolCallError = errors.New("An assistant tool call with no tool calls was
const NonExistantTool = "This tool does not exist"
func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage ToolCall) (ChatUserToolResponse, error) {
func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage ToolCall) ChatUserToolResponse {
fnName := toolCallMessage.Function.Name
arguments := toolCallMessage.Function.Arguments
@ -36,7 +36,8 @@ func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage ToolCa
fnHandler, exists := handler.handlers[fnName]
if !exists {
return ChatUserToolResponse{}, errors.New(NonExistantTool)
responseMessage.Content = NonExistantTool
return responseMessage
}
res, err := fnHandler.Fn(info, arguments, toolCallMessage)
@ -47,7 +48,7 @@ func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage ToolCa
responseMessage.Content = res
}
return responseMessage, nil
return responseMessage
}
func (handler *ToolsHandlers) AddTool(name string, fn func(info ToolHandlerInfo, args string, call ToolCall) (any, error)) {

View File

@ -12,6 +12,7 @@ type ToolTestSuite struct {
suite.Suite
handler ToolsHandlers
client AgentClient
}
func (suite *ToolTestSuite) SetupTest() {
@ -26,70 +27,44 @@ func (suite *ToolTestSuite) SetupTest() {
suite.handler.AddTool("error", func(info ToolHandlerInfo, args string, call ToolCall) (any, error) {
return false, errors.New("I will always error")
})
suite.client.ToolHandler = suite.handler
}
func (suite *ToolTestSuite) TestSingleToolCall() {
assert := suite.Assert()
require := suite.Require()
response, err := suite.handler.Handle(
response := suite.handler.Handle(
ToolHandlerInfo{
UserId: uuid.Nil,
ImageId: uuid.Nil,
},
AgentAssistantToolCall{
Role: "assistant",
Content: "",
ToolCalls: []ToolCall{{
Index: 0,
Id: "1",
Function: FunctionCall{
Name: "a",
Arguments: "return",
},
}},
ToolCall{
Index: 0,
Id: "1",
Function: FunctionCall{
Name: "a",
Arguments: "return",
},
})
require.NoError(err, "Tool call shouldnt return an error")
assert.EqualValues(response, []AgentResponseMessage{{
require.EqualValues(response, ChatUserToolResponse{
Role: "tool",
Content: "\"return\"",
ToolCallId: "1",
Name: "a",
}})
}
func (suite *ToolTestSuite) TestEmptyCall() {
require := suite.Require()
_, err := suite.handler.Handle(
ToolHandlerInfo{
UserId: uuid.Nil,
ImageId: uuid.Nil,
},
AgentAssistantToolCall{
Role: "assistant",
Content: "",
ToolCalls: []ToolCall{},
})
require.ErrorIs(err, NoToolCallError)
})
}
func (suite *ToolTestSuite) TestMultipleToolCalls() {
assert := suite.Assert()
require := suite.Require()
response, err := suite.handler.Handle(
ToolHandlerInfo{
UserId: uuid.Nil,
ImageId: uuid.Nil,
},
AgentAssistantToolCall{
chat := Chat{
Messages: []ChatMessage{ChatAiMessage{
Role: "assistant",
Content: "",
ToolCalls: []ToolCall{
ToolCalls: &[]ToolCall{
{
Index: 0,
Id: "1",
@ -107,18 +82,27 @@ func (suite *ToolTestSuite) TestMultipleToolCalls() {
},
},
},
}},
}
err := suite.client.Process(
ToolHandlerInfo{
UserId: uuid.Nil,
ImageId: uuid.Nil,
},
&AgentRequestBody{
Chat: &chat,
})
require.NoError(err, "Tool call shouldnt return an error")
assert.EqualValues(response, []AgentResponseMessage{
{
assert.EqualValues(chat.Messages[1:], []ChatMessage{
ChatUserToolResponse{
Role: "tool",
Content: "\"first-call\"",
ToolCallId: "1",
Name: "a",
},
{
ChatUserToolResponse{
Role: "tool",
Content: "\"second-call\"",
ToolCallId: "2",
@ -131,15 +115,11 @@ func (suite *ToolTestSuite) TestMultipleToolCallsWithErrors() {
assert := suite.Assert()
require := suite.Require()
response, err := suite.handler.Handle(
ToolHandlerInfo{
UserId: uuid.Nil,
ImageId: uuid.Nil,
},
AgentAssistantToolCall{
chat := Chat{
Messages: []ChatMessage{ChatAiMessage{
Role: "assistant",
Content: "",
ToolCalls: []ToolCall{
ToolCalls: &[]ToolCall{
{
Index: 0,
Id: "1",
@ -165,24 +145,34 @@ func (suite *ToolTestSuite) TestMultipleToolCallsWithErrors() {
},
},
},
}},
}
err := suite.client.Process(
ToolHandlerInfo{
UserId: uuid.Nil,
ImageId: uuid.Nil,
},
&AgentRequestBody{
Chat: &chat,
})
require.NoError(err, "Tool call shouldnt return an error")
assert.EqualValues(response, []AgentResponseMessage{
{
assert.EqualValues(chat.Messages[1:], []ChatMessage{
ChatUserToolResponse{
Role: "tool",
Content: "I will always error",
ToolCallId: "1",
Name: "error",
},
{
ChatUserToolResponse{
Role: "tool",
Content: "This tool does not exist",
ToolCallId: "2",
Name: "non-existant",
},
{
ChatUserToolResponse{
Role: "tool",
Content: "\"no-error\"",
ToolCallId: "3",
@ -192,5 +182,7 @@ func (suite *ToolTestSuite) TestMultipleToolCallsWithErrors() {
}
func TestToolSuite(t *testing.T) {
suite.Run(t, &ToolTestSuite{})
suite.Run(t, &ToolTestSuite{
client: AgentClient{},
})
}

View File

@ -161,7 +161,7 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID
ResponseFormat: client.ResponseFormat{
Type: "text",
},
Chat: client.Chat{
Chat: &client.Chat{
Messages: make([]client.ChatMessage, 0),
},
}
@ -169,7 +169,7 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID
request.Chat.AddSystem(eventLocationPrompt)
request.Chat.AddImage(imageName, imageData)
_, err = agent.client.Request(&request.Chat)
_, err = agent.client.Request(&request)
if err != nil {
return err
}
@ -179,7 +179,7 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID
UserId: userId,
}
return agent.client.Process(toolHandlerInfo, &request.Chat)
return agent.client.Process(toolHandlerInfo, &request)
}
func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel, contactModel models.ContactModel) (EventLocationAgent, error) {

View File

@ -32,7 +32,7 @@ func (agent NoteAgent) GetNotes(userId uuid.UUID, imageId uuid.UUID, imageName s
ResponseFormat: client.ResponseFormat{
Type: "text",
},
Chat: client.Chat{
Chat: &client.Chat{
Messages: make([]client.ChatMessage, 0),
},
}
@ -40,7 +40,7 @@ func (agent NoteAgent) GetNotes(userId uuid.UUID, imageId uuid.UUID, imageName s
request.Chat.AddSystem(noteAgentPrompt)
request.Chat.AddImage(imageName, imageData)
resp, err := agent.client.Request(&request.Chat)
resp, err := agent.client.Request(&request)
if err != nil {
return err
}

View File

@ -3,6 +3,7 @@ package agents
import (
"encoding/json"
"errors"
"fmt"
"screenmark/screenmark/agents/client"
"github.com/google/uuid"
@ -102,7 +103,7 @@ func (agent OrchestratorAgent) Orchestrate(userId uuid.UUID, imageId uuid.UUID,
ToolChoice: &toolChoice,
Tools: &tools,
Chat: client.Chat{
Chat: &client.Chat{
Messages: make([]client.ChatMessage, 0),
},
}
@ -110,17 +111,19 @@ func (agent OrchestratorAgent) Orchestrate(userId uuid.UUID, imageId uuid.UUID,
request.Chat.AddSystem(orchestratorPrompt)
request.Chat.AddImage(imageName, imageData)
_, err = agent.client.Request(&request.Chat)
res, err := agent.client.Request(&request)
if err != nil {
return err
}
fmt.Println(res)
toolHandlerInfo := client.ToolHandlerInfo{
ImageId: imageId,
UserId: userId,
}
return agent.client.Process(toolHandlerInfo, &request.Chat)
return agent.client.Process(toolHandlerInfo, &request)
}
func NewOrchestratorAgent(eventLocationAgent EventLocationAgent, noteAgent NoteAgent, imageName string, imageData []byte) (OrchestratorAgent, error) {

View File

@ -104,7 +104,10 @@ func main() {
panic(err)
}
orchestrator.Orchestrate(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
err = orchestrator.Orchestrate(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
if err != nil {
fmt.Println(err)
}
}()
}
}