fix(tools): testing and processing
fix
This commit is contained in:
@ -2,6 +2,7 @@ package client
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
@ -49,7 +50,30 @@ const (
|
||||
type ChatUserMessage struct {
|
||||
Role UserRole `json:"role"`
|
||||
|
||||
MessageContent `json:"content"`
|
||||
MessageContent `json:"MessageContent"`
|
||||
}
|
||||
|
||||
func (m ChatUserMessage) MarshalJSON() ([]byte, error) {
|
||||
switch t := m.MessageContent.(type) {
|
||||
case SingleMessage:
|
||||
return json.Marshal(&struct {
|
||||
Role UserRole `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}{
|
||||
Role: User,
|
||||
Content: t.Content,
|
||||
})
|
||||
case ArrayMessage:
|
||||
return json.Marshal(&struct {
|
||||
Role UserRole `json:"role"`
|
||||
Content []ImageMessageContent `json:"content"`
|
||||
}{
|
||||
Role: User,
|
||||
Content: t.Content,
|
||||
})
|
||||
}
|
||||
|
||||
return []byte{}, errors.New("Unreachable")
|
||||
}
|
||||
|
||||
func (r ChatUserMessage) IsResponse() bool {
|
||||
|
24
backend/agents/client/chat_test.go
Normal file
24
backend/agents/client/chat_test.go
Normal file
@ -0,0 +1,24 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFlatMarshallSingleMessage(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
message := ChatUserMessage{
|
||||
Role: User,
|
||||
MessageContent: SingleMessage{
|
||||
Content: "Hello",
|
||||
},
|
||||
}
|
||||
|
||||
json, err := json.Marshal(message)
|
||||
require.NoError(err)
|
||||
|
||||
require.Equal(string(json), "{\"role\":\"user\",\"content\":\"Hello\"}")
|
||||
}
|
@ -4,8 +4,8 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
)
|
||||
@ -23,7 +23,7 @@ type AgentRequestBody struct {
|
||||
Tools *any `json:"tools,omitempty"`
|
||||
ToolChoice *string `json:"tool_choice,omitempty"`
|
||||
|
||||
Chat Chat `json:"messages"`
|
||||
Chat *Chat `json:"messages"`
|
||||
}
|
||||
|
||||
type ResponseChoice struct {
|
||||
@ -84,8 +84,8 @@ func (client AgentClient) getRequest(body []byte) (*http.Request, error) {
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (client AgentClient) Request(chat *Chat) (AgentResponse, error) {
|
||||
jsonAiRequest, err := json.Marshal(chat)
|
||||
func (client AgentClient) Request(req *AgentRequestBody) (AgentResponse, error) {
|
||||
jsonAiRequest, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return AgentResponse{}, err
|
||||
}
|
||||
@ -105,6 +105,8 @@ func (client AgentClient) Request(chat *Chat) (AgentResponse, error) {
|
||||
return AgentResponse{}, err
|
||||
}
|
||||
|
||||
fmt.Println(string(response))
|
||||
|
||||
agentResponse := AgentResponse{}
|
||||
err = json.Unmarshal(response, &agentResponse)
|
||||
|
||||
@ -116,15 +118,15 @@ func (client AgentClient) Request(chat *Chat) (AgentResponse, error) {
|
||||
return AgentResponse{}, errors.New("Unsupported. We currently only accept 1 choice from AI.")
|
||||
}
|
||||
|
||||
chat.AddAiResponse(agentResponse.Choices[0].Message)
|
||||
req.Chat.AddAiResponse(agentResponse.Choices[0].Message)
|
||||
|
||||
return agentResponse, nil
|
||||
}
|
||||
|
||||
func (client AgentClient) Process(info ToolHandlerInfo, chat *Chat) error {
|
||||
func (client AgentClient) Process(info ToolHandlerInfo, req *AgentRequestBody) error {
|
||||
var err error
|
||||
|
||||
message, err := chat.GetLatest()
|
||||
message, err := req.Chat.GetLatest()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -140,21 +142,10 @@ func (client AgentClient) Process(info ToolHandlerInfo, chat *Chat) error {
|
||||
}
|
||||
|
||||
for _, toolCall := range *aiMessage.ToolCalls {
|
||||
toolResponse, err := client.ToolHandler.Handle(info, toolCall)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
toolResponse := client.ToolHandler.Handle(info, toolCall)
|
||||
|
||||
chat.AddToolResponse(toolResponse)
|
||||
|
||||
_, err = client.Request(chat)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
req.Chat.AddToolResponse(toolResponse)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
@ -24,7 +24,7 @@ var NoToolCallError = errors.New("An assistant tool call with no tool calls was
|
||||
|
||||
const NonExistantTool = "This tool does not exist"
|
||||
|
||||
func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage ToolCall) (ChatUserToolResponse, error) {
|
||||
func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage ToolCall) ChatUserToolResponse {
|
||||
fnName := toolCallMessage.Function.Name
|
||||
arguments := toolCallMessage.Function.Arguments
|
||||
|
||||
@ -36,7 +36,8 @@ func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage ToolCa
|
||||
|
||||
fnHandler, exists := handler.handlers[fnName]
|
||||
if !exists {
|
||||
return ChatUserToolResponse{}, errors.New(NonExistantTool)
|
||||
responseMessage.Content = NonExistantTool
|
||||
return responseMessage
|
||||
}
|
||||
|
||||
res, err := fnHandler.Fn(info, arguments, toolCallMessage)
|
||||
@ -47,7 +48,7 @@ func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage ToolCa
|
||||
responseMessage.Content = res
|
||||
}
|
||||
|
||||
return responseMessage, nil
|
||||
return responseMessage
|
||||
}
|
||||
|
||||
func (handler *ToolsHandlers) AddTool(name string, fn func(info ToolHandlerInfo, args string, call ToolCall) (any, error)) {
|
||||
|
@ -12,6 +12,7 @@ type ToolTestSuite struct {
|
||||
suite.Suite
|
||||
|
||||
handler ToolsHandlers
|
||||
client AgentClient
|
||||
}
|
||||
|
||||
func (suite *ToolTestSuite) SetupTest() {
|
||||
@ -26,70 +27,44 @@ func (suite *ToolTestSuite) SetupTest() {
|
||||
suite.handler.AddTool("error", func(info ToolHandlerInfo, args string, call ToolCall) (any, error) {
|
||||
return false, errors.New("I will always error")
|
||||
})
|
||||
|
||||
suite.client.ToolHandler = suite.handler
|
||||
}
|
||||
|
||||
func (suite *ToolTestSuite) TestSingleToolCall() {
|
||||
assert := suite.Assert()
|
||||
require := suite.Require()
|
||||
|
||||
response, err := suite.handler.Handle(
|
||||
response := suite.handler.Handle(
|
||||
ToolHandlerInfo{
|
||||
UserId: uuid.Nil,
|
||||
ImageId: uuid.Nil,
|
||||
},
|
||||
AgentAssistantToolCall{
|
||||
Role: "assistant",
|
||||
Content: "",
|
||||
ToolCalls: []ToolCall{{
|
||||
Index: 0,
|
||||
Id: "1",
|
||||
Function: FunctionCall{
|
||||
Name: "a",
|
||||
Arguments: "return",
|
||||
},
|
||||
}},
|
||||
ToolCall{
|
||||
Index: 0,
|
||||
Id: "1",
|
||||
Function: FunctionCall{
|
||||
Name: "a",
|
||||
Arguments: "return",
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(err, "Tool call shouldnt return an error")
|
||||
|
||||
assert.EqualValues(response, []AgentResponseMessage{{
|
||||
require.EqualValues(response, ChatUserToolResponse{
|
||||
Role: "tool",
|
||||
Content: "\"return\"",
|
||||
ToolCallId: "1",
|
||||
Name: "a",
|
||||
}})
|
||||
}
|
||||
|
||||
func (suite *ToolTestSuite) TestEmptyCall() {
|
||||
require := suite.Require()
|
||||
|
||||
_, err := suite.handler.Handle(
|
||||
ToolHandlerInfo{
|
||||
UserId: uuid.Nil,
|
||||
ImageId: uuid.Nil,
|
||||
},
|
||||
AgentAssistantToolCall{
|
||||
Role: "assistant",
|
||||
Content: "",
|
||||
ToolCalls: []ToolCall{},
|
||||
})
|
||||
|
||||
require.ErrorIs(err, NoToolCallError)
|
||||
})
|
||||
}
|
||||
|
||||
func (suite *ToolTestSuite) TestMultipleToolCalls() {
|
||||
assert := suite.Assert()
|
||||
require := suite.Require()
|
||||
|
||||
response, err := suite.handler.Handle(
|
||||
ToolHandlerInfo{
|
||||
UserId: uuid.Nil,
|
||||
ImageId: uuid.Nil,
|
||||
},
|
||||
AgentAssistantToolCall{
|
||||
chat := Chat{
|
||||
Messages: []ChatMessage{ChatAiMessage{
|
||||
Role: "assistant",
|
||||
Content: "",
|
||||
ToolCalls: []ToolCall{
|
||||
ToolCalls: &[]ToolCall{
|
||||
{
|
||||
Index: 0,
|
||||
Id: "1",
|
||||
@ -107,18 +82,27 @@ func (suite *ToolTestSuite) TestMultipleToolCalls() {
|
||||
},
|
||||
},
|
||||
},
|
||||
}},
|
||||
}
|
||||
|
||||
err := suite.client.Process(
|
||||
ToolHandlerInfo{
|
||||
UserId: uuid.Nil,
|
||||
ImageId: uuid.Nil,
|
||||
},
|
||||
&AgentRequestBody{
|
||||
Chat: &chat,
|
||||
})
|
||||
|
||||
require.NoError(err, "Tool call shouldnt return an error")
|
||||
|
||||
assert.EqualValues(response, []AgentResponseMessage{
|
||||
{
|
||||
assert.EqualValues(chat.Messages[1:], []ChatMessage{
|
||||
ChatUserToolResponse{
|
||||
Role: "tool",
|
||||
Content: "\"first-call\"",
|
||||
ToolCallId: "1",
|
||||
Name: "a",
|
||||
},
|
||||
{
|
||||
ChatUserToolResponse{
|
||||
Role: "tool",
|
||||
Content: "\"second-call\"",
|
||||
ToolCallId: "2",
|
||||
@ -131,15 +115,11 @@ func (suite *ToolTestSuite) TestMultipleToolCallsWithErrors() {
|
||||
assert := suite.Assert()
|
||||
require := suite.Require()
|
||||
|
||||
response, err := suite.handler.Handle(
|
||||
ToolHandlerInfo{
|
||||
UserId: uuid.Nil,
|
||||
ImageId: uuid.Nil,
|
||||
},
|
||||
AgentAssistantToolCall{
|
||||
chat := Chat{
|
||||
Messages: []ChatMessage{ChatAiMessage{
|
||||
Role: "assistant",
|
||||
Content: "",
|
||||
ToolCalls: []ToolCall{
|
||||
ToolCalls: &[]ToolCall{
|
||||
{
|
||||
Index: 0,
|
||||
Id: "1",
|
||||
@ -165,24 +145,34 @@ func (suite *ToolTestSuite) TestMultipleToolCallsWithErrors() {
|
||||
},
|
||||
},
|
||||
},
|
||||
}},
|
||||
}
|
||||
|
||||
err := suite.client.Process(
|
||||
ToolHandlerInfo{
|
||||
UserId: uuid.Nil,
|
||||
ImageId: uuid.Nil,
|
||||
},
|
||||
&AgentRequestBody{
|
||||
Chat: &chat,
|
||||
})
|
||||
|
||||
require.NoError(err, "Tool call shouldnt return an error")
|
||||
|
||||
assert.EqualValues(response, []AgentResponseMessage{
|
||||
{
|
||||
assert.EqualValues(chat.Messages[1:], []ChatMessage{
|
||||
ChatUserToolResponse{
|
||||
Role: "tool",
|
||||
Content: "I will always error",
|
||||
ToolCallId: "1",
|
||||
Name: "error",
|
||||
},
|
||||
{
|
||||
ChatUserToolResponse{
|
||||
Role: "tool",
|
||||
Content: "This tool does not exist",
|
||||
ToolCallId: "2",
|
||||
Name: "non-existant",
|
||||
},
|
||||
{
|
||||
ChatUserToolResponse{
|
||||
Role: "tool",
|
||||
Content: "\"no-error\"",
|
||||
ToolCallId: "3",
|
||||
@ -192,5 +182,7 @@ func (suite *ToolTestSuite) TestMultipleToolCallsWithErrors() {
|
||||
}
|
||||
|
||||
func TestToolSuite(t *testing.T) {
|
||||
suite.Run(t, &ToolTestSuite{})
|
||||
suite.Run(t, &ToolTestSuite{
|
||||
client: AgentClient{},
|
||||
})
|
||||
}
|
||||
|
@ -161,7 +161,7 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID
|
||||
ResponseFormat: client.ResponseFormat{
|
||||
Type: "text",
|
||||
},
|
||||
Chat: client.Chat{
|
||||
Chat: &client.Chat{
|
||||
Messages: make([]client.ChatMessage, 0),
|
||||
},
|
||||
}
|
||||
@ -169,7 +169,7 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID
|
||||
request.Chat.AddSystem(eventLocationPrompt)
|
||||
request.Chat.AddImage(imageName, imageData)
|
||||
|
||||
_, err = agent.client.Request(&request.Chat)
|
||||
_, err = agent.client.Request(&request)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -179,7 +179,7 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID
|
||||
UserId: userId,
|
||||
}
|
||||
|
||||
return agent.client.Process(toolHandlerInfo, &request.Chat)
|
||||
return agent.client.Process(toolHandlerInfo, &request)
|
||||
}
|
||||
|
||||
func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel, contactModel models.ContactModel) (EventLocationAgent, error) {
|
||||
|
@ -32,7 +32,7 @@ func (agent NoteAgent) GetNotes(userId uuid.UUID, imageId uuid.UUID, imageName s
|
||||
ResponseFormat: client.ResponseFormat{
|
||||
Type: "text",
|
||||
},
|
||||
Chat: client.Chat{
|
||||
Chat: &client.Chat{
|
||||
Messages: make([]client.ChatMessage, 0),
|
||||
},
|
||||
}
|
||||
@ -40,7 +40,7 @@ func (agent NoteAgent) GetNotes(userId uuid.UUID, imageId uuid.UUID, imageName s
|
||||
request.Chat.AddSystem(noteAgentPrompt)
|
||||
request.Chat.AddImage(imageName, imageData)
|
||||
|
||||
resp, err := agent.client.Request(&request.Chat)
|
||||
resp, err := agent.client.Request(&request)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ package agents
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"screenmark/screenmark/agents/client"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@ -102,7 +103,7 @@ func (agent OrchestratorAgent) Orchestrate(userId uuid.UUID, imageId uuid.UUID,
|
||||
ToolChoice: &toolChoice,
|
||||
Tools: &tools,
|
||||
|
||||
Chat: client.Chat{
|
||||
Chat: &client.Chat{
|
||||
Messages: make([]client.ChatMessage, 0),
|
||||
},
|
||||
}
|
||||
@ -110,17 +111,19 @@ func (agent OrchestratorAgent) Orchestrate(userId uuid.UUID, imageId uuid.UUID,
|
||||
request.Chat.AddSystem(orchestratorPrompt)
|
||||
request.Chat.AddImage(imageName, imageData)
|
||||
|
||||
_, err = agent.client.Request(&request.Chat)
|
||||
res, err := agent.client.Request(&request)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Println(res)
|
||||
|
||||
toolHandlerInfo := client.ToolHandlerInfo{
|
||||
ImageId: imageId,
|
||||
UserId: userId,
|
||||
}
|
||||
|
||||
return agent.client.Process(toolHandlerInfo, &request.Chat)
|
||||
return agent.client.Process(toolHandlerInfo, &request)
|
||||
}
|
||||
|
||||
func NewOrchestratorAgent(eventLocationAgent EventLocationAgent, noteAgent NoteAgent, imageName string, imageData []byte) (OrchestratorAgent, error) {
|
||||
|
@ -104,7 +104,10 @@ func main() {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
orchestrator.Orchestrate(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
|
||||
err = orchestrator.Orchestrate(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user