From 959b741fcb9c8023dba5dae0679723dbb2fb137d Mon Sep 17 00:00:00 2001 From: John Costa Date: Sat, 12 Apr 2025 14:39:16 +0100 Subject: [PATCH] refactor(agent): main agent loop extracted away Still not super sure how to represent these agents in code. It doesn't make the most amount of sense to keep them in structs. A curried function is more like it, with system prompt and tooling. Maybe that's what I'll end up doing. --- backend/agents/client/client.go | 37 +++++++++++++++++ backend/agents/contact_agent.go | 37 ----------------- backend/agents/event_location_agent.go | 38 +---------------- backend/agents/orchestrator.go | 56 +++----------------------- backend/events.go | 4 +- 5 files changed, 47 insertions(+), 125 deletions(-) diff --git a/backend/agents/client/client.go b/backend/agents/client/client.go index 6be9f96..3e540c7 100644 --- a/backend/agents/client/client.go +++ b/backend/agents/client/client.go @@ -9,6 +9,7 @@ import ( "os" "github.com/charmbracelet/log" + "github.com/google/uuid" ) type ResponseFormat struct { @@ -217,3 +218,39 @@ func (client AgentClient) Process(info ToolHandlerInfo, req *AgentRequestBody) e return err } + +func (client AgentClient) RunAgent(systemPrompt string, jsonTools string, endToolCall string, userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error { + var tools any + err := json.Unmarshal([]byte(jsonTools), &tools) + + toolChoice := "any" + + request := AgentRequestBody{ + Tools: &tools, + ToolChoice: &toolChoice, + Model: "pixtral-12b-2409", + Temperature: 0.3, + EndToolCall: endToolCall, + ResponseFormat: ResponseFormat{ + Type: "text", + }, + Chat: &Chat{ + Messages: make([]ChatMessage, 0), + }, + } + + request.Chat.AddSystem(systemPrompt) + request.Chat.AddImage(imageName, imageData) + + _, err = client.Request(&request) + if err != nil { + return err + } + + toolHandlerInfo := ToolHandlerInfo{ + ImageId: imageId, + UserId: userId, + } + + return client.ToolLoop(toolHandlerInfo, &request) +} diff --git a/backend/agents/contact_agent.go b/backend/agents/contact_agent.go index 2e41c21..a649889 100644 --- a/backend/agents/contact_agent.go +++ b/backend/agents/contact_agent.go @@ -116,43 +116,6 @@ type linkContactArguments struct { ContactID string `json:"contactId"` } -// Yeah this is just a copy of the other one. -func (agent ContactAgent) GetContacts(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error { - var tools any - err := json.Unmarshal([]byte(contactTools), &tools) - - toolChoice := "any" - - request := client.AgentRequestBody{ - Tools: &tools, - ToolChoice: &toolChoice, - Model: "pixtral-12b-2409", - Temperature: 0.3, - EndToolCall: "finish", - ResponseFormat: client.ResponseFormat{ - Type: "text", - }, - Chat: &client.Chat{ - Messages: make([]client.ChatMessage, 0), - }, - } - - request.Chat.AddSystem(eventLocationPrompt) - request.Chat.AddImage(imageName, imageData) - - _, err = agent.client.Request(&request) - if err != nil { - return err - } - - toolHandlerInfo := client.ToolHandlerInfo{ - ImageId: imageId, - UserId: userId, - } - - return agent.client.ToolLoop(toolHandlerInfo, &request) -} - func NewContactAgent(contactModel models.ContactModel) (ContactAgent, error) { agentClient, err := client.CreateAgentClient(log.NewWithOptions(os.Stdout, log.Options{ ReportTimestamp: true, diff --git a/backend/agents/event_location_agent.go b/backend/agents/event_location_agent.go index 979edcf..68a07b0 100644 --- a/backend/agents/event_location_agent.go +++ b/backend/agents/event_location_agent.go @@ -37,7 +37,7 @@ Always prioritize the creation of locations and organizers before events. Ensure ` // TODO: this should be read directly from a file on load. -const TOOLS = ` +const eventLocationTools = ` [ { "type": "function", @@ -151,42 +151,6 @@ type CreateEventArguments struct { OrganizerName string `json:"organizerName"` } -func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error { - var tools any - err := json.Unmarshal([]byte(TOOLS), &tools) - - toolChoice := "any" - - request := client.AgentRequestBody{ - Tools: &tools, - ToolChoice: &toolChoice, - Model: "pixtral-12b-2409", - Temperature: 0.3, - EndToolCall: "finish", - ResponseFormat: client.ResponseFormat{ - Type: "text", - }, - Chat: &client.Chat{ - Messages: make([]client.ChatMessage, 0), - }, - } - - request.Chat.AddSystem(eventLocationPrompt) - request.Chat.AddImage(imageName, imageData) - - _, err = agent.client.Request(&request) - if err != nil { - return err - } - - toolHandlerInfo := client.ToolHandlerInfo{ - ImageId: imageId, - UserId: userId, - } - - return agent.client.ToolLoop(toolHandlerInfo, &request) -} - func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel, contactModel models.ContactModel) (EventLocationAgent, error) { agentClient, err := client.CreateAgentClient(log.NewWithOptions(os.Stdout, log.Options{ ReportTimestamp: true, diff --git a/backend/agents/orchestrator.go b/backend/agents/orchestrator.go index 721ce22..9868bef 100644 --- a/backend/agents/orchestrator.go +++ b/backend/agents/orchestrator.go @@ -1,17 +1,15 @@ package agents import ( - "encoding/json" "errors" "os" "screenmark/screenmark/agents/client" "time" "github.com/charmbracelet/log" - "github.com/google/uuid" ) -const orchestratorPrompt = ` +const OrchestratorPrompt = ` You are an Orchestrator for various AI agents. The user will send you images and you have to determine which agents you have to call, in order to best help the user. @@ -41,7 +39,7 @@ Always call agents in parallel if you need to call more than 1. Do not call the agent if you do not think it is relevant for the image. ` -const MY_TOOLS = ` +const OrchestratorTools = ` [ { "type": "function", @@ -94,7 +92,7 @@ const MY_TOOLS = ` ]` type OrchestratorAgent struct { - client client.AgentClient + Client client.AgentClient log log.Logger } @@ -103,48 +101,6 @@ type Status struct { Ok bool `json:"ok"` } -// TODO: the primary function of the agent could be extracted outwards. -// This is basically the same function as we have in the `event_location_agent.go` -func (agent OrchestratorAgent) Orchestrate(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error { - toolChoice := "any" - - var tools any - err := json.Unmarshal([]byte(MY_TOOLS), &tools) - if err != nil { - return err - } - - request := client.AgentRequestBody{ - Model: "pixtral-12b-2409", - Temperature: 0.4, - ResponseFormat: client.ResponseFormat{ - Type: "text", - }, - ToolChoice: &toolChoice, - Tools: &tools, - - EndToolCall: "noAction", - - Chat: &client.Chat{ - Messages: make([]client.ChatMessage, 0), - }, - } - - request.Chat.AddSystem(orchestratorPrompt) - request.Chat.AddImage(imageName, imageData) - - if _, err := agent.client.Request(&request); err != nil { - return err - } - - toolHandlerInfo := client.ToolHandlerInfo{ - ImageId: imageId, - UserId: userId, - } - - return agent.client.ToolLoop(toolHandlerInfo, &request) -} - func NewOrchestratorAgent(eventLocationAgent EventLocationAgent, noteAgent NoteAgent, contactAgent ContactAgent, imageName string, imageData []byte) (OrchestratorAgent, error) { agent, err := client.CreateAgentClient(log.NewWithOptions(os.Stdout, log.Options{ ReportTimestamp: true, @@ -160,7 +116,7 @@ func NewOrchestratorAgent(eventLocationAgent EventLocationAgent, noteAgent NoteA // We need a way to keep track of this async? // Probably just a DB, because we don't want to wait. The orchistrator shouldnt wait for this stuff to finish. - go eventLocationAgent.GetLocations(info.UserId, info.ImageId, imageName, imageData) + go eventLocationAgent.client.RunAgent(eventLocationPrompt, eventLocationTools, "finish", info.UserId, info.ImageId, imageName, imageData) return Status{ Ok: true, @@ -176,7 +132,7 @@ func NewOrchestratorAgent(eventLocationAgent EventLocationAgent, noteAgent NoteA }) agent.ToolHandler.AddTool("contactAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { - go contactAgent.GetContacts(info.UserId, info.ImageId, imageName, imageData) + go contactAgent.client.RunAgent(contactPrompt, contactTools, "finish", info.UserId, info.ImageId, imageName, imageData) return Status{ Ok: true, @@ -192,6 +148,6 @@ func NewOrchestratorAgent(eventLocationAgent EventLocationAgent, noteAgent NoteA }) return OrchestratorAgent{ - client: agent, + Client: agent, }, nil } diff --git a/backend/events.go b/backend/events.go index c736995..47821f2 100644 --- a/backend/events.go +++ b/backend/events.go @@ -74,7 +74,9 @@ func ListenNewImageEvents(db *sql.DB) { panic(err) } - err = orchestrator.Orchestrate(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image) + // Still need to find some way to hide this complexity away. + // I don't think wrapping agents in structs actually works too well. + err = orchestrator.Client.RunAgent(agents.OrchestratorPrompt, agents.OrchestratorTools, "noAction", image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image) if err != nil { log.Println(err) }