refactor(agent): main agent loop extracted away
Still not super sure how to represent these agents in code. It doesn't make the most amount of sense to keep them in structs. A curried function is more like it, with system prompt and tooling. Maybe that's what I'll end up doing.
This commit is contained in:
@ -9,6 +9,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/charmbracelet/log"
|
"github.com/charmbracelet/log"
|
||||||
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ResponseFormat struct {
|
type ResponseFormat struct {
|
||||||
@ -217,3 +218,39 @@ func (client AgentClient) Process(info ToolHandlerInfo, req *AgentRequestBody) e
|
|||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (client AgentClient) RunAgent(systemPrompt string, jsonTools string, endToolCall string, userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
|
||||||
|
var tools any
|
||||||
|
err := json.Unmarshal([]byte(jsonTools), &tools)
|
||||||
|
|
||||||
|
toolChoice := "any"
|
||||||
|
|
||||||
|
request := AgentRequestBody{
|
||||||
|
Tools: &tools,
|
||||||
|
ToolChoice: &toolChoice,
|
||||||
|
Model: "pixtral-12b-2409",
|
||||||
|
Temperature: 0.3,
|
||||||
|
EndToolCall: endToolCall,
|
||||||
|
ResponseFormat: ResponseFormat{
|
||||||
|
Type: "text",
|
||||||
|
},
|
||||||
|
Chat: &Chat{
|
||||||
|
Messages: make([]ChatMessage, 0),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
request.Chat.AddSystem(systemPrompt)
|
||||||
|
request.Chat.AddImage(imageName, imageData)
|
||||||
|
|
||||||
|
_, err = client.Request(&request)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
toolHandlerInfo := ToolHandlerInfo{
|
||||||
|
ImageId: imageId,
|
||||||
|
UserId: userId,
|
||||||
|
}
|
||||||
|
|
||||||
|
return client.ToolLoop(toolHandlerInfo, &request)
|
||||||
|
}
|
||||||
|
@ -116,43 +116,6 @@ type linkContactArguments struct {
|
|||||||
ContactID string `json:"contactId"`
|
ContactID string `json:"contactId"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Yeah this is just a copy of the other one.
|
|
||||||
func (agent ContactAgent) GetContacts(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
|
|
||||||
var tools any
|
|
||||||
err := json.Unmarshal([]byte(contactTools), &tools)
|
|
||||||
|
|
||||||
toolChoice := "any"
|
|
||||||
|
|
||||||
request := client.AgentRequestBody{
|
|
||||||
Tools: &tools,
|
|
||||||
ToolChoice: &toolChoice,
|
|
||||||
Model: "pixtral-12b-2409",
|
|
||||||
Temperature: 0.3,
|
|
||||||
EndToolCall: "finish",
|
|
||||||
ResponseFormat: client.ResponseFormat{
|
|
||||||
Type: "text",
|
|
||||||
},
|
|
||||||
Chat: &client.Chat{
|
|
||||||
Messages: make([]client.ChatMessage, 0),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
request.Chat.AddSystem(eventLocationPrompt)
|
|
||||||
request.Chat.AddImage(imageName, imageData)
|
|
||||||
|
|
||||||
_, err = agent.client.Request(&request)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
toolHandlerInfo := client.ToolHandlerInfo{
|
|
||||||
ImageId: imageId,
|
|
||||||
UserId: userId,
|
|
||||||
}
|
|
||||||
|
|
||||||
return agent.client.ToolLoop(toolHandlerInfo, &request)
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewContactAgent(contactModel models.ContactModel) (ContactAgent, error) {
|
func NewContactAgent(contactModel models.ContactModel) (ContactAgent, error) {
|
||||||
agentClient, err := client.CreateAgentClient(log.NewWithOptions(os.Stdout, log.Options{
|
agentClient, err := client.CreateAgentClient(log.NewWithOptions(os.Stdout, log.Options{
|
||||||
ReportTimestamp: true,
|
ReportTimestamp: true,
|
||||||
|
@ -37,7 +37,7 @@ Always prioritize the creation of locations and organizers before events. Ensure
|
|||||||
`
|
`
|
||||||
|
|
||||||
// TODO: this should be read directly from a file on load.
|
// TODO: this should be read directly from a file on load.
|
||||||
const TOOLS = `
|
const eventLocationTools = `
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
@ -151,42 +151,6 @@ type CreateEventArguments struct {
|
|||||||
OrganizerName string `json:"organizerName"`
|
OrganizerName string `json:"organizerName"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (agent EventLocationAgent) GetLocations(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
|
|
||||||
var tools any
|
|
||||||
err := json.Unmarshal([]byte(TOOLS), &tools)
|
|
||||||
|
|
||||||
toolChoice := "any"
|
|
||||||
|
|
||||||
request := client.AgentRequestBody{
|
|
||||||
Tools: &tools,
|
|
||||||
ToolChoice: &toolChoice,
|
|
||||||
Model: "pixtral-12b-2409",
|
|
||||||
Temperature: 0.3,
|
|
||||||
EndToolCall: "finish",
|
|
||||||
ResponseFormat: client.ResponseFormat{
|
|
||||||
Type: "text",
|
|
||||||
},
|
|
||||||
Chat: &client.Chat{
|
|
||||||
Messages: make([]client.ChatMessage, 0),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
request.Chat.AddSystem(eventLocationPrompt)
|
|
||||||
request.Chat.AddImage(imageName, imageData)
|
|
||||||
|
|
||||||
_, err = agent.client.Request(&request)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
toolHandlerInfo := client.ToolHandlerInfo{
|
|
||||||
ImageId: imageId,
|
|
||||||
UserId: userId,
|
|
||||||
}
|
|
||||||
|
|
||||||
return agent.client.ToolLoop(toolHandlerInfo, &request)
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel, contactModel models.ContactModel) (EventLocationAgent, error) {
|
func NewLocationEventAgent(locationModel models.LocationModel, eventModel models.EventModel, contactModel models.ContactModel) (EventLocationAgent, error) {
|
||||||
agentClient, err := client.CreateAgentClient(log.NewWithOptions(os.Stdout, log.Options{
|
agentClient, err := client.CreateAgentClient(log.NewWithOptions(os.Stdout, log.Options{
|
||||||
ReportTimestamp: true,
|
ReportTimestamp: true,
|
||||||
|
@ -1,17 +1,15 @@
|
|||||||
package agents
|
package agents
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"os"
|
"os"
|
||||||
"screenmark/screenmark/agents/client"
|
"screenmark/screenmark/agents/client"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/charmbracelet/log"
|
"github.com/charmbracelet/log"
|
||||||
"github.com/google/uuid"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const orchestratorPrompt = `
|
const OrchestratorPrompt = `
|
||||||
You are an Orchestrator for various AI agents.
|
You are an Orchestrator for various AI agents.
|
||||||
|
|
||||||
The user will send you images and you have to determine which agents you have to call, in order to best help the user.
|
The user will send you images and you have to determine which agents you have to call, in order to best help the user.
|
||||||
@ -41,7 +39,7 @@ Always call agents in parallel if you need to call more than 1.
|
|||||||
Do not call the agent if you do not think it is relevant for the image.
|
Do not call the agent if you do not think it is relevant for the image.
|
||||||
`
|
`
|
||||||
|
|
||||||
const MY_TOOLS = `
|
const OrchestratorTools = `
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
@ -94,7 +92,7 @@ const MY_TOOLS = `
|
|||||||
]`
|
]`
|
||||||
|
|
||||||
type OrchestratorAgent struct {
|
type OrchestratorAgent struct {
|
||||||
client client.AgentClient
|
Client client.AgentClient
|
||||||
|
|
||||||
log log.Logger
|
log log.Logger
|
||||||
}
|
}
|
||||||
@ -103,48 +101,6 @@ type Status struct {
|
|||||||
Ok bool `json:"ok"`
|
Ok bool `json:"ok"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: the primary function of the agent could be extracted outwards.
|
|
||||||
// This is basically the same function as we have in the `event_location_agent.go`
|
|
||||||
func (agent OrchestratorAgent) Orchestrate(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
|
|
||||||
toolChoice := "any"
|
|
||||||
|
|
||||||
var tools any
|
|
||||||
err := json.Unmarshal([]byte(MY_TOOLS), &tools)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
request := client.AgentRequestBody{
|
|
||||||
Model: "pixtral-12b-2409",
|
|
||||||
Temperature: 0.4,
|
|
||||||
ResponseFormat: client.ResponseFormat{
|
|
||||||
Type: "text",
|
|
||||||
},
|
|
||||||
ToolChoice: &toolChoice,
|
|
||||||
Tools: &tools,
|
|
||||||
|
|
||||||
EndToolCall: "noAction",
|
|
||||||
|
|
||||||
Chat: &client.Chat{
|
|
||||||
Messages: make([]client.ChatMessage, 0),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
request.Chat.AddSystem(orchestratorPrompt)
|
|
||||||
request.Chat.AddImage(imageName, imageData)
|
|
||||||
|
|
||||||
if _, err := agent.client.Request(&request); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
toolHandlerInfo := client.ToolHandlerInfo{
|
|
||||||
ImageId: imageId,
|
|
||||||
UserId: userId,
|
|
||||||
}
|
|
||||||
|
|
||||||
return agent.client.ToolLoop(toolHandlerInfo, &request)
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewOrchestratorAgent(eventLocationAgent EventLocationAgent, noteAgent NoteAgent, contactAgent ContactAgent, imageName string, imageData []byte) (OrchestratorAgent, error) {
|
func NewOrchestratorAgent(eventLocationAgent EventLocationAgent, noteAgent NoteAgent, contactAgent ContactAgent, imageName string, imageData []byte) (OrchestratorAgent, error) {
|
||||||
agent, err := client.CreateAgentClient(log.NewWithOptions(os.Stdout, log.Options{
|
agent, err := client.CreateAgentClient(log.NewWithOptions(os.Stdout, log.Options{
|
||||||
ReportTimestamp: true,
|
ReportTimestamp: true,
|
||||||
@ -160,7 +116,7 @@ func NewOrchestratorAgent(eventLocationAgent EventLocationAgent, noteAgent NoteA
|
|||||||
// We need a way to keep track of this async?
|
// We need a way to keep track of this async?
|
||||||
// Probably just a DB, because we don't want to wait. The orchistrator shouldnt wait for this stuff to finish.
|
// Probably just a DB, because we don't want to wait. The orchistrator shouldnt wait for this stuff to finish.
|
||||||
|
|
||||||
go eventLocationAgent.GetLocations(info.UserId, info.ImageId, imageName, imageData)
|
go eventLocationAgent.client.RunAgent(eventLocationPrompt, eventLocationTools, "finish", info.UserId, info.ImageId, imageName, imageData)
|
||||||
|
|
||||||
return Status{
|
return Status{
|
||||||
Ok: true,
|
Ok: true,
|
||||||
@ -176,7 +132,7 @@ func NewOrchestratorAgent(eventLocationAgent EventLocationAgent, noteAgent NoteA
|
|||||||
})
|
})
|
||||||
|
|
||||||
agent.ToolHandler.AddTool("contactAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
agent.ToolHandler.AddTool("contactAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
||||||
go contactAgent.GetContacts(info.UserId, info.ImageId, imageName, imageData)
|
go contactAgent.client.RunAgent(contactPrompt, contactTools, "finish", info.UserId, info.ImageId, imageName, imageData)
|
||||||
|
|
||||||
return Status{
|
return Status{
|
||||||
Ok: true,
|
Ok: true,
|
||||||
@ -192,6 +148,6 @@ func NewOrchestratorAgent(eventLocationAgent EventLocationAgent, noteAgent NoteA
|
|||||||
})
|
})
|
||||||
|
|
||||||
return OrchestratorAgent{
|
return OrchestratorAgent{
|
||||||
client: agent,
|
Client: agent,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
@ -74,7 +74,9 @@ func ListenNewImageEvents(db *sql.DB) {
|
|||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = orchestrator.Orchestrate(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
|
// Still need to find some way to hide this complexity away.
|
||||||
|
// I don't think wrapping agents in structs actually works too well.
|
||||||
|
err = orchestrator.Client.RunAgent(agents.OrchestratorPrompt, agents.OrchestratorTools, "noAction", image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user