fixup(chat): better way to organize agent messages and tool calls
This commit is contained in:
@ -49,14 +49,16 @@ type AgentMessage interface {
|
||||
MessageToJson() ([]byte, error)
|
||||
}
|
||||
|
||||
type AgentTextMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ToolCallId string `json:"tool_call_id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
type AgentResponseMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
// Not sure I need this field.
|
||||
ToolCallId string `json:"tool_call_id,omitempty"`
|
||||
ToolCalls *[]ToolCall `json:"tool_calls,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
func (textContent AgentTextMessage) MessageToJson() ([]byte, error) {
|
||||
func (textContent AgentResponseMessage) MessageToJson() ([]byte, error) {
|
||||
// TODO: Validate the `Role`.
|
||||
return json.Marshal(textContent)
|
||||
}
|
||||
@ -91,7 +93,7 @@ func (arrayContent AgentArrayMessage) MessageToJson() ([]byte, error) {
|
||||
return json.Marshal(arrayContent)
|
||||
}
|
||||
|
||||
func (content *AgentMessages) AddText(message AgentTextMessage) {
|
||||
func (content *AgentMessages) AddResponse(message ResponseChoiceMessage) {
|
||||
content.Messages = append(content.Messages, message)
|
||||
}
|
||||
|
||||
@ -126,7 +128,7 @@ func (content *AgentMessages) AddSystem(prompt string) error {
|
||||
return errors.New("You can only add a system prompt at the beginning")
|
||||
}
|
||||
|
||||
content.Messages = append(content.Messages, AgentTextMessage{
|
||||
content.Messages = append(content.Messages, AgentResponseMessage{
|
||||
Role: ROLE_SYSTEM,
|
||||
Content: prompt,
|
||||
})
|
||||
@ -134,6 +136,26 @@ func (content *AgentMessages) AddSystem(prompt string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: `AgentMessages` is not really a good name.
|
||||
// It's a step above that, like a real chat. AgentChat or something.
|
||||
func (chat *AgentMessages) HandleResponse(response AgentResponse) error {
|
||||
if len(chat.Messages) == 0 {
|
||||
return errors.New("This chat doesnt contain any messages therefore cannot be handled.")
|
||||
}
|
||||
|
||||
for _, choice := range response.Choices {
|
||||
// TOOD
|
||||
// if len(choice.Message.ToolCalls) > 0 {
|
||||
// for _, toolCall := choice.Message.ToolCalls {
|
||||
// chat.AddToolCall()
|
||||
// }
|
||||
// }
|
||||
chat.AddResponse(choice.Message)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type AgentContent interface {
|
||||
ToJson() ([]byte, error)
|
||||
}
|
||||
@ -158,6 +180,10 @@ type ResponseChoiceMessage struct {
|
||||
ToolCalls []ToolCall `json:"tool_calls"`
|
||||
}
|
||||
|
||||
func (choice ResponseChoiceMessage) MessageToJson() ([]byte, error) {
|
||||
return json.Marshal(choice)
|
||||
}
|
||||
|
||||
type ResponseChoice struct {
|
||||
Index int `json:"index"`
|
||||
Message ResponseChoiceMessage `json:"message"`
|
||||
@ -174,7 +200,6 @@ type AgentResponse struct {
|
||||
type AgentClient struct {
|
||||
url string
|
||||
apiKey string
|
||||
systemPrompt string
|
||||
responseFormat string
|
||||
|
||||
ToolHandler ToolsHandlers
|
||||
@ -187,7 +212,7 @@ const ROLE_USER = "user"
|
||||
const ROLE_SYSTEM = "system"
|
||||
const IMAGE_TYPE = "image_url"
|
||||
|
||||
func CreateAgentClient(prompt string) (AgentClient, error) {
|
||||
func CreateAgentClient() (AgentClient, error) {
|
||||
apiKey := os.Getenv(OPENAI_API_KEY)
|
||||
|
||||
if len(apiKey) == 0 {
|
||||
@ -195,9 +220,8 @@ func CreateAgentClient(prompt string) (AgentClient, error) {
|
||||
}
|
||||
|
||||
return AgentClient{
|
||||
apiKey: apiKey,
|
||||
url: "https://api.mistral.ai/v1/chat/completions",
|
||||
systemPrompt: prompt,
|
||||
apiKey: apiKey,
|
||||
url: "https://api.mistral.ai/v1/chat/completions",
|
||||
Do: func(req *http.Request) (*http.Response, error) {
|
||||
client := &http.Client{}
|
||||
return client.Do(req)
|
||||
|
40
backend/agents/client/client_test.go
Normal file
40
backend/agents/client/client_test.go
Normal file
@ -0,0 +1,40 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSimpleResponse(t *testing.T) {
|
||||
// assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
chat := AgentMessages{
|
||||
Messages: make([]AgentMessage, 0),
|
||||
}
|
||||
|
||||
chat.AddSystem("system message")
|
||||
|
||||
err := chat.HandleResponse(AgentResponse{
|
||||
Id: "0",
|
||||
Object: "chat.completion",
|
||||
Created: 1,
|
||||
Choices: []ResponseChoice{{
|
||||
Index: 0,
|
||||
Message: ResponseChoiceMessage{
|
||||
Role: "assistant",
|
||||
Content: "some basic content",
|
||||
},
|
||||
FinishReason: "",
|
||||
}},
|
||||
})
|
||||
|
||||
require.NoError(err)
|
||||
require.Len(chat.Messages, 2)
|
||||
|
||||
require.EqualValues(chat.Messages[1], ResponseChoiceMessage{
|
||||
Role: "assistant",
|
||||
Content: "some basic content",
|
||||
})
|
||||
}
|
@ -24,18 +24,18 @@ var NoToolCallError = errors.New("An assistant tool call with no tool calls was
|
||||
|
||||
const NonExistantTool = "This tool does not exist"
|
||||
|
||||
func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage AgentAssistantToolCall) ([]AgentTextMessage, error) {
|
||||
func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage AgentAssistantToolCall) ([]AgentResponseMessage, error) {
|
||||
if len(toolCallMessage.ToolCalls) == 0 {
|
||||
return []AgentTextMessage{}, NoToolCallError
|
||||
return []AgentResponseMessage{}, NoToolCallError
|
||||
}
|
||||
|
||||
responses := make([]AgentTextMessage, len(toolCallMessage.ToolCalls))
|
||||
responses := make([]AgentResponseMessage, len(toolCallMessage.ToolCalls))
|
||||
|
||||
for i, toolCall := range toolCallMessage.ToolCalls {
|
||||
fnName := toolCall.Function.Name
|
||||
arguments := toolCall.Function.Arguments
|
||||
|
||||
responseMessage := AgentTextMessage{
|
||||
responseMessage := AgentResponseMessage{
|
||||
Role: "tool",
|
||||
Name: fnName,
|
||||
ToolCallId: toolCall.Id,
|
||||
|
@ -52,7 +52,7 @@ func (suite *ToolTestSuite) TestSingleToolCall() {
|
||||
|
||||
require.NoError(err, "Tool call shouldnt return an error")
|
||||
|
||||
assert.EqualValues(response, []AgentTextMessage{{
|
||||
assert.EqualValues(response, []AgentResponseMessage{{
|
||||
Role: "tool",
|
||||
Content: "\"return\"",
|
||||
ToolCallId: "1",
|
||||
@ -111,7 +111,7 @@ func (suite *ToolTestSuite) TestMultipleToolCalls() {
|
||||
|
||||
require.NoError(err, "Tool call shouldnt return an error")
|
||||
|
||||
assert.EqualValues(response, []AgentTextMessage{
|
||||
assert.EqualValues(response, []AgentResponseMessage{
|
||||
{
|
||||
Role: "tool",
|
||||
Content: "\"first-call\"",
|
||||
@ -169,7 +169,7 @@ func (suite *ToolTestSuite) TestMultipleToolCallsWithErrors() {
|
||||
|
||||
require.NoError(err, "Tool call shouldnt return an error")
|
||||
|
||||
assert.EqualValues(response, []AgentTextMessage{
|
||||
assert.EqualValues(response, []AgentResponseMessage{
|
||||
{
|
||||
Role: "tool",
|
||||
Content: "I will always error",
|
||||
|
@ -184,7 +184,7 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID
|
||||
}
|
||||
|
||||
func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel, contactModel models.ContactModel) (EventLocationAgent, error) {
|
||||
agentClient, err := client.CreateAgentClient(eventLocationPrompt)
|
||||
agentClient, err := client.CreateAgentClient()
|
||||
if err != nil {
|
||||
return EventLocationAgent{}, err
|
||||
}
|
||||
|
@ -66,7 +66,7 @@ func (agent NoteAgent) GetNotes(userId uuid.UUID, imageId uuid.UUID, imageName s
|
||||
}
|
||||
|
||||
func NewNoteAgent(noteModel models.NoteModel) (NoteAgent, error) {
|
||||
client, err := client.CreateAgentClient(noteAgentPrompt)
|
||||
client, err := client.CreateAgentClient()
|
||||
if err != nil {
|
||||
return NoteAgent{}, err
|
||||
}
|
||||
|
@ -123,7 +123,7 @@ func (agent OrchestratorAgent) Orchestrate(userId uuid.UUID, imageId uuid.UUID,
|
||||
}
|
||||
|
||||
func NewOrchestratorAgent(eventLocationAgent EventLocationAgent, noteAgent NoteAgent, imageName string, imageData []byte) (OrchestratorAgent, error) {
|
||||
agent, err := client.CreateAgentClient(orchestratorPrompt)
|
||||
agent, err := client.CreateAgentClient()
|
||||
if err != nil {
|
||||
return OrchestratorAgent{}, err
|
||||
}
|
||||
|
Reference in New Issue
Block a user