fixup(chat): better way to organize agent messages and tool calls

This commit is contained in:
2025-04-06 20:24:40 +01:00
parent 5c5df168ad
commit 26c6edb6ba
7 changed files with 87 additions and 23 deletions

View File

@ -49,14 +49,16 @@ type AgentMessage interface {
MessageToJson() ([]byte, error) MessageToJson() ([]byte, error)
} }
type AgentTextMessage struct { type AgentResponseMessage struct {
Role string `json:"role"` Role string `json:"role"`
Content string `json:"content"` Content string `json:"content"`
// Not sure I need this field.
ToolCallId string `json:"tool_call_id,omitempty"` ToolCallId string `json:"tool_call_id,omitempty"`
ToolCalls *[]ToolCall `json:"tool_calls,omitempty"`
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
} }
func (textContent AgentTextMessage) MessageToJson() ([]byte, error) { func (textContent AgentResponseMessage) MessageToJson() ([]byte, error) {
// TODO: Validate the `Role`. // TODO: Validate the `Role`.
return json.Marshal(textContent) return json.Marshal(textContent)
} }
@ -91,7 +93,7 @@ func (arrayContent AgentArrayMessage) MessageToJson() ([]byte, error) {
return json.Marshal(arrayContent) return json.Marshal(arrayContent)
} }
func (content *AgentMessages) AddText(message AgentTextMessage) { func (content *AgentMessages) AddResponse(message ResponseChoiceMessage) {
content.Messages = append(content.Messages, message) content.Messages = append(content.Messages, message)
} }
@ -126,7 +128,7 @@ func (content *AgentMessages) AddSystem(prompt string) error {
return errors.New("You can only add a system prompt at the beginning") return errors.New("You can only add a system prompt at the beginning")
} }
content.Messages = append(content.Messages, AgentTextMessage{ content.Messages = append(content.Messages, AgentResponseMessage{
Role: ROLE_SYSTEM, Role: ROLE_SYSTEM,
Content: prompt, Content: prompt,
}) })
@ -134,6 +136,26 @@ func (content *AgentMessages) AddSystem(prompt string) error {
return nil return nil
} }
// TODO: `AgentMessages` is not really a good name.
// It's a step above that, like a real chat. AgentChat or something.
func (chat *AgentMessages) HandleResponse(response AgentResponse) error {
if len(chat.Messages) == 0 {
return errors.New("This chat doesnt contain any messages therefore cannot be handled.")
}
for _, choice := range response.Choices {
// TOOD
// if len(choice.Message.ToolCalls) > 0 {
// for _, toolCall := choice.Message.ToolCalls {
// chat.AddToolCall()
// }
// }
chat.AddResponse(choice.Message)
}
return nil
}
type AgentContent interface { type AgentContent interface {
ToJson() ([]byte, error) ToJson() ([]byte, error)
} }
@ -158,6 +180,10 @@ type ResponseChoiceMessage struct {
ToolCalls []ToolCall `json:"tool_calls"` ToolCalls []ToolCall `json:"tool_calls"`
} }
func (choice ResponseChoiceMessage) MessageToJson() ([]byte, error) {
return json.Marshal(choice)
}
type ResponseChoice struct { type ResponseChoice struct {
Index int `json:"index"` Index int `json:"index"`
Message ResponseChoiceMessage `json:"message"` Message ResponseChoiceMessage `json:"message"`
@ -174,7 +200,6 @@ type AgentResponse struct {
type AgentClient struct { type AgentClient struct {
url string url string
apiKey string apiKey string
systemPrompt string
responseFormat string responseFormat string
ToolHandler ToolsHandlers ToolHandler ToolsHandlers
@ -187,7 +212,7 @@ const ROLE_USER = "user"
const ROLE_SYSTEM = "system" const ROLE_SYSTEM = "system"
const IMAGE_TYPE = "image_url" const IMAGE_TYPE = "image_url"
func CreateAgentClient(prompt string) (AgentClient, error) { func CreateAgentClient() (AgentClient, error) {
apiKey := os.Getenv(OPENAI_API_KEY) apiKey := os.Getenv(OPENAI_API_KEY)
if len(apiKey) == 0 { if len(apiKey) == 0 {
@ -197,7 +222,6 @@ func CreateAgentClient(prompt string) (AgentClient, error) {
return AgentClient{ return AgentClient{
apiKey: apiKey, apiKey: apiKey,
url: "https://api.mistral.ai/v1/chat/completions", url: "https://api.mistral.ai/v1/chat/completions",
systemPrompt: prompt,
Do: func(req *http.Request) (*http.Response, error) { Do: func(req *http.Request) (*http.Response, error) {
client := &http.Client{} client := &http.Client{}
return client.Do(req) return client.Do(req)

View File

@ -0,0 +1,40 @@
package client
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestSimpleResponse(t *testing.T) {
// assert := assert.New(t)
require := require.New(t)
chat := AgentMessages{
Messages: make([]AgentMessage, 0),
}
chat.AddSystem("system message")
err := chat.HandleResponse(AgentResponse{
Id: "0",
Object: "chat.completion",
Created: 1,
Choices: []ResponseChoice{{
Index: 0,
Message: ResponseChoiceMessage{
Role: "assistant",
Content: "some basic content",
},
FinishReason: "",
}},
})
require.NoError(err)
require.Len(chat.Messages, 2)
require.EqualValues(chat.Messages[1], ResponseChoiceMessage{
Role: "assistant",
Content: "some basic content",
})
}

View File

@ -24,18 +24,18 @@ var NoToolCallError = errors.New("An assistant tool call with no tool calls was
const NonExistantTool = "This tool does not exist" const NonExistantTool = "This tool does not exist"
func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage AgentAssistantToolCall) ([]AgentTextMessage, error) { func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage AgentAssistantToolCall) ([]AgentResponseMessage, error) {
if len(toolCallMessage.ToolCalls) == 0 { if len(toolCallMessage.ToolCalls) == 0 {
return []AgentTextMessage{}, NoToolCallError return []AgentResponseMessage{}, NoToolCallError
} }
responses := make([]AgentTextMessage, len(toolCallMessage.ToolCalls)) responses := make([]AgentResponseMessage, len(toolCallMessage.ToolCalls))
for i, toolCall := range toolCallMessage.ToolCalls { for i, toolCall := range toolCallMessage.ToolCalls {
fnName := toolCall.Function.Name fnName := toolCall.Function.Name
arguments := toolCall.Function.Arguments arguments := toolCall.Function.Arguments
responseMessage := AgentTextMessage{ responseMessage := AgentResponseMessage{
Role: "tool", Role: "tool",
Name: fnName, Name: fnName,
ToolCallId: toolCall.Id, ToolCallId: toolCall.Id,

View File

@ -52,7 +52,7 @@ func (suite *ToolTestSuite) TestSingleToolCall() {
require.NoError(err, "Tool call shouldnt return an error") require.NoError(err, "Tool call shouldnt return an error")
assert.EqualValues(response, []AgentTextMessage{{ assert.EqualValues(response, []AgentResponseMessage{{
Role: "tool", Role: "tool",
Content: "\"return\"", Content: "\"return\"",
ToolCallId: "1", ToolCallId: "1",
@ -111,7 +111,7 @@ func (suite *ToolTestSuite) TestMultipleToolCalls() {
require.NoError(err, "Tool call shouldnt return an error") require.NoError(err, "Tool call shouldnt return an error")
assert.EqualValues(response, []AgentTextMessage{ assert.EqualValues(response, []AgentResponseMessage{
{ {
Role: "tool", Role: "tool",
Content: "\"first-call\"", Content: "\"first-call\"",
@ -169,7 +169,7 @@ func (suite *ToolTestSuite) TestMultipleToolCallsWithErrors() {
require.NoError(err, "Tool call shouldnt return an error") require.NoError(err, "Tool call shouldnt return an error")
assert.EqualValues(response, []AgentTextMessage{ assert.EqualValues(response, []AgentResponseMessage{
{ {
Role: "tool", Role: "tool",
Content: "I will always error", Content: "I will always error",

View File

@ -184,7 +184,7 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID
} }
func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel, contactModel models.ContactModel) (EventLocationAgent, error) { func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel, contactModel models.ContactModel) (EventLocationAgent, error) {
agentClient, err := client.CreateAgentClient(eventLocationPrompt) agentClient, err := client.CreateAgentClient()
if err != nil { if err != nil {
return EventLocationAgent{}, err return EventLocationAgent{}, err
} }

View File

@ -66,7 +66,7 @@ func (agent NoteAgent) GetNotes(userId uuid.UUID, imageId uuid.UUID, imageName s
} }
func NewNoteAgent(noteModel models.NoteModel) (NoteAgent, error) { func NewNoteAgent(noteModel models.NoteModel) (NoteAgent, error) {
client, err := client.CreateAgentClient(noteAgentPrompt) client, err := client.CreateAgentClient()
if err != nil { if err != nil {
return NoteAgent{}, err return NoteAgent{}, err
} }

View File

@ -123,7 +123,7 @@ func (agent OrchestratorAgent) Orchestrate(userId uuid.UUID, imageId uuid.UUID,
} }
func NewOrchestratorAgent(eventLocationAgent EventLocationAgent, noteAgent NoteAgent, imageName string, imageData []byte) (OrchestratorAgent, error) { func NewOrchestratorAgent(eventLocationAgent EventLocationAgent, noteAgent NoteAgent, imageName string, imageData []byte) (OrchestratorAgent, error) {
agent, err := client.CreateAgentClient(orchestratorPrompt) agent, err := client.CreateAgentClient()
if err != nil { if err != nil {
return OrchestratorAgent{}, err return OrchestratorAgent{}, err
} }