diff --git a/backend/agents/client/client.go b/backend/agents/client/client.go index 34c463b..464dba2 100644 --- a/backend/agents/client/client.go +++ b/backend/agents/client/client.go @@ -152,10 +152,6 @@ func (imageMessage AgentImage) ToJson() ([]byte, error) { return json.Marshal(imageMessage) } -type AiClient interface { - GetImageInfo(imageName string, imageData []byte) (ImageInfo, error) -} - type ResponseChoiceMessage struct { Role string `json:"role"` Content string `json:"content"` @@ -253,8 +249,6 @@ func (client AgentClient) Request(request *AgentRequestBody) (AgentResponse, err return AgentResponse{}, err } - log.Println(string(response)) - toolCalls := agentResponse.Choices[0].Message.ToolCalls if len(toolCalls) > 0 { // Should for sure be more flexible. diff --git a/backend/agents/orchestrator.go b/backend/agents/orchestrator.go index cff08b3..4fe7d3a 100644 --- a/backend/agents/orchestrator.go +++ b/backend/agents/orchestrator.go @@ -2,7 +2,7 @@ package agents import ( "encoding/json" - "fmt" + "errors" "screenmark/screenmark/agents/client" "github.com/google/uuid" @@ -78,6 +78,12 @@ type OrchestratorAgent struct { client client.AgentClient } +type Status struct { + Ok bool `json:"ok"` +} + +// TODO: the primary function of the agent could be extracted outwards. +// This is basically the same function as we have in the `event_location_agent.go` func (agent OrchestratorAgent) Orchestrate(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error { toolChoice := "any" @@ -103,22 +109,52 @@ func (agent OrchestratorAgent) Orchestrate(userId uuid.UUID, imageId uuid.UUID, } request.AddImage(imageName, imageData) - resp, err := agent.client.Request(&request) + _, err = agent.client.Request(&request) if err != nil { return err } - fmt.Println(resp) + toolHandlerInfo := client.ToolHandlerInfo{ + ImageId: imageId, + UserId: userId, + } - return nil + return agent.client.Process(toolHandlerInfo, request) } -func NewOrchestratorAgent() (OrchestratorAgent, error) { +func NewOrchestratorAgent(eventLocationAgent EventLocationAgent, noteAgent NoteAgent, imageName string, imageData []byte) (OrchestratorAgent, error) { agent, err := client.CreateAgentClient(orchestratorPrompt) if err != nil { return OrchestratorAgent{}, err } + agent.ToolHandler.AddTool("eventLocationAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { + // We need a way to keep track of this async? + // Probably just a DB, because we don't want to wait. The orchistrator shouldnt wait for this stuff to finish. + + eventLocationAgent.GetLocations(info.UserId, info.ImageId, imageName, imageData) + + return Status{ + Ok: true, + }, nil + }) + + agent.ToolHandler.AddTool("noteAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { + noteAgent.GetNotes(info.UserId, info.ImageId, imageName, imageData) + + return Status{ + Ok: true, + }, nil + }) + + agent.ToolHandler.AddTool("defaultAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { + // To nothing + + return Status{ + Ok: true, + }, errors.New("Finished! Kinda bad return type but...") + }) + return OrchestratorAgent{ client: agent, }, nil diff --git a/backend/main.go b/backend/main.go index 8e4230a..205815f 100644 --- a/backend/main.go +++ b/backend/main.go @@ -85,11 +85,6 @@ func main() { panic(err) } - orchestrator, err := agents.NewOrchestratorAgent() - if err != nil { - panic(err) - } - image, err := imageModel.GetToProcessWithData(ctx, imageId) if err != nil { log.Println("Failed to GetToProcessWithData") @@ -104,17 +99,12 @@ func main() { return } + orchestrator, err := agents.NewOrchestratorAgent(locationAgent, noteAgent, image.Image.ImageName, image.Image.Image) + if err != nil { + panic(err) + } + orchestrator.Orchestrate(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image) - - // TODO: this can very much be parallel - - log.Println("Calling locationAgent!") - err = locationAgent.GetLocations(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image) - log.Println(err) - - log.Println("Calling noteAgent!") - err = noteAgent.GetNotes(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image) - log.Println(err) }() } }