refactor(agents): encapsulating prompt and calls inside factory method
This commit is contained in:
@ -76,15 +76,25 @@ type AgentClient struct {
|
||||
Reply string
|
||||
|
||||
Do func(req *http.Request) (*http.Response, error)
|
||||
|
||||
Options CreateAgentClientOptions
|
||||
}
|
||||
|
||||
const OPENAI_API_KEY = "OPENAI_API_KEY"
|
||||
|
||||
func CreateAgentClient(log *log.Logger) (AgentClient, error) {
|
||||
type CreateAgentClientOptions struct {
|
||||
Log *log.Logger
|
||||
SystemPrompt string
|
||||
JsonTools string
|
||||
EndToolCall string
|
||||
Query *string
|
||||
}
|
||||
|
||||
func CreateAgentClient(options CreateAgentClientOptions) AgentClient {
|
||||
apiKey := os.Getenv(OPENAI_API_KEY)
|
||||
|
||||
if len(apiKey) == 0 {
|
||||
return AgentClient{}, errors.New(OPENAI_API_KEY + " was not found.")
|
||||
panic("No api key")
|
||||
}
|
||||
|
||||
return AgentClient{
|
||||
@ -95,12 +105,14 @@ func CreateAgentClient(log *log.Logger) (AgentClient, error) {
|
||||
return client.Do(req)
|
||||
},
|
||||
|
||||
Log: log,
|
||||
Log: options.Log,
|
||||
|
||||
ToolHandler: ToolsHandlers{
|
||||
handlers: map[string]ToolHandler{},
|
||||
},
|
||||
}, nil
|
||||
|
||||
Options: options,
|
||||
}
|
||||
}
|
||||
|
||||
func (client AgentClient) getRequest(body []byte) (*http.Request, error) {
|
||||
@ -226,9 +238,9 @@ func (client *AgentClient) Process(info ToolHandlerInfo, req *AgentRequestBody)
|
||||
return err
|
||||
}
|
||||
|
||||
func (client AgentClient) RunAgent(systemPrompt string, jsonTools string, endToolCall string, query *string, userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
|
||||
func (client AgentClient) RunAgent(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
|
||||
var tools any
|
||||
err := json.Unmarshal([]byte(jsonTools), &tools)
|
||||
err := json.Unmarshal([]byte(client.Options.JsonTools), &tools)
|
||||
|
||||
toolChoice := "any"
|
||||
|
||||
@ -237,7 +249,7 @@ func (client AgentClient) RunAgent(systemPrompt string, jsonTools string, endToo
|
||||
ToolChoice: &toolChoice,
|
||||
Model: "pixtral-12b-2409",
|
||||
Temperature: 0.3,
|
||||
EndToolCall: endToolCall,
|
||||
EndToolCall: client.Options.EndToolCall,
|
||||
ResponseFormat: ResponseFormat{
|
||||
Type: "text",
|
||||
},
|
||||
@ -246,8 +258,8 @@ func (client AgentClient) RunAgent(systemPrompt string, jsonTools string, endToo
|
||||
},
|
||||
}
|
||||
|
||||
request.Chat.AddSystem(systemPrompt)
|
||||
request.Chat.AddImage(imageName, imageData, query)
|
||||
request.Chat.AddSystem(client.Options.SystemPrompt)
|
||||
request.Chat.AddImage(imageName, imageData, client.Options.Query)
|
||||
|
||||
_, err = client.Request(&request)
|
||||
if err != nil {
|
||||
|
@ -91,10 +91,12 @@ type linkContactArguments struct {
|
||||
}
|
||||
|
||||
func NewContactAgent(log *log.Logger, contactModel models.ContactModel) (client.AgentClient, error) {
|
||||
agentClient, err := client.CreateAgentClient(log)
|
||||
if err != nil {
|
||||
return client.AgentClient{}, err
|
||||
}
|
||||
agentClient := client.CreateAgentClient(client.CreateAgentClientOptions{
|
||||
SystemPrompt: contactPrompt,
|
||||
JsonTools: contactTools,
|
||||
Log: log,
|
||||
EndToolCall: "finish",
|
||||
})
|
||||
|
||||
agentClient.ToolHandler.AddTool("listContacts", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
||||
return contactModel.List(context.Background(), info.UserId)
|
||||
|
@ -109,11 +109,12 @@ type linkEventArguments struct {
|
||||
}
|
||||
|
||||
func NewEventAgent(log *log.Logger, eventsModel models.EventModel, locationAgent client.AgentClient) (client.AgentClient, error) {
|
||||
agentClient, err := client.CreateAgentClient(log)
|
||||
|
||||
if err != nil {
|
||||
return client.AgentClient{}, err
|
||||
}
|
||||
agentClient := client.CreateAgentClient(client.CreateAgentClientOptions{
|
||||
SystemPrompt: eventPrompt,
|
||||
JsonTools: eventTools,
|
||||
Log: log,
|
||||
EndToolCall: "finish",
|
||||
})
|
||||
|
||||
agentClient.ToolHandler.AddTool("listEvents", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
||||
return eventsModel.List(context.Background(), info.UserId)
|
||||
@ -177,9 +178,10 @@ func NewEventAgent(log *log.Logger, eventsModel models.EventModel, locationAgent
|
||||
})
|
||||
|
||||
agentClient.ToolHandler.AddTool("getEventLocationId", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
||||
query := "Can you get me the ID of the location present in this image?"
|
||||
locationAgent.Log = log.With("Locations 📍", true)
|
||||
locationAgent.RunAgent(locationPrompt, locationTools, "finish", &query, info.UserId, info.ImageId, info.ImageName, *info.Image)
|
||||
// TODO: reenable this when I'm creating the agent locally instead of getting it from above.
|
||||
// query := "Can you get me the ID of the location present in this image?"
|
||||
// locationAgent.Log = log.With("Locations 📍", true)
|
||||
// locationAgent.RunAgent(info.UserId, info.ImageId, info.ImageName, *info.Image)
|
||||
|
||||
return locationAgent.Reply, nil
|
||||
})
|
||||
|
@ -105,11 +105,12 @@ type linkLocationArguments struct {
|
||||
}
|
||||
|
||||
func NewLocationAgent(log *log.Logger, locationModel models.LocationModel) (client.AgentClient, error) {
|
||||
agentClient, err := client.CreateAgentClient(log)
|
||||
|
||||
if err != nil {
|
||||
return client.AgentClient{}, err
|
||||
}
|
||||
agentClient := client.CreateAgentClient(client.CreateAgentClientOptions{
|
||||
SystemPrompt: locationPrompt,
|
||||
JsonTools: locationTools,
|
||||
Log: log,
|
||||
EndToolCall: "finish",
|
||||
})
|
||||
|
||||
agentClient.ToolHandler.AddTool("listLocations", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
||||
return locationModel.List(context.Background(), info.UserId)
|
||||
|
@ -69,10 +69,9 @@ func (agent NoteAgent) GetNotes(userId uuid.UUID, imageId uuid.UUID, imageName s
|
||||
}
|
||||
|
||||
func NewNoteAgent(log *log.Logger, noteModel models.NoteModel) (NoteAgent, error) {
|
||||
client, err := client.CreateAgentClient(log)
|
||||
if err != nil {
|
||||
return NoteAgent{}, err
|
||||
}
|
||||
client := client.CreateAgentClient(client.CreateAgentClientOptions{
|
||||
SystemPrompt: noteAgentPrompt,
|
||||
})
|
||||
|
||||
agent := NoteAgent{
|
||||
client: client,
|
||||
|
@ -7,7 +7,7 @@ import (
|
||||
"github.com/charmbracelet/log"
|
||||
)
|
||||
|
||||
const OrchestratorPrompt = `
|
||||
const orchestratorPrompt = `
|
||||
You are an Orchestrator for various AI agents.
|
||||
|
||||
The user will send you images and you have to determine which agents you have to call, in order to best help the user.
|
||||
@ -38,7 +38,7 @@ Always call agents in parallel if you need to call more than 1.
|
||||
Do not call the agent if you do not think it is relevant for the image.
|
||||
`
|
||||
|
||||
const OrchestratorTools = `
|
||||
const orchestratorTools = `
|
||||
[
|
||||
{
|
||||
"type": "function",
|
||||
@ -113,11 +113,12 @@ type Status struct {
|
||||
}
|
||||
|
||||
func NewOrchestratorAgent(log *log.Logger, noteAgent NoteAgent, contactAgent client.AgentClient, locationAgent client.AgentClient, eventAgent client.AgentClient, imageName string, imageData []byte) (client.AgentClient, error) {
|
||||
agent, err := client.CreateAgentClient(log)
|
||||
|
||||
if err != nil {
|
||||
return client.AgentClient{}, err
|
||||
}
|
||||
agent := client.CreateAgentClient(client.CreateAgentClientOptions{
|
||||
SystemPrompt: orchestratorPrompt,
|
||||
JsonTools: orchestratorTools,
|
||||
Log: log,
|
||||
EndToolCall: "noAction",
|
||||
})
|
||||
|
||||
agent.ToolHandler.AddTool("noteAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
||||
go noteAgent.GetNotes(info.UserId, info.ImageId, imageName, imageData)
|
||||
@ -128,7 +129,7 @@ func NewOrchestratorAgent(log *log.Logger, noteAgent NoteAgent, contactAgent cli
|
||||
})
|
||||
|
||||
agent.ToolHandler.AddTool("contactAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
||||
go contactAgent.RunAgent(contactPrompt, contactTools, "finish", nil, info.UserId, info.ImageId, imageName, imageData)
|
||||
go contactAgent.RunAgent(info.UserId, info.ImageId, imageName, imageData)
|
||||
|
||||
return Status{
|
||||
Ok: true,
|
||||
@ -136,7 +137,7 @@ func NewOrchestratorAgent(log *log.Logger, noteAgent NoteAgent, contactAgent cli
|
||||
})
|
||||
|
||||
agent.ToolHandler.AddTool("locationAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
||||
go locationAgent.RunAgent(locationPrompt, locationTools, "finish", nil, info.UserId, info.ImageId, imageName, imageData)
|
||||
go locationAgent.RunAgent(info.UserId, info.ImageId, imageName, imageData)
|
||||
|
||||
return Status{
|
||||
Ok: true,
|
||||
@ -144,7 +145,7 @@ func NewOrchestratorAgent(log *log.Logger, noteAgent NoteAgent, contactAgent cli
|
||||
})
|
||||
|
||||
agent.ToolHandler.AddTool("eventAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
||||
go eventAgent.RunAgent(eventPrompt, eventTools, "finish", nil, info.UserId, info.ImageId, imageName, imageData)
|
||||
go eventAgent.RunAgent(info.UserId, info.ImageId, imageName, imageData)
|
||||
|
||||
return Status{
|
||||
Ok: true,
|
||||
|
@ -90,7 +90,7 @@ func ListenNewImageEvents(db *sql.DB, eventManager *EventManager) {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = orchestrator.RunAgent(agents.OrchestratorPrompt, agents.OrchestratorTools, "noAction", nil, image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
|
||||
err = orchestrator.RunAgent(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
|
||||
if err != nil {
|
||||
databaseEventLog.Error("Orchestrator failed", "error", err)
|
||||
return
|
||||
|
Reference in New Issue
Block a user