refactor(agents): encapsulating prompt and calls inside factory method
This commit is contained in:
@ -76,15 +76,25 @@ type AgentClient struct {
|
|||||||
Reply string
|
Reply string
|
||||||
|
|
||||||
Do func(req *http.Request) (*http.Response, error)
|
Do func(req *http.Request) (*http.Response, error)
|
||||||
|
|
||||||
|
Options CreateAgentClientOptions
|
||||||
}
|
}
|
||||||
|
|
||||||
const OPENAI_API_KEY = "OPENAI_API_KEY"
|
const OPENAI_API_KEY = "OPENAI_API_KEY"
|
||||||
|
|
||||||
func CreateAgentClient(log *log.Logger) (AgentClient, error) {
|
type CreateAgentClientOptions struct {
|
||||||
|
Log *log.Logger
|
||||||
|
SystemPrompt string
|
||||||
|
JsonTools string
|
||||||
|
EndToolCall string
|
||||||
|
Query *string
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateAgentClient(options CreateAgentClientOptions) AgentClient {
|
||||||
apiKey := os.Getenv(OPENAI_API_KEY)
|
apiKey := os.Getenv(OPENAI_API_KEY)
|
||||||
|
|
||||||
if len(apiKey) == 0 {
|
if len(apiKey) == 0 {
|
||||||
return AgentClient{}, errors.New(OPENAI_API_KEY + " was not found.")
|
panic("No api key")
|
||||||
}
|
}
|
||||||
|
|
||||||
return AgentClient{
|
return AgentClient{
|
||||||
@ -95,12 +105,14 @@ func CreateAgentClient(log *log.Logger) (AgentClient, error) {
|
|||||||
return client.Do(req)
|
return client.Do(req)
|
||||||
},
|
},
|
||||||
|
|
||||||
Log: log,
|
Log: options.Log,
|
||||||
|
|
||||||
ToolHandler: ToolsHandlers{
|
ToolHandler: ToolsHandlers{
|
||||||
handlers: map[string]ToolHandler{},
|
handlers: map[string]ToolHandler{},
|
||||||
},
|
},
|
||||||
}, nil
|
|
||||||
|
Options: options,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (client AgentClient) getRequest(body []byte) (*http.Request, error) {
|
func (client AgentClient) getRequest(body []byte) (*http.Request, error) {
|
||||||
@ -226,9 +238,9 @@ func (client *AgentClient) Process(info ToolHandlerInfo, req *AgentRequestBody)
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (client AgentClient) RunAgent(systemPrompt string, jsonTools string, endToolCall string, query *string, userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
|
func (client AgentClient) RunAgent(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
|
||||||
var tools any
|
var tools any
|
||||||
err := json.Unmarshal([]byte(jsonTools), &tools)
|
err := json.Unmarshal([]byte(client.Options.JsonTools), &tools)
|
||||||
|
|
||||||
toolChoice := "any"
|
toolChoice := "any"
|
||||||
|
|
||||||
@ -237,7 +249,7 @@ func (client AgentClient) RunAgent(systemPrompt string, jsonTools string, endToo
|
|||||||
ToolChoice: &toolChoice,
|
ToolChoice: &toolChoice,
|
||||||
Model: "pixtral-12b-2409",
|
Model: "pixtral-12b-2409",
|
||||||
Temperature: 0.3,
|
Temperature: 0.3,
|
||||||
EndToolCall: endToolCall,
|
EndToolCall: client.Options.EndToolCall,
|
||||||
ResponseFormat: ResponseFormat{
|
ResponseFormat: ResponseFormat{
|
||||||
Type: "text",
|
Type: "text",
|
||||||
},
|
},
|
||||||
@ -246,8 +258,8 @@ func (client AgentClient) RunAgent(systemPrompt string, jsonTools string, endToo
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
request.Chat.AddSystem(systemPrompt)
|
request.Chat.AddSystem(client.Options.SystemPrompt)
|
||||||
request.Chat.AddImage(imageName, imageData, query)
|
request.Chat.AddImage(imageName, imageData, client.Options.Query)
|
||||||
|
|
||||||
_, err = client.Request(&request)
|
_, err = client.Request(&request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -91,10 +91,12 @@ type linkContactArguments struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewContactAgent(log *log.Logger, contactModel models.ContactModel) (client.AgentClient, error) {
|
func NewContactAgent(log *log.Logger, contactModel models.ContactModel) (client.AgentClient, error) {
|
||||||
agentClient, err := client.CreateAgentClient(log)
|
agentClient := client.CreateAgentClient(client.CreateAgentClientOptions{
|
||||||
if err != nil {
|
SystemPrompt: contactPrompt,
|
||||||
return client.AgentClient{}, err
|
JsonTools: contactTools,
|
||||||
}
|
Log: log,
|
||||||
|
EndToolCall: "finish",
|
||||||
|
})
|
||||||
|
|
||||||
agentClient.ToolHandler.AddTool("listContacts", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
agentClient.ToolHandler.AddTool("listContacts", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
||||||
return contactModel.List(context.Background(), info.UserId)
|
return contactModel.List(context.Background(), info.UserId)
|
||||||
|
@ -109,11 +109,12 @@ type linkEventArguments struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewEventAgent(log *log.Logger, eventsModel models.EventModel, locationAgent client.AgentClient) (client.AgentClient, error) {
|
func NewEventAgent(log *log.Logger, eventsModel models.EventModel, locationAgent client.AgentClient) (client.AgentClient, error) {
|
||||||
agentClient, err := client.CreateAgentClient(log)
|
agentClient := client.CreateAgentClient(client.CreateAgentClientOptions{
|
||||||
|
SystemPrompt: eventPrompt,
|
||||||
if err != nil {
|
JsonTools: eventTools,
|
||||||
return client.AgentClient{}, err
|
Log: log,
|
||||||
}
|
EndToolCall: "finish",
|
||||||
|
})
|
||||||
|
|
||||||
agentClient.ToolHandler.AddTool("listEvents", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
agentClient.ToolHandler.AddTool("listEvents", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
||||||
return eventsModel.List(context.Background(), info.UserId)
|
return eventsModel.List(context.Background(), info.UserId)
|
||||||
@ -177,9 +178,10 @@ func NewEventAgent(log *log.Logger, eventsModel models.EventModel, locationAgent
|
|||||||
})
|
})
|
||||||
|
|
||||||
agentClient.ToolHandler.AddTool("getEventLocationId", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
agentClient.ToolHandler.AddTool("getEventLocationId", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
||||||
query := "Can you get me the ID of the location present in this image?"
|
// TODO: reenable this when I'm creating the agent locally instead of getting it from above.
|
||||||
locationAgent.Log = log.With("Locations 📍", true)
|
// query := "Can you get me the ID of the location present in this image?"
|
||||||
locationAgent.RunAgent(locationPrompt, locationTools, "finish", &query, info.UserId, info.ImageId, info.ImageName, *info.Image)
|
// locationAgent.Log = log.With("Locations 📍", true)
|
||||||
|
// locationAgent.RunAgent(info.UserId, info.ImageId, info.ImageName, *info.Image)
|
||||||
|
|
||||||
return locationAgent.Reply, nil
|
return locationAgent.Reply, nil
|
||||||
})
|
})
|
||||||
|
@ -105,11 +105,12 @@ type linkLocationArguments struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewLocationAgent(log *log.Logger, locationModel models.LocationModel) (client.AgentClient, error) {
|
func NewLocationAgent(log *log.Logger, locationModel models.LocationModel) (client.AgentClient, error) {
|
||||||
agentClient, err := client.CreateAgentClient(log)
|
agentClient := client.CreateAgentClient(client.CreateAgentClientOptions{
|
||||||
|
SystemPrompt: locationPrompt,
|
||||||
if err != nil {
|
JsonTools: locationTools,
|
||||||
return client.AgentClient{}, err
|
Log: log,
|
||||||
}
|
EndToolCall: "finish",
|
||||||
|
})
|
||||||
|
|
||||||
agentClient.ToolHandler.AddTool("listLocations", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
agentClient.ToolHandler.AddTool("listLocations", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
||||||
return locationModel.List(context.Background(), info.UserId)
|
return locationModel.List(context.Background(), info.UserId)
|
||||||
|
@ -69,10 +69,9 @@ func (agent NoteAgent) GetNotes(userId uuid.UUID, imageId uuid.UUID, imageName s
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewNoteAgent(log *log.Logger, noteModel models.NoteModel) (NoteAgent, error) {
|
func NewNoteAgent(log *log.Logger, noteModel models.NoteModel) (NoteAgent, error) {
|
||||||
client, err := client.CreateAgentClient(log)
|
client := client.CreateAgentClient(client.CreateAgentClientOptions{
|
||||||
if err != nil {
|
SystemPrompt: noteAgentPrompt,
|
||||||
return NoteAgent{}, err
|
})
|
||||||
}
|
|
||||||
|
|
||||||
agent := NoteAgent{
|
agent := NoteAgent{
|
||||||
client: client,
|
client: client,
|
||||||
|
@ -7,7 +7,7 @@ import (
|
|||||||
"github.com/charmbracelet/log"
|
"github.com/charmbracelet/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
const OrchestratorPrompt = `
|
const orchestratorPrompt = `
|
||||||
You are an Orchestrator for various AI agents.
|
You are an Orchestrator for various AI agents.
|
||||||
|
|
||||||
The user will send you images and you have to determine which agents you have to call, in order to best help the user.
|
The user will send you images and you have to determine which agents you have to call, in order to best help the user.
|
||||||
@ -38,7 +38,7 @@ Always call agents in parallel if you need to call more than 1.
|
|||||||
Do not call the agent if you do not think it is relevant for the image.
|
Do not call the agent if you do not think it is relevant for the image.
|
||||||
`
|
`
|
||||||
|
|
||||||
const OrchestratorTools = `
|
const orchestratorTools = `
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
@ -113,11 +113,12 @@ type Status struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewOrchestratorAgent(log *log.Logger, noteAgent NoteAgent, contactAgent client.AgentClient, locationAgent client.AgentClient, eventAgent client.AgentClient, imageName string, imageData []byte) (client.AgentClient, error) {
|
func NewOrchestratorAgent(log *log.Logger, noteAgent NoteAgent, contactAgent client.AgentClient, locationAgent client.AgentClient, eventAgent client.AgentClient, imageName string, imageData []byte) (client.AgentClient, error) {
|
||||||
agent, err := client.CreateAgentClient(log)
|
agent := client.CreateAgentClient(client.CreateAgentClientOptions{
|
||||||
|
SystemPrompt: orchestratorPrompt,
|
||||||
if err != nil {
|
JsonTools: orchestratorTools,
|
||||||
return client.AgentClient{}, err
|
Log: log,
|
||||||
}
|
EndToolCall: "noAction",
|
||||||
|
})
|
||||||
|
|
||||||
agent.ToolHandler.AddTool("noteAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
agent.ToolHandler.AddTool("noteAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
||||||
go noteAgent.GetNotes(info.UserId, info.ImageId, imageName, imageData)
|
go noteAgent.GetNotes(info.UserId, info.ImageId, imageName, imageData)
|
||||||
@ -128,7 +129,7 @@ func NewOrchestratorAgent(log *log.Logger, noteAgent NoteAgent, contactAgent cli
|
|||||||
})
|
})
|
||||||
|
|
||||||
agent.ToolHandler.AddTool("contactAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
agent.ToolHandler.AddTool("contactAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
||||||
go contactAgent.RunAgent(contactPrompt, contactTools, "finish", nil, info.UserId, info.ImageId, imageName, imageData)
|
go contactAgent.RunAgent(info.UserId, info.ImageId, imageName, imageData)
|
||||||
|
|
||||||
return Status{
|
return Status{
|
||||||
Ok: true,
|
Ok: true,
|
||||||
@ -136,7 +137,7 @@ func NewOrchestratorAgent(log *log.Logger, noteAgent NoteAgent, contactAgent cli
|
|||||||
})
|
})
|
||||||
|
|
||||||
agent.ToolHandler.AddTool("locationAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
agent.ToolHandler.AddTool("locationAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
||||||
go locationAgent.RunAgent(locationPrompt, locationTools, "finish", nil, info.UserId, info.ImageId, imageName, imageData)
|
go locationAgent.RunAgent(info.UserId, info.ImageId, imageName, imageData)
|
||||||
|
|
||||||
return Status{
|
return Status{
|
||||||
Ok: true,
|
Ok: true,
|
||||||
@ -144,7 +145,7 @@ func NewOrchestratorAgent(log *log.Logger, noteAgent NoteAgent, contactAgent cli
|
|||||||
})
|
})
|
||||||
|
|
||||||
agent.ToolHandler.AddTool("eventAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
agent.ToolHandler.AddTool("eventAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
|
||||||
go eventAgent.RunAgent(eventPrompt, eventTools, "finish", nil, info.UserId, info.ImageId, imageName, imageData)
|
go eventAgent.RunAgent(info.UserId, info.ImageId, imageName, imageData)
|
||||||
|
|
||||||
return Status{
|
return Status{
|
||||||
Ok: true,
|
Ok: true,
|
||||||
|
@ -90,7 +90,7 @@ func ListenNewImageEvents(db *sql.DB, eventManager *EventManager) {
|
|||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = orchestrator.RunAgent(agents.OrchestratorPrompt, agents.OrchestratorTools, "noAction", nil, image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
|
err = orchestrator.RunAgent(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
databaseEventLog.Error("Orchestrator failed", "error", err)
|
databaseEventLog.Error("Orchestrator failed", "error", err)
|
||||||
return
|
return
|
||||||
|
Reference in New Issue
Block a user