fixup(chat): better way to organize agent messages and tool calls

This commit is contained in:
2025-04-06 20:24:40 +01:00
parent 5c5df168ad
commit 26c6edb6ba
7 changed files with 87 additions and 23 deletions

View File

@ -49,14 +49,16 @@ type AgentMessage interface {
MessageToJson() ([]byte, error)
}
type AgentTextMessage struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCallId string `json:"tool_call_id,omitempty"`
Name string `json:"name,omitempty"`
type AgentResponseMessage struct {
Role string `json:"role"`
Content string `json:"content"`
// Not sure I need this field.
ToolCallId string `json:"tool_call_id,omitempty"`
ToolCalls *[]ToolCall `json:"tool_calls,omitempty"`
Name string `json:"name,omitempty"`
}
func (textContent AgentTextMessage) MessageToJson() ([]byte, error) {
func (textContent AgentResponseMessage) MessageToJson() ([]byte, error) {
// TODO: Validate the `Role`.
return json.Marshal(textContent)
}
@ -91,7 +93,7 @@ func (arrayContent AgentArrayMessage) MessageToJson() ([]byte, error) {
return json.Marshal(arrayContent)
}
func (content *AgentMessages) AddText(message AgentTextMessage) {
func (content *AgentMessages) AddResponse(message ResponseChoiceMessage) {
content.Messages = append(content.Messages, message)
}
@ -126,7 +128,7 @@ func (content *AgentMessages) AddSystem(prompt string) error {
return errors.New("You can only add a system prompt at the beginning")
}
content.Messages = append(content.Messages, AgentTextMessage{
content.Messages = append(content.Messages, AgentResponseMessage{
Role: ROLE_SYSTEM,
Content: prompt,
})
@ -134,6 +136,26 @@ func (content *AgentMessages) AddSystem(prompt string) error {
return nil
}
// TODO: `AgentMessages` is not really a good name.
// It's a step above that, like a real chat. AgentChat or something.
func (chat *AgentMessages) HandleResponse(response AgentResponse) error {
if len(chat.Messages) == 0 {
return errors.New("This chat doesnt contain any messages therefore cannot be handled.")
}
for _, choice := range response.Choices {
// TOOD
// if len(choice.Message.ToolCalls) > 0 {
// for _, toolCall := choice.Message.ToolCalls {
// chat.AddToolCall()
// }
// }
chat.AddResponse(choice.Message)
}
return nil
}
type AgentContent interface {
ToJson() ([]byte, error)
}
@ -158,6 +180,10 @@ type ResponseChoiceMessage struct {
ToolCalls []ToolCall `json:"tool_calls"`
}
func (choice ResponseChoiceMessage) MessageToJson() ([]byte, error) {
return json.Marshal(choice)
}
type ResponseChoice struct {
Index int `json:"index"`
Message ResponseChoiceMessage `json:"message"`
@ -174,7 +200,6 @@ type AgentResponse struct {
type AgentClient struct {
url string
apiKey string
systemPrompt string
responseFormat string
ToolHandler ToolsHandlers
@ -187,7 +212,7 @@ const ROLE_USER = "user"
const ROLE_SYSTEM = "system"
const IMAGE_TYPE = "image_url"
func CreateAgentClient(prompt string) (AgentClient, error) {
func CreateAgentClient() (AgentClient, error) {
apiKey := os.Getenv(OPENAI_API_KEY)
if len(apiKey) == 0 {
@ -195,9 +220,8 @@ func CreateAgentClient(prompt string) (AgentClient, error) {
}
return AgentClient{
apiKey: apiKey,
url: "https://api.mistral.ai/v1/chat/completions",
systemPrompt: prompt,
apiKey: apiKey,
url: "https://api.mistral.ai/v1/chat/completions",
Do: func(req *http.Request) (*http.Response, error) {
client := &http.Client{}
return client.Do(req)

View File

@ -0,0 +1,40 @@
package client
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestSimpleResponse(t *testing.T) {
// assert := assert.New(t)
require := require.New(t)
chat := AgentMessages{
Messages: make([]AgentMessage, 0),
}
chat.AddSystem("system message")
err := chat.HandleResponse(AgentResponse{
Id: "0",
Object: "chat.completion",
Created: 1,
Choices: []ResponseChoice{{
Index: 0,
Message: ResponseChoiceMessage{
Role: "assistant",
Content: "some basic content",
},
FinishReason: "",
}},
})
require.NoError(err)
require.Len(chat.Messages, 2)
require.EqualValues(chat.Messages[1], ResponseChoiceMessage{
Role: "assistant",
Content: "some basic content",
})
}

View File

@ -24,18 +24,18 @@ var NoToolCallError = errors.New("An assistant tool call with no tool calls was
const NonExistantTool = "This tool does not exist"
func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage AgentAssistantToolCall) ([]AgentTextMessage, error) {
func (handler ToolsHandlers) Handle(info ToolHandlerInfo, toolCallMessage AgentAssistantToolCall) ([]AgentResponseMessage, error) {
if len(toolCallMessage.ToolCalls) == 0 {
return []AgentTextMessage{}, NoToolCallError
return []AgentResponseMessage{}, NoToolCallError
}
responses := make([]AgentTextMessage, len(toolCallMessage.ToolCalls))
responses := make([]AgentResponseMessage, len(toolCallMessage.ToolCalls))
for i, toolCall := range toolCallMessage.ToolCalls {
fnName := toolCall.Function.Name
arguments := toolCall.Function.Arguments
responseMessage := AgentTextMessage{
responseMessage := AgentResponseMessage{
Role: "tool",
Name: fnName,
ToolCallId: toolCall.Id,

View File

@ -52,7 +52,7 @@ func (suite *ToolTestSuite) TestSingleToolCall() {
require.NoError(err, "Tool call shouldnt return an error")
assert.EqualValues(response, []AgentTextMessage{{
assert.EqualValues(response, []AgentResponseMessage{{
Role: "tool",
Content: "\"return\"",
ToolCallId: "1",
@ -111,7 +111,7 @@ func (suite *ToolTestSuite) TestMultipleToolCalls() {
require.NoError(err, "Tool call shouldnt return an error")
assert.EqualValues(response, []AgentTextMessage{
assert.EqualValues(response, []AgentResponseMessage{
{
Role: "tool",
Content: "\"first-call\"",
@ -169,7 +169,7 @@ func (suite *ToolTestSuite) TestMultipleToolCallsWithErrors() {
require.NoError(err, "Tool call shouldnt return an error")
assert.EqualValues(response, []AgentTextMessage{
assert.EqualValues(response, []AgentResponseMessage{
{
Role: "tool",
Content: "I will always error",

View File

@ -184,7 +184,7 @@ func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID
}
func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel, contactModel models.ContactModel) (EventLocationAgent, error) {
agentClient, err := client.CreateAgentClient(eventLocationPrompt)
agentClient, err := client.CreateAgentClient()
if err != nil {
return EventLocationAgent{}, err
}

View File

@ -66,7 +66,7 @@ func (agent NoteAgent) GetNotes(userId uuid.UUID, imageId uuid.UUID, imageName s
}
func NewNoteAgent(noteModel models.NoteModel) (NoteAgent, error) {
client, err := client.CreateAgentClient(noteAgentPrompt)
client, err := client.CreateAgentClient()
if err != nil {
return NoteAgent{}, err
}

View File

@ -123,7 +123,7 @@ func (agent OrchestratorAgent) Orchestrate(userId uuid.UUID, imageId uuid.UUID,
}
func NewOrchestratorAgent(eventLocationAgent EventLocationAgent, noteAgent NoteAgent, imageName string, imageData []byte) (OrchestratorAgent, error) {
agent, err := client.CreateAgentClient(orchestratorPrompt)
agent, err := client.CreateAgentClient()
if err != nil {
return OrchestratorAgent{}, err
}