refactor(agent): main agent loop extracted away
Still not super sure how to represent these agents in code. It doesn't make the most amount of sense to keep them in structs. A curried function is more like it, with system prompt and tooling. Maybe that's what I'll end up doing.
This commit is contained in:
@ -9,6 +9,7 @@ import (
|
||||
"os"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type ResponseFormat struct {
|
||||
@ -217,3 +218,39 @@ func (client AgentClient) Process(info ToolHandlerInfo, req *AgentRequestBody) e
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (client AgentClient) RunAgent(systemPrompt string, jsonTools string, endToolCall string, userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
|
||||
var tools any
|
||||
err := json.Unmarshal([]byte(jsonTools), &tools)
|
||||
|
||||
toolChoice := "any"
|
||||
|
||||
request := AgentRequestBody{
|
||||
Tools: &tools,
|
||||
ToolChoice: &toolChoice,
|
||||
Model: "pixtral-12b-2409",
|
||||
Temperature: 0.3,
|
||||
EndToolCall: endToolCall,
|
||||
ResponseFormat: ResponseFormat{
|
||||
Type: "text",
|
||||
},
|
||||
Chat: &Chat{
|
||||
Messages: make([]ChatMessage, 0),
|
||||
},
|
||||
}
|
||||
|
||||
request.Chat.AddSystem(systemPrompt)
|
||||
request.Chat.AddImage(imageName, imageData)
|
||||
|
||||
_, err = client.Request(&request)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
toolHandlerInfo := ToolHandlerInfo{
|
||||
ImageId: imageId,
|
||||
UserId: userId,
|
||||
}
|
||||
|
||||
return client.ToolLoop(toolHandlerInfo, &request)
|
||||
}
|
||||
|
@ -116,43 +116,6 @@ type linkContactArguments struct {
|
||||
ContactID string `json:"contactId"`
|
||||
}
|
||||
|
||||
// Yeah this is just a copy of the other one.
|
||||
func (agent ContactAgent) GetContacts(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
|
||||
var tools any
|
||||
err := json.Unmarshal([]byte(contactTools), &tools)
|
||||
|
||||
toolChoice := "any"
|
||||
|
||||
request := client.AgentRequestBody{
|
||||
Tools: &tools,
|
||||
ToolChoice: &toolChoice,
|
||||
Model: "pixtral-12b-2409",
|
||||
Temperature: 0.3,
|
||||
EndToolCall: "finish",
|
||||
ResponseFormat: client.ResponseFormat{
|
||||
Type: "text",
|
||||
},
|
||||
Chat: &client.Chat{
|
||||
Messages: make([]client.ChatMessage, 0),
|
||||
},
|
||||
}
|
||||
|
||||
request.Chat.AddSystem(eventLocationPrompt)
|
||||
request.Chat.AddImage(imageName, imageData)
|
||||
|
||||
_, err = agent.client.Request(&request)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
toolHandlerInfo := client.ToolHandlerInfo{
|
||||
ImageId: imageId,
|
||||
UserId: userId,
|
||||
}
|
||||
|
||||
return agent.client.ToolLoop(toolHandlerInfo, &request)
|
||||
}
|
||||
|
||||
func NewContactAgent(contactModel models.ContactModel) (ContactAgent, error) {
|
||||
agentClient, err := client.CreateAgentClient(log.NewWithOptions(os.Stdout, log.Options{
|
||||
ReportTimestamp: true,
|
||||
|
@ -37,7 +37,7 @@ Always prioritize the creation of locations and organizers before events. Ensure
|
||||
`
|
||||
|
||||
// TODO: this should be read directly from a file on load.
|
||||
const TOOLS = `
|
||||
const eventLocationTools = `
|
||||
[
|
||||
{
|
||||
"type": "function",
|
||||
@ -151,42 +151,6 @@ type CreateEventArguments struct {
|
||||
OrganizerName string `json:"organizerName"`
|
||||
}
|
||||
|
||||
func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
|
||||
var tools any
|
||||
err := json.Unmarshal([]byte(TOOLS), &tools)
|
||||
|
||||
toolChoice := "any"
|
||||
|
||||
request := client.AgentRequestBody{
|
||||
Tools: &tools,
|
||||
ToolChoice: &toolChoice,
|
||||
Model: "pixtral-12b-2409",
|
||||
Temperature: 0.3,
|
||||
EndToolCall: "finish",
|
||||
ResponseFormat: client.ResponseFormat{
|
||||
Type: "text",
|
||||
},
|
||||
Chat: &client.Chat{
|
||||
Messages: make([]client.ChatMessage, 0),
|
||||
},
|
||||
}
|
||||
|
||||
request.Chat.AddSystem(eventLocationPrompt)
|
||||
request.Chat.AddImage(imageName, imageData)
|
||||
|
||||
_, err = agent.client.Request(&request)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
toolHandlerInfo := client.ToolHandlerInfo{
|
||||
ImageId: imageId,
|
||||
UserId: userId,
|
||||
}
|
||||
|
||||
return agent.client.ToolLoop(toolHandlerInfo, &request)
|
||||
}
|
||||
|
||||
func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel, contactModel models.ContactModel) (EventLocationAgent, error) {
|
||||
agentClient, err := client.CreateAgentClient(log.NewWithOptions(os.Stdout, log.Options{
|
||||
ReportTimestamp: true,
|
||||
|
@ -1,17 +1,15 @@
|
||||
package agents
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"os"
|
||||
"screenmark/screenmark/agents/client"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const orchestratorPrompt = `
|
||||
const OrchestratorPrompt = `
|
||||
You are an Orchestrator for various AI agents.
|
||||
|
||||
The user will send you images and you have to determine which agents you have to call, in order to best help the user.
|
||||
@ -41,7 +39,7 @@ Always call agents in parallel if you need to call more than 1.
|
||||
Do not call the agent if you do not think it is relevant for the image.
|
||||
`
|
||||
|
||||
const MY_TOOLS = `
|
||||
const OrchestratorTools = `
|
||||
[
|
||||
{
|
||||
"type": "function",
|
||||
@ -94,7 +92,7 @@ const MY_TOOLS = `
|
||||
]`
|
||||
|
||||
type OrchestratorAgent struct {
|
||||
client client.AgentClient
|
||||
Client client.AgentClient
|
||||
|
||||
log log.Logger
|
||||
}
|
||||
@ -103,48 +101,6 @@ type Status struct {
|
||||
Ok bool `json:"ok"`
|
||||
}
|
||||
|
||||
// TODO: the primary function of the agent could be extracted outwards.
|
||||
// This is basically the same function as we have in the `event_location_agent.go`
|
||||
func (agent OrchestratorAgent) Orchestrate(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
|
||||
toolChoice := "any"
|
||||
|
||||
var tools any
|
||||
err := json.Unmarshal([]byte(MY_TOOLS), &tools)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
request := client.AgentRequestBody{
|
||||
Model: "pixtral-12b-2409",
|
||||
Temperature: 0.4,
|
||||
ResponseFormat: client.ResponseFormat{
|
||||
Type: "text",
|
||||
},
|
||||
ToolChoice: &toolChoice,
|
||||
Tools: &tools,
|
||||
|
||||
EndToolCall: "noAction",
|
||||
|
||||
Chat: &client.Chat{
|
||||
Messages: make([]client.ChatMessage, 0),
|
||||
},
|
||||
}
|
||||
|
||||
request.Chat.AddSystem(orchestratorPrompt)
|
||||
request.Chat.AddImage(imageName, imageData)
|
||||
|
||||
if _, err := agent.client.Request(&request); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
toolHandlerInfo := client.ToolHandlerInfo{
|
||||
ImageId: imageId,
|
||||
UserId: userId,
|
||||
}
|
||||
|
||||
return agent.client.ToolLoop(toolHandlerInfo, &request)
|
||||
}
|
||||
|
||||
func NewOrchestratorAgent(eventLocationAgent EventLocationAgent, noteAgent NoteAgent, contactAgent ContactAgent, imageName string, imageData []byte) (OrchestratorAgent, error) {
|
||||
agent, err := client.CreateAgentClient(log.NewWithOptions(os.Stdout, log.Options{
|
||||
ReportTimestamp: true,
|
||||
@ -160,7 +116,7 @@ func NewOrchestratorAgent(eventLocationAgent EventLocationAgent, noteAgent NoteA
|
||||
// We need a way to keep track of this async?
|
||||
// Probably just a DB, because we don't want to wait. The orchistrator shouldnt wait for this stuff to finish.
|
||||
|
||||
go eventLocationAgent.GetLocations(info.UserId, info.ImageId, imageName, imageData)
|
||||
go eventLocationAgent.client.RunAgent(eventLocationPrompt, eventLocationTools, "finish", info.UserId, info.ImageId, imageName, imageData)
|
||||
|
||||
return Status{
|
||||
Ok: true,
|
||||
@ -176,7 +132,7 @@ func NewOrchestratorAgent(eventLocationAgent EventLocationAgent, noteAgent NoteA
|
||||
})
|
||||
|
||||
agent.ToolHandler.AddTool("contactAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
||||
go contactAgent.GetContacts(info.UserId, info.ImageId, imageName, imageData)
|
||||
go contactAgent.client.RunAgent(contactPrompt, contactTools, "finish", info.UserId, info.ImageId, imageName, imageData)
|
||||
|
||||
return Status{
|
||||
Ok: true,
|
||||
@ -192,6 +148,6 @@ func NewOrchestratorAgent(eventLocationAgent EventLocationAgent, noteAgent NoteA
|
||||
})
|
||||
|
||||
return OrchestratorAgent{
|
||||
client: agent,
|
||||
Client: agent,
|
||||
}, nil
|
||||
}
|
||||
|
@ -74,7 +74,9 @@ func ListenNewImageEvents(db *sql.DB) {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = orchestrator.Orchestrate(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
|
||||
// Still need to find some way to hide this complexity away.
|
||||
// I don't think wrapping agents in structs actually works too well.
|
||||
err = orchestrator.Client.RunAgent(agents.OrchestratorPrompt, agents.OrchestratorTools, "noAction", image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
|
Reference in New Issue
Block a user