refactor(agent): main agent loop extracted away

Still not super sure how to represent these agents in code.
It doesn't make the most amount of sense to keep them in structs. A
curried function is more like it, with system prompt and tooling.

Maybe that's what I'll end up doing.
This commit is contained in:
2025-04-12 14:39:16 +01:00
parent 91cc54aaec
commit 959b741fcb
5 changed files with 47 additions and 125 deletions

View File

@ -9,6 +9,7 @@ import (
"os" "os"
"github.com/charmbracelet/log" "github.com/charmbracelet/log"
"github.com/google/uuid"
) )
type ResponseFormat struct { type ResponseFormat struct {
@ -217,3 +218,39 @@ func (client AgentClient) Process(info ToolHandlerInfo, req *AgentRequestBody) e
return err return err
} }
func (client AgentClient) RunAgent(systemPrompt string, jsonTools string, endToolCall string, userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
var tools any
err := json.Unmarshal([]byte(jsonTools), &tools)
toolChoice := "any"
request := AgentRequestBody{
Tools: &tools,
ToolChoice: &toolChoice,
Model: "pixtral-12b-2409",
Temperature: 0.3,
EndToolCall: endToolCall,
ResponseFormat: ResponseFormat{
Type: "text",
},
Chat: &Chat{
Messages: make([]ChatMessage, 0),
},
}
request.Chat.AddSystem(systemPrompt)
request.Chat.AddImage(imageName, imageData)
_, err = client.Request(&request)
if err != nil {
return err
}
toolHandlerInfo := ToolHandlerInfo{
ImageId: imageId,
UserId: userId,
}
return client.ToolLoop(toolHandlerInfo, &request)
}

View File

