feat/inter-agent-communication #10

Merged
JohnCosta27 merged 9 commits from feat/inter-agent-communication into main 2025-04-17 17:51:57 +01:00
9 changed files with 210 additions and 155 deletions
Showing only changes of commit c4569e925b - Show all commits

View File

@ -76,15 +76,25 @@ type AgentClient struct {
Reply string Reply string
Do func(req *http.Request) (*http.Response, error) Do func(req *http.Request) (*http.Response, error)
Options CreateAgentClientOptions
} }
const OPENAI_API_KEY = "OPENAI_API_KEY" const OPENAI_API_KEY = "OPENAI_API_KEY"
func CreateAgentClient(log *log.Logger) (AgentClient, error) { type CreateAgentClientOptions struct {
Log *log.Logger
SystemPrompt string
JsonTools string
EndToolCall string
Query *string
}
func CreateAgentClient(options CreateAgentClientOptions) AgentClient {
apiKey := os.Getenv(OPENAI_API_KEY) apiKey := os.Getenv(OPENAI_API_KEY)
if len(apiKey) == 0 { if len(apiKey) == 0 {
return AgentClient{}, errors.New(OPENAI_API_KEY + " was not found.") panic("No api key")
} }
return AgentClient{ return AgentClient{
@ -95,12 +105,14 @@ func CreateAgentClient(log *log.Logger) (AgentClient, error) {
return client.Do(req) return client.Do(req)
}, },
Log: log, Log: options.Log,
ToolHandler: ToolsHandlers{ ToolHandler: ToolsHandlers{
handlers: map[string]ToolHandler{}, handlers: map[string]ToolHandler{},
}, },
}, nil
Options: options,
}
} }
func (client AgentClient) getRequest(body []byte) (*http.Request, error) { func (client AgentClient) getRequest(body []byte) (*http.Request, error) {
@ -226,9 +238,9 @@ func (client *AgentClient) Process(info ToolHandlerInfo, req *AgentRequestBody)
return err return err
} }
func (client AgentClient) RunAgent(systemPrompt string, jsonTools string, endToolCall string, query *string, userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error { func (client AgentClient) RunAgent(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
var tools any var tools any
err := json.Unmarshal([]byte(jsonTools), &tools) err := json.Unmarshal([]byte(client.Options.JsonTools), &tools)
toolChoice := "any" toolChoice := "any"
@ -237,7 +249,7 @@ func (client AgentClient) RunAgent(systemPrompt string, jsonTools string, endToo
ToolChoice: &toolChoice, ToolChoice: &toolChoice,
Model: "pixtral-12b-2409", Model: "pixtral-12b-2409",
Temperature: 0.3, Temperature: 0.3,
EndToolCall: endToolCall, EndToolCall: client.Options.EndToolCall,
ResponseFormat: ResponseFormat{ ResponseFormat: ResponseFormat{
Type: "text", Type: "text",
}, },
@ -246,8 +258,8 @@ func (client AgentClient) RunAgent(systemPrompt string, jsonTools string, endToo
}, },
} }
request.Chat.AddSystem(systemPrompt) request.Chat.AddSystem(client.Options.SystemPrompt)
request.Chat.AddImage(imageName, imageData, query) request.Chat.AddImage(imageName, imageData, client.Options.Query)
_, err = client.Request(&request) _, err = client.Request(&request)
if err != nil { if err != nil {

View File

@ -91,10 +91,12 @@ type linkContactArguments struct {
} }
func NewContactAgent(log *log.Logger, contactModel models.ContactModel) (client.AgentClient, error) { func NewContactAgent(log *log.Logger, contactModel models.ContactModel) (client.AgentClient, error) {
agentClient, err := client.CreateAgentClient(log) agentClient := client.CreateAgentClient(client.CreateAgentClientOptions{
if err != nil { SystemPrompt: contactPrompt,
return client.AgentClient{}, err JsonTools: contactTools,
} Log: log,
EndToolCall: "finish",
})
agentClient.ToolHandler.AddTool("listContacts", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { agentClient.ToolHandler.AddTool("listContacts", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
return contactModel.List(context.Background(), info.UserId) return contactModel.List(context.Background(), info.UserId)

View File

@ -109,11 +109,12 @@ type linkEventArguments struct {
} }
func NewEventAgent(log *log.Logger, eventsModel models.EventModel, locationAgent client.AgentClient) (client.AgentClient, error) { func NewEventAgent(log *log.Logger, eventsModel models.EventModel, locationAgent client.AgentClient) (client.AgentClient, error) {
agentClient, err := client.CreateAgentClient(log) agentClient := client.CreateAgentClient(client.CreateAgentClientOptions{
SystemPrompt: eventPrompt,
if err != nil { JsonTools: eventTools,
return client.AgentClient{}, err Log: log,
} EndToolCall: "finish",
})
agentClient.ToolHandler.AddTool("listEvents", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { agentClient.ToolHandler.AddTool("listEvents", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
return eventsModel.List(context.Background(), info.UserId) return eventsModel.List(context.Background(), info.UserId)
@ -177,9 +178,10 @@ func NewEventAgent(log *log.Logger, eventsModel models.EventModel, locationAgent
}) })
agentClient.ToolHandler.AddTool("getEventLocationId", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { agentClient.ToolHandler.AddTool("getEventLocationId", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
query := "Can you get me the ID of the location present in this image?" // TODO: reenable this when I'm creating the agent locally instead of getting it from above.
locationAgent.Log = log.With("Locations 📍", true) // query := "Can you get me the ID of the location present in this image?"
locationAgent.RunAgent(locationPrompt, locationTools, "finish", &query, info.UserId, info.ImageId, info.ImageName, *info.Image) // locationAgent.Log = log.With("Locations 📍", true)
// locationAgent.RunAgent(info.UserId, info.ImageId, info.ImageName, *info.Image)
return locationAgent.Reply, nil return locationAgent.Reply, nil
}) })

View File

@ -105,11 +105,12 @@ type linkLocationArguments struct {
} }
func NewLocationAgent(log *log.Logger, locationModel models.LocationModel) (client.AgentClient, error) { func NewLocationAgent(log *log.Logger, locationModel models.LocationModel) (client.AgentClient, error) {
agentClient, err := client.CreateAgentClient(log) agentClient := client.CreateAgentClient(client.CreateAgentClientOptions{
SystemPrompt: locationPrompt,
if err != nil { JsonTools: locationTools,
return client.AgentClient{}, err Log: log,
} EndToolCall: "finish",
})
agentClient.ToolHandler.AddTool("listLocations", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { agentClient.ToolHandler.AddTool("listLocations", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
return locationModel.List(context.Background(), info.UserId) return locationModel.List(context.Background(), info.UserId)

View File

@ -69,10 +69,9 @@ func (agent NoteAgent) GetNotes(userId uuid.UUID, imageId uuid.UUID, imageName s
} }
func NewNoteAgent(log *log.Logger, noteModel models.NoteModel) (NoteAgent, error) { func NewNoteAgent(log *log.Logger, noteModel models.NoteModel) (NoteAgent, error) {
client, err := client.CreateAgentClient(log) client := client.CreateAgentClient(client.CreateAgentClientOptions{
if err != nil { SystemPrompt: noteAgentPrompt,
return NoteAgent{}, err })
}
agent := NoteAgent{ agent := NoteAgent{
client: client, client: client,

View File

@ -7,7 +7,7 @@ import (
"github.com/charmbracelet/log" "github.com/charmbracelet/log"
) )
const OrchestratorPrompt = ` const orchestratorPrompt = `
You are an Orchestrator for various AI agents. You are an Orchestrator for various AI agents.
The user will send you images and you have to determine which agents you have to call, in order to best help the user. The user will send you images and you have to determine which agents you have to call, in order to best help the user.
@ -38,7 +38,7 @@ Always call agents in parallel if you need to call more than 1.
Do not call the agent if you do not think it is relevant for the image. Do not call the agent if you do not think it is relevant for the image.
` `
const OrchestratorTools = ` const orchestratorTools = `
[ [
{ {
"type": "function", "type": "function",
@ -113,11 +113,12 @@ type Status struct {
} }
func NewOrchestratorAgent(log *log.Logger, noteAgent NoteAgent, contactAgent client.AgentClient, locationAgent client.AgentClient, eventAgent client.AgentClient, imageName string, imageData []byte) (client.AgentClient, error) { func NewOrchestratorAgent(log *log.Logger, noteAgent NoteAgent, contactAgent client.AgentClient, locationAgent client.AgentClient, eventAgent client.AgentClient, imageName string, imageData []byte) (client.AgentClient, error) {
agent, err := client.CreateAgentClient(log) agent := client.CreateAgentClient(client.CreateAgentClientOptions{
SystemPrompt: orchestratorPrompt,
if err != nil { JsonTools: orchestratorTools,
return client.AgentClient{}, err Log: log,
} EndToolCall: "noAction",
})
agent.ToolHandler.AddTool("noteAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { agent.ToolHandler.AddTool("noteAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
go noteAgent.GetNotes(info.UserId, info.ImageId, imageName, imageData) go noteAgent.GetNotes(info.UserId, info.ImageId, imageName, imageData)
@ -128,7 +129,7 @@ func NewOrchestratorAgent(log *log.Logger, noteAgent NoteAgent, contactAgent cli
}) })
agent.ToolHandler.AddTool("contactAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { agent.ToolHandler.AddTool("contactAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
go contactAgent.RunAgent(contactPrompt, contactTools, "finish", nil, info.UserId, info.ImageId, imageName, imageData) go contactAgent.RunAgent(info.UserId, info.ImageId, imageName, imageData)
return Status{ return Status{
Ok: true, Ok: true,
@ -136,7 +137,7 @@ func NewOrchestratorAgent(log *log.Logger, noteAgent NoteAgent, contactAgent cli
}) })
agent.ToolHandler.AddTool("locationAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { agent.ToolHandler.AddTool("locationAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
go locationAgent.RunAgent(locationPrompt, locationTools, "finish", nil, info.UserId, info.ImageId, imageName, imageData) go locationAgent.RunAgent(info.UserId, info.ImageId, imageName, imageData)
return Status{ return Status{
Ok: true, Ok: true,
@ -144,7 +145,7 @@ func NewOrchestratorAgent(log *log.Logger, noteAgent NoteAgent, contactAgent cli
}) })
agent.ToolHandler.AddTool("eventAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { agent.ToolHandler.AddTool("eventAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
go eventAgent.RunAgent(eventPrompt, eventTools, "finish", nil, info.UserId, info.ImageId, imageName, imageData) go eventAgent.RunAgent(info.UserId, info.ImageId, imageName, imageData)
return Status{ return Status{
Ok: true, Ok: true,

View File

@ -90,7 +90,7 @@ func ListenNewImageEvents(db *sql.DB, eventManager *EventManager) {
panic(err) panic(err)
} }
err = orchestrator.RunAgent(agents.OrchestratorPrompt, agents.OrchestratorTools, "noAction", nil, image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image) err = orchestrator.RunAgent(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
if err != nil { if err != nil {
databaseEventLog.Error("Orchestrator failed", "error", err) databaseEventLog.Error("Orchestrator failed", "error", err)
return return