From 61c158d5b64a009db962552adccb85f5627d44f4 Mon Sep 17 00:00:00 2001 From: John Costa Date: Thu, 17 Apr 2025 10:58:19 +0100 Subject: [PATCH] refactor(agents): encapsulating prompt and calls inside factory method --- backend/agents/client/client.go | 30 +++++++++++++++++++++--------- backend/agents/contact_agent.go | 10 ++++++---- backend/agents/event_agent.go | 18 ++++++++++-------- backend/agents/location_agent.go | 11 ++++++----- backend/agents/note_agent.go | 7 +++---- backend/agents/orchestrator.go | 21 +++++++++++---------- backend/events.go | 2 +- 7 files changed, 58 insertions(+), 41 deletions(-) diff --git a/backend/agents/client/client.go b/backend/agents/client/client.go index c08b65d..29370c1 100644 --- a/backend/agents/client/client.go +++ b/backend/agents/client/client.go @@ -76,15 +76,25 @@ type AgentClient struct { Reply string Do func(req *http.Request) (*http.Response, error) + + Options CreateAgentClientOptions } const OPENAI_API_KEY = "OPENAI_API_KEY" -func CreateAgentClient(log *log.Logger) (AgentClient, error) { +type CreateAgentClientOptions struct { + Log *log.Logger + SystemPrompt string + JsonTools string + EndToolCall string + Query *string +} + +func CreateAgentClient(options CreateAgentClientOptions) AgentClient { apiKey := os.Getenv(OPENAI_API_KEY) if len(apiKey) == 0 { - return AgentClient{}, errors.New(OPENAI_API_KEY + " was not found.") + panic("No api key") } return AgentClient{ @@ -95,12 +105,14 @@ func CreateAgentClient(log *log.Logger) (AgentClient, error) { return client.Do(req) }, - Log: log, + Log: options.Log, ToolHandler: ToolsHandlers{ handlers: map[string]ToolHandler{}, }, - }, nil + + Options: options, + } } func (client AgentClient) getRequest(body []byte) (*http.Request, error) { @@ -226,9 +238,9 @@ func (client *AgentClient) Process(info ToolHandlerInfo, req *AgentRequestBody) return err } -func (client AgentClient) RunAgent(systemPrompt string, jsonTools string, endToolCall string, query *string, userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error { +func (client AgentClient) RunAgent(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error { var tools any - err := json.Unmarshal([]byte(jsonTools), &tools) + err := json.Unmarshal([]byte(client.Options.JsonTools), &tools) toolChoice := "any" @@ -237,7 +249,7 @@ func (client AgentClient) RunAgent(systemPrompt string, jsonTools string, endToo ToolChoice: &toolChoice, Model: "pixtral-12b-2409", Temperature: 0.3, - EndToolCall: endToolCall, + EndToolCall: client.Options.EndToolCall, ResponseFormat: ResponseFormat{ Type: "text", }, @@ -246,8 +258,8 @@ func (client AgentClient) RunAgent(systemPrompt string, jsonTools string, endToo }, } - request.Chat.AddSystem(systemPrompt) - request.Chat.AddImage(imageName, imageData, query) + request.Chat.AddSystem(client.Options.SystemPrompt) + request.Chat.AddImage(imageName, imageData, client.Options.Query) _, err = client.Request(&request) if err != nil { diff --git a/backend/agents/contact_agent.go b/backend/agents/contact_agent.go index eacbb28..2b105df 100644 --- a/backend/agents/contact_agent.go +++ b/backend/agents/contact_agent.go @@ -91,10 +91,12 @@ type linkContactArguments struct { } func NewContactAgent(log *log.Logger, contactModel models.ContactModel) (client.AgentClient, error) { - agentClient, err := client.CreateAgentClient(log) - if err != nil { - return client.AgentClient{}, err - } + agentClient := client.CreateAgentClient(client.CreateAgentClientOptions{ + SystemPrompt: contactPrompt, + JsonTools: contactTools, + Log: log, + EndToolCall: "finish", + }) agentClient.ToolHandler.AddTool("listContacts", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { return contactModel.List(context.Background(), info.UserId) diff --git a/backend/agents/event_agent.go b/backend/agents/event_agent.go index 712690d..150c77e 100644 --- a/backend/agents/event_agent.go +++ b/backend/agents/event_agent.go @@ -109,11 +109,12 @@ type linkEventArguments struct { } func NewEventAgent(log *log.Logger, eventsModel models.EventModel, locationAgent client.AgentClient) (client.AgentClient, error) { - agentClient, err := client.CreateAgentClient(log) - - if err != nil { - return client.AgentClient{}, err - } + agentClient := client.CreateAgentClient(client.CreateAgentClientOptions{ + SystemPrompt: eventPrompt, + JsonTools: eventTools, + Log: log, + EndToolCall: "finish", + }) agentClient.ToolHandler.AddTool("listEvents", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { return eventsModel.List(context.Background(), info.UserId) @@ -177,9 +178,10 @@ func NewEventAgent(log *log.Logger, eventsModel models.EventModel, locationAgent }) agentClient.ToolHandler.AddTool("getEventLocationId", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { - query := "Can you get me the ID of the location present in this image?" - locationAgent.Log = log.With("Locations 📍", true) - locationAgent.RunAgent(locationPrompt, locationTools, "finish", &query, info.UserId, info.ImageId, info.ImageName, *info.Image) + // TODO: reenable this when I'm creating the agent locally instead of getting it from above. + // query := "Can you get me the ID of the location present in this image?" + // locationAgent.Log = log.With("Locations 📍", true) + // locationAgent.RunAgent(info.UserId, info.ImageId, info.ImageName, *info.Image) return locationAgent.Reply, nil }) diff --git a/backend/agents/location_agent.go b/backend/agents/location_agent.go index 0c1b4be..209d58c 100644 --- a/backend/agents/location_agent.go +++ b/backend/agents/location_agent.go @@ -105,11 +105,12 @@ type linkLocationArguments struct { } func NewLocationAgent(log *log.Logger, locationModel models.LocationModel) (client.AgentClient, error) { - agentClient, err := client.CreateAgentClient(log) - - if err != nil { - return client.AgentClient{}, err - } + agentClient := client.CreateAgentClient(client.CreateAgentClientOptions{ + SystemPrompt: locationPrompt, + JsonTools: locationTools, + Log: log, + EndToolCall: "finish", + }) agentClient.ToolHandler.AddTool("listLocations", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { return locationModel.List(context.Background(), info.UserId) diff --git a/backend/agents/note_agent.go b/backend/agents/note_agent.go index 81a2d57..fc26b8a 100644 --- a/backend/agents/note_agent.go +++ b/backend/agents/note_agent.go @@ -69,10 +69,9 @@ func (agent NoteAgent) GetNotes(userId uuid.UUID, imageId uuid.UUID, imageName s } func NewNoteAgent(log *log.Logger, noteModel models.NoteModel) (NoteAgent, error) { - client, err := client.CreateAgentClient(log) - if err != nil { - return NoteAgent{}, err - } + client := client.CreateAgentClient(client.CreateAgentClientOptions{ + SystemPrompt: noteAgentPrompt, + }) agent := NoteAgent{ client: client, diff --git a/backend/agents/orchestrator.go b/backend/agents/orchestrator.go index d3e072e..b5f68e4 100644 --- a/backend/agents/orchestrator.go +++ b/backend/agents/orchestrator.go @@ -7,7 +7,7 @@ import ( "github.com/charmbracelet/log" ) -const OrchestratorPrompt = ` +const orchestratorPrompt = ` You are an Orchestrator for various AI agents. The user will send you images and you have to determine which agents you have to call, in order to best help the user. @@ -38,7 +38,7 @@ Always call agents in parallel if you need to call more than 1. Do not call the agent if you do not think it is relevant for the image. ` -const OrchestratorTools = ` +const orchestratorTools = ` [ { "type": "function", @@ -113,11 +113,12 @@ type Status struct { } func NewOrchestratorAgent(log *log.Logger, noteAgent NoteAgent, contactAgent client.AgentClient, locationAgent client.AgentClient, eventAgent client.AgentClient, imageName string, imageData []byte) (client.AgentClient, error) { - agent, err := client.CreateAgentClient(log) - - if err != nil { - return client.AgentClient{}, err - } + agent := client.CreateAgentClient(client.CreateAgentClientOptions{ + SystemPrompt: orchestratorPrompt, + JsonTools: orchestratorTools, + Log: log, + EndToolCall: "noAction", + }) agent.ToolHandler.AddTool("noteAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { go noteAgent.GetNotes(info.UserId, info.ImageId, imageName, imageData) @@ -128,7 +129,7 @@ func NewOrchestratorAgent(log *log.Logger, noteAgent NoteAgent, contactAgent cli }) agent.ToolHandler.AddTool("contactAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { - go contactAgent.RunAgent(contactPrompt, contactTools, "finish", nil, info.UserId, info.ImageId, imageName, imageData) + go contactAgent.RunAgent(info.UserId, info.ImageId, imageName, imageData) return Status{ Ok: true, @@ -136,7 +137,7 @@ func NewOrchestratorAgent(log *log.Logger, noteAgent NoteAgent, contactAgent cli }) agent.ToolHandler.AddTool("locationAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { - go locationAgent.RunAgent(locationPrompt, locationTools, "finish", nil, info.UserId, info.ImageId, imageName, imageData) + go locationAgent.RunAgent(info.UserId, info.ImageId, imageName, imageData) return Status{ Ok: true, @@ -144,7 +145,7 @@ func NewOrchestratorAgent(log *log.Logger, noteAgent NoteAgent, contactAgent cli }) agent.ToolHandler.AddTool("eventAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { - go eventAgent.RunAgent(eventPrompt, eventTools, "finish", nil, info.UserId, info.ImageId, imageName, imageData) + go eventAgent.RunAgent(info.UserId, info.ImageId, imageName, imageData) return Status{ Ok: true, diff --git a/backend/events.go b/backend/events.go index 367cdd1..e82bb8c 100644 --- a/backend/events.go +++ b/backend/events.go @@ -90,7 +90,7 @@ func ListenNewImageEvents(db *sql.DB, eventManager *EventManager) { panic(err) } - err = orchestrator.RunAgent(agents.OrchestratorPrompt, agents.OrchestratorTools, "noAction", nil, image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image) + err = orchestrator.RunAgent(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image) if err != nil { databaseEventLog.Error("Orchestrator failed", "error", err) return