@ -116,43 +116,6 @@ type linkContactArguments struct {
ContactID string `json:"contactId"` ContactID string `json:"contactId"`
} }
// Yeah this is just a copy of the other one.
func (agent ContactAgent) GetContacts(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
var tools any
err := json.Unmarshal([]byte(contactTools), &tools)
toolChoice := "any"
request := client.AgentRequestBody{
Tools: &tools,
ToolChoice: &toolChoice,
Model: "pixtral-12b-2409",
Temperature: 0.3,
EndToolCall: "finish",
ResponseFormat: client.ResponseFormat{
Type: "text",
},
Chat: &client.Chat{
Messages: make([]client.ChatMessage, 0),
},
}
request.Chat.AddSystem(eventLocationPrompt)
request.Chat.AddImage(imageName, imageData)
_, err = agent.client.Request(&request)
if err != nil {
return err
}
toolHandlerInfo := client.ToolHandlerInfo{
ImageId: imageId,
UserId: userId,
}
return agent.client.ToolLoop(toolHandlerInfo, &request)
}
func NewContactAgent(contactModel models.ContactModel) (ContactAgent, error) { func NewContactAgent(contactModel models.ContactModel) (ContactAgent, error) {
agentClient, err := client.CreateAgentClient(log.NewWithOptions(os.Stdout, log.Options{ agentClient, err := client.CreateAgentClient(log.NewWithOptions(os.Stdout, log.Options{
ReportTimestamp: true, ReportTimestamp: true,

View File

@ -37,7 +37,7 @@ Always prioritize the creation of locations and organizers before events. Ensure
` `
// TODO: this should be read directly from a file on load. // TODO: this should be read directly from a file on load.
const TOOLS = ` const eventLocationTools = `
[ [
{ {
"type": "function", "type": "function",
@ -151,42 +151,6 @@ type CreateEventArguments struct {
OrganizerName string `json:"organizerName"` OrganizerName string `json:"organizerName"`
} }
func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
var tools any
err := json.Unmarshal([]byte(TOOLS), &tools)
toolChoice := "any"
request := client.AgentRequestBody{
Tools: &tools,
ToolChoice: &toolChoice,
Model: "pixtral-12b-2409",
Temperature: 0.3,
EndToolCall: "finish",
ResponseFormat: client.ResponseFormat{
Type: "text",
},
Chat: &client.Chat{
Messages: make([]client.ChatMessage, 0),
},
}
request.Chat.AddSystem(eventLocationPrompt)
request.Chat.AddImage(imageName, imageData)
_, err = agent.client.Request(&request)
if err != nil {
return err
}
toolHandlerInfo := client.ToolHandlerInfo{
ImageId: imageId,
UserId: userId,
}
return agent.client.ToolLoop(toolHandlerInfo, &request)
}
func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel, contactModel models.ContactModel) (EventLocationAgent, error) { func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel, contactModel models.ContactModel) (EventLocationAgent, error) {
agentClient, err := client.CreateAgentClient(log.NewWithOptions(os.Stdout, log.Options{ agentClient, err := client.CreateAgentClient(log.NewWithOptions(os.Stdout, log.Options{
ReportTimestamp: true, ReportTimestamp: true,

View File

@ -1,17 +1,15 @@
package agents package agents
import ( import (
"encoding/json"
"errors" "errors"
"os" "os"
"screenmark/screenmark/agents/client" "screenmark/screenmark/agents/client"
"time" "time"
"github.com/charmbracelet/log" "github.com/charmbracelet/log"
"github.com/google/uuid"
) )
const orchestratorPrompt = ` const OrchestratorPrompt = `
You are an Orchestrator for various AI agents. You are an Orchestrator for various AI agents.
The user will send you images and you have to determine which agents you have to call, in order to best help the user. The user will send you images and you have to determine which agents you have to call, in order to best help the user.
@ -41,7 +39,7 @@ Always call agents in parallel if you need to call more than 1.
Do not call the agent if you do not think it is relevant for the image. Do not call the agent if you do not think it is relevant for the image.
` `
const MY_TOOLS = ` const OrchestratorTools = `
[ [
{ {
"type": "function", "type": "function",
@ -94,7 +92,7 @@ const MY_TOOLS = `
]` ]`
type OrchestratorAgent struct { type OrchestratorAgent struct {
client client.AgentClient Client client.AgentClient
log log.Logger log log.Logger
} }
@ -103,48 +101,6 @@ type Status struct {
Ok bool `json:"ok"` Ok bool `json:"ok"`
} }
// TODO: the primary function of the agent could be extracted outwards.
// This is basically the same function as we have in the `event_location_agent.go`
func (agent OrchestratorAgent) Orchestrate(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
toolChoice := "any"
var tools any
err := json.Unmarshal([]byte(MY_TOOLS), &tools)
if err != nil {
return err
}
request := client.AgentRequestBody{
Model: "pixtral-12b-2409",
Temperature: 0.4,
ResponseFormat: client.ResponseFormat{
Type: "text",
},
ToolChoice: &toolChoice,
Tools: &tools,
EndToolCall: "noAction",
Chat: &client.Chat{
Messages: make([]client.ChatMessage, 0),
},
}
request.Chat.AddSystem(orchestratorPrompt)
request.Chat.AddImage(imageName, imageData)
if _, err := agent.client.Request(&request); err != nil {
return err
}
toolHandlerInfo := client.ToolHandlerInfo{
ImageId: imageId,
UserId: userId,
}
return agent.client.ToolLoop(toolHandlerInfo, &request)
}
func NewOrchestratorAgent(eventLocationAgent EventLocationAgent, noteAgent NoteAgent, contactAgent ContactAgent, imageName string, imageData []byte) (OrchestratorAgent, error) { func NewOrchestratorAgent(eventLocationAgent EventLocationAgent, noteAgent NoteAgent, contactAgent ContactAgent, imageName string, imageData []byte) (OrchestratorAgent, error) {
agent, err := client.CreateAgentClient(log.NewWithOptions(os.Stdout, log.Options{ agent, err := client.CreateAgentClient(log.NewWithOptions(os.Stdout, log.Options{
ReportTimestamp: true, ReportTimestamp: true,
@ -160,7 +116,7 @@ func NewOrchestratorAgent(eventLocationAgent EventLocationAgent, noteAgent NoteA
// We need a way to keep track of this async? // We need a way to keep track of this async?
// Probably just a DB, because we don't want to wait. The orchistrator shouldnt wait for this stuff to finish. // Probably just a DB, because we don't want to wait. The orchistrator shouldnt wait for this stuff to finish.
go eventLocationAgent.GetLocations(info.UserId, info.ImageId, imageName, imageData) go eventLocationAgent.client.RunAgent(eventLocationPrompt, eventLocationTools, "finish", info.UserId, info.ImageId, imageName, imageData)
return Status{ return Status{
Ok: true, Ok: true,
@ -176,7 +132,7 @@ func NewOrchestratorAgent(eventLocationAgent EventLocationAgent, noteAgent NoteA
}) })
agent.ToolHandler.AddTool("contactAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { agent.ToolHandler.AddTool("contactAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
go contactAgent.GetContacts(info.UserId, info.ImageId, imageName, imageData) go contactAgent.client.RunAgent(contactPrompt, contactTools, "finish", info.UserId, info.ImageId, imageName, imageData)
return Status{ return Status{
Ok: true, Ok: true,
@ -192,6 +148,6 @@ func NewOrchestratorAgent(eventLocationAgent EventLocationAgent, noteAgent NoteA
}) })
return OrchestratorAgent{ return OrchestratorAgent{
client: agent, Client: agent,
}, nil }, nil
} }

View File

@ -74,7 +74,9 @@ func ListenNewImageEvents(db *sql.DB) {
panic(err) panic(err)
} }
err = orchestrator.Orchestrate(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image) // Still need to find some way to hide this complexity away.
// I don't think wrapping agents in structs actually works too well.
err = orchestrator.Client.RunAgent(agents.OrchestratorPrompt, agents.OrchestratorTools, "noAction", image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
} }