feat/inter-agent-communication #10

Merged
JohnCosta27 merged 9 commits from feat/inter-agent-communication into main 2025-04-17 17:51:57 +01:00
9 changed files with 221 additions and 176 deletions

View File

@ -65,8 +65,8 @@ func (m ChatUserMessage) MarshalJSON() ([]byte, error) {
}) })
case ArrayMessage: case ArrayMessage:
return json.Marshal(&struct { return json.Marshal(&struct {
Role UserRole `json:"role"` Role UserRole `json:"role"`
Content []ImageMessageContent `json:"content"` Content []MessageContentMessage `json:"content"`
}{ }{
Role: User, Role: User,
Content: t.Content, Content: t.Content,
@ -121,18 +121,35 @@ func (m SingleMessage) IsSingleMessage() bool {
} }
type ArrayMessage struct { type ArrayMessage struct {
Content []ImageMessageContent `json:"content"` Content []MessageContentMessage `json:"content"`
} }
func (m ArrayMessage) IsSingleMessage() bool { func (m ArrayMessage) IsSingleMessage() bool {
return false return false
} }
type MessageContentMessage interface {
IsImageMessage() bool
}
type TextMessageContent struct {
TextType string `json:"type"`
Text string `json:"text"`
}
func (m TextMessageContent) IsImageMessage() bool {
return false
}
type ImageMessageContent struct { type ImageMessageContent struct {
ImageType string `json:"type"` ImageType string `json:"type"`
ImageUrl string `json:"image_url"` ImageUrl string `json:"image_url"`
} }
func (m ImageMessageContent) IsImageMessage() bool {
return true
}
type ImageContentUrl struct { type ImageContentUrl struct {
Url string `json:"url"` Url string `json:"url"`
} }
@ -165,7 +182,7 @@ func (chat *Chat) AddSystem(prompt string) {
}) })
} }
func (chat *Chat) AddImage(imageName string, image []byte) error { func (chat *Chat) AddImage(imageName string, image []byte, query *string) error {
extension := filepath.Ext(imageName) extension := filepath.Ext(imageName)
if len(extension) == 0 { if len(extension) == 0 {
// TODO: could also validate for image types we support. // TODO: could also validate for image types we support.
@ -173,14 +190,28 @@ func (chat *Chat) AddImage(imageName string, image []byte) error {
} }
extension = extension[1:] extension = extension[1:]
encodedString := base64.StdEncoding.EncodeToString(image) encodedString := base64.StdEncoding.EncodeToString(image)
messageContent := ArrayMessage{ contentLength := 1
Content: make([]ImageMessageContent, 1), if query != nil {
contentLength = 2
} }
messageContent.Content[0] = ImageMessageContent{ messageContent := ArrayMessage{
Content: make([]MessageContentMessage, contentLength),
}
index := 0
if query != nil {
messageContent.Content[index] = TextMessageContent{
TextType: "text",
Text: *query,
}
index += 1
}
messageContent.Content[index] = ImageMessageContent{
ImageType: "image_url", ImageType: "image_url",
ImageUrl: fmt.Sprintf("data:image/%s;base64,%s", extension, encodedString), ImageUrl: fmt.Sprintf("data:image/%s;base64,%s", extension, encodedString),
} }

View File

@ -73,16 +73,28 @@ type AgentClient struct {
Log *log.Logger Log *log.Logger
Reply string
Do func(req *http.Request) (*http.Response, error) Do func(req *http.Request) (*http.Response, error)
Options CreateAgentClientOptions
} }
const OPENAI_API_KEY = "OPENAI_API_KEY" const OPENAI_API_KEY = "OPENAI_API_KEY"
func CreateAgentClient(log *log.Logger) (AgentClient, error) { type CreateAgentClientOptions struct {
Log *log.Logger
SystemPrompt string
JsonTools string
EndToolCall string
Query *string
}
func CreateAgentClient(options CreateAgentClientOptions) AgentClient {
apiKey := os.Getenv(OPENAI_API_KEY) apiKey := os.Getenv(OPENAI_API_KEY)
if len(apiKey) == 0 { if len(apiKey) == 0 {
return AgentClient{}, errors.New(OPENAI_API_KEY + " was not found.") panic("No api key")
} }
return AgentClient{ return AgentClient{
@ -93,12 +105,14 @@ func CreateAgentClient(log *log.Logger) (AgentClient, error) {
return client.Do(req) return client.Do(req)
}, },
Log: log, Log: options.Log,
ToolHandler: ToolsHandlers{ ToolHandler: ToolsHandlers{
handlers: map[string]ToolHandler{}, handlers: map[string]ToolHandler{},
}, },
}, nil
Options: options,
}
} }
func (client AgentClient) getRequest(body []byte) (*http.Request, error) { func (client AgentClient) getRequest(body []byte) (*http.Request, error) {
@ -146,8 +160,6 @@ func (client AgentClient) Request(req *AgentRequestBody) (AgentResponse, error)
return AgentResponse{}, errors.New("Unsupported. We currently only accept 1 choice from AI.") return AgentResponse{}, errors.New("Unsupported. We currently only accept 1 choice from AI.")
} }
client.Log.SetLevel(log.DebugLevel)
msg := agentResponse.Choices[0].Message msg := agentResponse.Choices[0].Message
if len(msg.Content) > 0 { if len(msg.Content) > 0 {
@ -170,7 +182,7 @@ func (client AgentClient) Request(req *AgentRequestBody) (AgentResponse, error)
return agentResponse, nil return agentResponse, nil
} }
func (client AgentClient) ToolLoop(info ToolHandlerInfo, req *AgentRequestBody) error { func (client *AgentClient) ToolLoop(info ToolHandlerInfo, req *AgentRequestBody) error {
for { for {
err := client.Process(info, req) err := client.Process(info, req)
if err != nil { if err != nil {
@ -186,7 +198,7 @@ func (client AgentClient) ToolLoop(info ToolHandlerInfo, req *AgentRequestBody)
var FinishedCall = errors.New("Last tool tool was called") var FinishedCall = errors.New("Last tool tool was called")
func (client AgentClient) Process(info ToolHandlerInfo, req *AgentRequestBody) error { func (client *AgentClient) Process(info ToolHandlerInfo, req *AgentRequestBody) error {
var err error var err error
message, err := req.Chat.GetLatest() message, err := req.Chat.GetLatest()
@ -211,7 +223,10 @@ func (client AgentClient) Process(info ToolHandlerInfo, req *AgentRequestBody) e
toolResponse := client.ToolHandler.Handle(info, toolCall) toolResponse := client.ToolHandler.Handle(info, toolCall)
client.Log.SetLevel(log.DebugLevel) if toolCall.Function.Name == "reply" {
client.Reply = toolCall.Function.Arguments
}
client.Log.Debugf("Response: %s", toolResponse.Content) client.Log.Debugf("Response: %s", toolResponse.Content)
req.Chat.AddToolResponse(toolResponse) req.Chat.AddToolResponse(toolResponse)
@ -220,9 +235,9 @@ func (client AgentClient) Process(info ToolHandlerInfo, req *AgentRequestBody) e
return err return err
} }
func (client AgentClient) RunAgent(systemPrompt string, jsonTools string, endToolCall string, userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error { func (client *AgentClient) RunAgent(userId uuid.UUID, imageId uuid.UUID, imageName string, imageData []byte) error {
var tools any var tools any
err := json.Unmarshal([]byte(jsonTools), &tools) err := json.Unmarshal([]byte(client.Options.JsonTools), &tools)
toolChoice := "any" toolChoice := "any"
@ -231,7 +246,7 @@ func (client AgentClient) RunAgent(systemPrompt string, jsonTools string, endToo
ToolChoice: &toolChoice, ToolChoice: &toolChoice,
Model: "pixtral-12b-2409", Model: "pixtral-12b-2409",
Temperature: 0.3, Temperature: 0.3,
EndToolCall: endToolCall, EndToolCall: client.Options.EndToolCall,
ResponseFormat: ResponseFormat{ ResponseFormat: ResponseFormat{
Type: "text", Type: "text",
}, },
@ -240,8 +255,8 @@ func (client AgentClient) RunAgent(systemPrompt string, jsonTools string, endToo
}, },
} }
request.Chat.AddSystem(systemPrompt) request.Chat.AddSystem(client.Options.SystemPrompt)
request.Chat.AddImage(imageName, imageData) request.Chat.AddImage(imageName, imageData, client.Options.Query)
_, err = client.Request(&request) _, err = client.Request(&request)
if err != nil { if err != nil {
@ -249,8 +264,10 @@ func (client AgentClient) RunAgent(systemPrompt string, jsonTools string, endToo
} }
toolHandlerInfo := ToolHandlerInfo{ toolHandlerInfo := ToolHandlerInfo{
ImageId: imageId, ImageId: imageId,
UserId: userId, ImageName: imageName,
UserId: userId,
Image: &imageData,
} }
return client.ToolLoop(toolHandlerInfo, &request) return client.ToolLoop(toolHandlerInfo, &request)

View File

@ -8,8 +8,12 @@ import (
) )
type ToolHandlerInfo struct { type ToolHandlerInfo struct {
UserId uuid.UUID UserId uuid.UUID
ImageId uuid.UUID ImageId uuid.UUID
ImageName string
// Pointer because we don't want to copy this around too much.
Image *[]byte
} }
type ToolHandler struct { type ToolHandler struct {

View File

@ -3,11 +3,9 @@ package agents
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"os"
"screenmark/screenmark/.gen/haystack/haystack/model" "screenmark/screenmark/.gen/haystack/haystack/model"
"screenmark/screenmark/agents/client" "screenmark/screenmark/agents/client"
"screenmark/screenmark/models" "screenmark/screenmark/models"
"time"
"github.com/charmbracelet/log" "github.com/charmbracelet/log"
"github.com/google/uuid" "github.com/google/uuid"
@ -81,12 +79,6 @@ const contactTools = `
] ]
` `
type ContactAgent struct {
client client.AgentClient
contactModel models.ContactModel
}
type listContactsArguments struct{} type listContactsArguments struct{}
type createContactsArguments struct { type createContactsArguments struct {
Name string `json:"name"` Name string `json:"name"`
@ -98,23 +90,16 @@ type linkContactArguments struct {
ContactID string `json:"contactId"` ContactID string `json:"contactId"`
} }
func NewContactAgent(contactModel models.ContactModel) (ContactAgent, error) { func NewContactAgent(log *log.Logger, contactModel models.ContactModel) client.AgentClient {
agentClient, err := client.CreateAgentClient(log.NewWithOptions(os.Stdout, log.Options{ agentClient := client.CreateAgentClient(client.CreateAgentClientOptions{
ReportTimestamp: true, SystemPrompt: contactPrompt,
TimeFormat: time.Kitchen, JsonTools: contactTools,
Prefix: "Contacts 👥", Log: log,
})) EndToolCall: "finish",
if err != nil { })
return ContactAgent{}, err
}
agent := ContactAgent{
client: agentClient,
contactModel: contactModel,
}
agentClient.ToolHandler.AddTool("listContacts", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { agentClient.ToolHandler.AddTool("listContacts", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
return agent.contactModel.List(context.Background(), info.UserId) return contactModel.List(context.Background(), info.UserId)
}) })
agentClient.ToolHandler.AddTool("createContact", func(info client.ToolHandlerInfo, _args string, call client.ToolCall) (any, error) { agentClient.ToolHandler.AddTool("createContact", func(info client.ToolHandlerInfo, _args string, call client.ToolCall) (any, error) {
@ -126,7 +111,7 @@ func NewContactAgent(contactModel models.ContactModel) (ContactAgent, error) {
ctx := context.Background() ctx := context.Background()
contact, err := agent.contactModel.Save(ctx, info.UserId, model.Contacts{ contact, err := contactModel.Save(ctx, info.UserId, model.Contacts{
Name: args.Name, Name: args.Name,
PhoneNumber: args.PhoneNumber, PhoneNumber: args.PhoneNumber,
Email: args.Email, Email: args.Email,
@ -136,7 +121,7 @@ func NewContactAgent(contactModel models.ContactModel) (ContactAgent, error) {
return model.Contacts{}, err return model.Contacts{}, err
} }
_, err = agent.contactModel.SaveToImage(ctx, info.ImageId, contact.ID) _, err = contactModel.SaveToImage(ctx, info.ImageId, contact.ID)
if err != nil { if err != nil {
return model.Contacts{}, err return model.Contacts{}, err
} }
@ -158,7 +143,7 @@ func NewContactAgent(contactModel models.ContactModel) (ContactAgent, error) {
return "", err return "", err
} }
_, err = agent.contactModel.SaveToImage(ctx, info.ImageId, contactUuid) _, err = contactModel.SaveToImage(ctx, info.ImageId, contactUuid)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -166,5 +151,5 @@ func NewContactAgent(contactModel models.ContactModel) (ContactAgent, error) {
return "Saved", nil return "Saved", nil
}) })
return agent, nil return agentClient
} }

View File

@ -3,7 +3,6 @@ package agents
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"os"
"screenmark/screenmark/.gen/haystack/haystack/model" "screenmark/screenmark/.gen/haystack/haystack/model"
"screenmark/screenmark/agents/client" "screenmark/screenmark/agents/client"
"screenmark/screenmark/models" "screenmark/screenmark/models"
@ -27,6 +26,9 @@ Lists the users already existing events.
createEvent createEvent
Use this to create a new events. Use this to create a new events.
getEventLocationId
Use this if the image contains a location or place. This tool will return the locationId.
finish finish
Call when there is nothing else to do. Call when there is nothing else to do.
` `
@ -63,11 +65,28 @@ const eventTools = `
"endDateTime": { "endDateTime": {
"type": "string", "type": "string",
"description": "The end time as an ISO string" "description": "The end time as an ISO string"
} },
"locationId": {
"type": "string",
"description": "The UUID of this location. You should use getEventLocationId to get this information, but only if you believe the event contains a location"
}
}, },
"required": ["name"] "required": ["name"]
} }
} }
},
{
"type": "function",
"function": {
"name": "getEventLocationId",
"description": "Get the ID of the location on the image, only use if the event contains a location or place.",
"parameters": {
"type": "object",
"properties": {},
"required": []
}
}
}, },
{ {
"type": "function", "type": "function",
@ -83,12 +102,6 @@ const eventTools = `
} }
]` ]`
type EventAgent struct {
client client.AgentClient
eventsModel models.EventModel
}
type listEventArguments struct{} type listEventArguments struct{}
type createEventArguments struct { type createEventArguments struct {
Name string `json:"name"` Name string `json:"name"`
@ -100,24 +113,20 @@ type linkEventArguments struct {
EventID string `json:"eventId"` EventID string `json:"eventId"`
} }
func NewEventAgent(eventsModel models.EventModel) (EventAgent, error) { func NewEventAgent(log *log.Logger, eventsModel models.EventModel, locationModel models.LocationModel) client.AgentClient {
agentClient, err := client.CreateAgentClient(log.NewWithOptions(os.Stdout, log.Options{ agentClient := client.CreateAgentClient(client.CreateAgentClientOptions{
ReportTimestamp: true, SystemPrompt: eventPrompt,
TimeFormat: time.Kitchen, JsonTools: eventTools,
Prefix: "Events 📍", Log: log,
})) EndToolCall: "finish",
})
if err != nil { locationAgent := NewLocationAgent(log.WithPrefix("Events 📅 > Locations 📍"), locationModel)
return EventAgent{}, err locationQuery := "Can you get me the ID of the location present in this image?"
} locationAgent.Options.Query = &locationQuery
agent := EventAgent{
client: agentClient,
eventsModel: eventsModel,
}
agentClient.ToolHandler.AddTool("listEvents", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { agentClient.ToolHandler.AddTool("listEvents", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
return agent.eventsModel.List(context.Background(), info.UserId) return eventsModel.List(context.Background(), info.UserId)
}) })
agentClient.ToolHandler.AddTool("createEvent", func(info client.ToolHandlerInfo, _args string, call client.ToolCall) (any, error) { agentClient.ToolHandler.AddTool("createEvent", func(info client.ToolHandlerInfo, _args string, call client.ToolCall) (any, error) {
@ -141,7 +150,7 @@ func NewEventAgent(eventsModel models.EventModel) (EventAgent, error) {
return model.Events{}, err return model.Events{}, err
} }
events, err := agent.eventsModel.Save(ctx, info.UserId, model.Events{ events, err := eventsModel.Save(ctx, info.UserId, model.Events{
Name: args.Name, Name: args.Name,
StartDateTime: &startTime, StartDateTime: &startTime,
EndDateTime: &endTime, EndDateTime: &endTime,
@ -151,7 +160,7 @@ func NewEventAgent(eventsModel models.EventModel) (EventAgent, error) {
return model.Events{}, err return model.Events{}, err
} }
_, err = agent.eventsModel.SaveToImage(ctx, info.ImageId, events.ID) _, err = eventsModel.SaveToImage(ctx, info.ImageId, events.ID)
if err != nil { if err != nil {
return model.Events{}, err return model.Events{}, err
} }
@ -173,9 +182,17 @@ func NewEventAgent(eventsModel models.EventModel) (EventAgent, error) {
return "", err return "", err
} }
agent.eventsModel.SaveToImage(ctx, info.ImageId, contactUuid) eventsModel.SaveToImage(ctx, info.ImageId, contactUuid)
return "Saved", nil return "Saved", nil
}) })
return agent, nil agentClient.ToolHandler.AddTool("getEventLocationId", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
// TODO: reenable this when I'm creating the agent locally instead of getting it from above.
locationAgent.RunAgent(info.UserId, info.ImageId, info.ImageName, *info.Image)
log.Debugf("Reply from location %s\n", locationAgent.Reply)
return locationAgent.Reply, nil
})
return agentClient
} }

View File

@ -3,11 +3,9 @@ package agents
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"os"
"screenmark/screenmark/.gen/haystack/haystack/model" "screenmark/screenmark/.gen/haystack/haystack/model"
"screenmark/screenmark/agents/client" "screenmark/screenmark/agents/client"
"screenmark/screenmark/models" "screenmark/screenmark/models"
"time"
"github.com/charmbracelet/log" "github.com/charmbracelet/log"
"github.com/google/uuid" "github.com/google/uuid"
@ -27,6 +25,9 @@ Lists the users already existing locations.
createLocation createLocation
Use this to create a new location, when you don't see a matching one from listLocations call. Use this to create a new location, when you don't see a matching one from listLocations call.
reply
Use this only if the user has asked a question about a location.
finish finish
Call when there is nothing else to do. Call when there is nothing else to do.
` `
@ -63,6 +64,22 @@ const locationTools = `
"required": ["name"] "required": ["name"]
} }
} }
},
{
"type": "function",
"function": {
"name": "reply",
"description": "Reply to a user query, only if the user has asked something",
"parameters": {
"type": "object",
"properties": {
"locationId": {
"type": "string"
}
},
"required": ["locationId"]
}
}
}, },
{ {
"type": "function", "type": "function",
@ -78,12 +95,6 @@ const locationTools = `
} }
]` ]`
type LocationAgent struct {
client client.AgentClient
locationModel models.LocationModel
}
type listLocationArguments struct{} type listLocationArguments struct{}
type createLocationArguments struct { type createLocationArguments struct {
Name string `json:"name"` Name string `json:"name"`
@ -93,24 +104,16 @@ type linkLocationArguments struct {
LocationID string `json:"locationId"` LocationID string `json:"locationId"`
} }
func NewLocationAgent(locationModel models.LocationModel) (LocationAgent, error) { func NewLocationAgent(log *log.Logger, locationModel models.LocationModel) client.AgentClient {
agentClient, err := client.CreateAgentClient(log.NewWithOptions(os.Stdout, log.Options{ agentClient := client.CreateAgentClient(client.CreateAgentClientOptions{
ReportTimestamp: true, SystemPrompt: locationPrompt,
TimeFormat: time.Kitchen, JsonTools: locationTools,
Prefix: "Locations 📍", Log: log,
})) EndToolCall: "finish",
})
if err != nil {
return LocationAgent{}, err
}
agent := LocationAgent{
client: agentClient,
locationModel: locationModel,
}
agentClient.ToolHandler.AddTool("listLocations", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { agentClient.ToolHandler.AddTool("listLocations", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
return agent.locationModel.List(context.Background(), info.UserId) return locationModel.List(context.Background(), info.UserId)
}) })
agentClient.ToolHandler.AddTool("createLocation", func(info client.ToolHandlerInfo, _args string, call client.ToolCall) (any, error) { agentClient.ToolHandler.AddTool("createLocation", func(info client.ToolHandlerInfo, _args string, call client.ToolCall) (any, error) {
@ -122,7 +125,7 @@ func NewLocationAgent(locationModel models.LocationModel) (LocationAgent, error)
ctx := context.Background() ctx := context.Background()
location, err := agent.locationModel.Save(ctx, info.UserId, model.Locations{ location, err := locationModel.Save(ctx, info.UserId, model.Locations{
Name: args.Name, Name: args.Name,
Address: args.Address, Address: args.Address,
}) })
@ -131,7 +134,7 @@ func NewLocationAgent(locationModel models.LocationModel) (LocationAgent, error)
return model.Locations{}, err return model.Locations{}, err
} }
_, err = agent.locationModel.SaveToImage(ctx, info.ImageId, location.ID) _, err = locationModel.SaveToImage(ctx, info.ImageId, location.ID)
if err != nil { if err != nil {
return model.Locations{}, err return model.Locations{}, err
} }
@ -153,9 +156,13 @@ func NewLocationAgent(locationModel models.LocationModel) (LocationAgent, error)
return "", err return "", err
} }
agent.locationModel.SaveToImage(ctx, info.ImageId, contactUuid) locationModel.SaveToImage(ctx, info.ImageId, contactUuid)
return "Saved", nil return "Saved", nil
}) })
return agent, nil agentClient.ToolHandler.AddTool("reply", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
return "ok", nil
})
return agentClient
} }

View File

@ -2,11 +2,9 @@ package agents
import ( import (
"context" "context"
"os"
"screenmark/screenmark/.gen/haystack/haystack/model" "screenmark/screenmark/.gen/haystack/haystack/model"
"screenmark/screenmark/agents/client" "screenmark/screenmark/agents/client"
"screenmark/screenmark/models" "screenmark/screenmark/models"
"time"
"github.com/charmbracelet/log" "github.com/charmbracelet/log"
"github.com/google/uuid" "github.com/google/uuid"
@ -43,7 +41,7 @@ func (agent NoteAgent) GetNotes(userId uuid.UUID, imageId uuid.UUID, imageName s
} }
request.Chat.AddSystem(noteAgentPrompt) request.Chat.AddSystem(noteAgentPrompt)
request.Chat.AddImage(imageName, imageData) request.Chat.AddImage(imageName, imageData, nil)
resp, err := agent.client.Request(&request) resp, err := agent.client.Request(&request)
if err != nil { if err != nil {
@ -70,20 +68,16 @@ func (agent NoteAgent) GetNotes(userId uuid.UUID, imageId uuid.UUID, imageName s
return nil return nil
} }
func NewNoteAgent(noteModel models.NoteModel) (NoteAgent, error) { func NewNoteAgent(log *log.Logger, noteModel models.NoteModel) NoteAgent {
client, err := client.CreateAgentClient(log.NewWithOptions(os.Stdout, log.Options{ client := client.CreateAgentClient(client.CreateAgentClientOptions{
ReportTimestamp: true, SystemPrompt: noteAgentPrompt,
TimeFormat: time.Kitchen, Log: log,
Prefix: "Notes 📝", })
}))
if err != nil {
return NoteAgent{}, err
}
agent := NoteAgent{ agent := NoteAgent{
client: client, client: client,
noteModel: noteModel, noteModel: noteModel,
} }
return agent, nil return agent
} }

View File

@ -2,14 +2,12 @@ package agents
import ( import (
"errors" "errors"
"os"
"screenmark/screenmark/agents/client" "screenmark/screenmark/agents/client"
"time"
"github.com/charmbracelet/log" "github.com/charmbracelet/log"
) )
const OrchestratorPrompt = ` const orchestratorPrompt = `
You are an Orchestrator for various AI agents. You are an Orchestrator for various AI agents.
The user will send you images and you have to determine which agents you have to call, in order to best help the user. The user will send you images and you have to determine which agents you have to call, in order to best help the user.
@ -40,7 +38,7 @@ Always call agents in parallel if you need to call more than 1.
Do not call the agent if you do not think it is relevant for the image. Do not call the agent if you do not think it is relevant for the image.
` `
const OrchestratorTools = ` const orchestratorTools = `
[ [
{ {
"type": "function", "type": "function",
@ -114,16 +112,13 @@ type Status struct {
Ok bool `json:"ok"` Ok bool `json:"ok"`
} }
func NewOrchestratorAgent(noteAgent NoteAgent, contactAgent ContactAgent, locationAgent LocationAgent, eventAgent EventAgent, imageName string, imageData []byte) (OrchestratorAgent, error) { func NewOrchestratorAgent(log *log.Logger, noteAgent NoteAgent, contactAgent client.AgentClient, locationAgent client.AgentClient, eventAgent client.AgentClient, imageName string, imageData []byte) client.AgentClient {
agent, err := client.CreateAgentClient(log.NewWithOptions(os.Stdout, log.Options{ agent := client.CreateAgentClient(client.CreateAgentClientOptions{
ReportTimestamp: true, SystemPrompt: orchestratorPrompt,
TimeFormat: time.Kitchen, JsonTools: orchestratorTools,
Prefix: "Orchestrator 🎼", Log: log,
})) EndToolCall: "noAction",
})
if err != nil {
return OrchestratorAgent{}, err
}
agent.ToolHandler.AddTool("noteAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { agent.ToolHandler.AddTool("noteAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
go noteAgent.GetNotes(info.UserId, info.ImageId, imageName, imageData) go noteAgent.GetNotes(info.UserId, info.ImageId, imageName, imageData)
@ -134,7 +129,7 @@ func NewOrchestratorAgent(noteAgent NoteAgent, contactAgent ContactAgent, locati
}) })
agent.ToolHandler.AddTool("contactAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { agent.ToolHandler.AddTool("contactAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
go contactAgent.client.RunAgent(contactPrompt, contactTools, "finish", info.UserId, info.ImageId, imageName, imageData) go contactAgent.RunAgent(info.UserId, info.ImageId, imageName, imageData)
return Status{ return Status{
Ok: true, Ok: true,
@ -142,7 +137,7 @@ func NewOrchestratorAgent(noteAgent NoteAgent, contactAgent ContactAgent, locati
}) })
agent.ToolHandler.AddTool("locationAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { agent.ToolHandler.AddTool("locationAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
go locationAgent.client.RunAgent(locationPrompt, locationTools, "finish", info.UserId, info.ImageId, imageName, imageData) go locationAgent.RunAgent(info.UserId, info.ImageId, imageName, imageData)
return Status{ return Status{
Ok: true, Ok: true,
@ -150,7 +145,7 @@ func NewOrchestratorAgent(noteAgent NoteAgent, contactAgent ContactAgent, locati
}) })
agent.ToolHandler.AddTool("eventAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { agent.ToolHandler.AddTool("eventAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
go eventAgent.client.RunAgent(eventPrompt, eventTools, "finish", info.UserId, info.ImageId, imageName, imageData) go eventAgent.RunAgent(info.UserId, info.ImageId, imageName, imageData)
return Status{ return Status{
Ok: true, Ok: true,
@ -165,7 +160,5 @@ func NewOrchestratorAgent(noteAgent NoteAgent, contactAgent ContactAgent, locati
}, errors.New("Finished! Kinda bad return type but...") }, errors.New("Finished! Kinda bad return type but...")
}) })
return OrchestratorAgent{ return agent
Client: agent,
}, nil
} }

View File

@ -3,17 +3,28 @@ package main
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"log"
"os" "os"
"screenmark/screenmark/agents" "screenmark/screenmark/agents"
"screenmark/screenmark/models" "screenmark/screenmark/models"
"time" "time"
"github.com/charmbracelet/log"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/lib/pq" "github.com/lib/pq"
) )
func createLogger(prefix string) *log.Logger {
logger := log.NewWithOptions(os.Stdout, log.Options{
ReportTimestamp: true,
TimeFormat: time.Kitchen,
Prefix: prefix,
})
logger.SetLevel(log.DebugLevel)
return logger
}
func ListenNewImageEvents(db *sql.DB, eventManager *EventManager) { func ListenNewImageEvents(db *sql.DB, eventManager *EventManager) {
listener := pq.NewListener(os.Getenv("DB_CONNECTION"), time.Second, time.Second, func(event pq.ListenerEventType, err error) { listener := pq.NewListener(os.Getenv("DB_CONNECTION"), time.Second, time.Second, func(event pq.ListenerEventType, err error) {
if err != nil { if err != nil {
@ -28,6 +39,9 @@ func ListenNewImageEvents(db *sql.DB, eventManager *EventManager) {
imageModel := models.NewImageModel(db) imageModel := models.NewImageModel(db)
contactModel := models.NewContactModel(db) contactModel := models.NewContactModel(db)
databaseEventLog := createLogger("Database Events 🤖")
databaseEventLog.SetLevel(log.DebugLevel)
err := listener.Listen("new_image") err := listener.Listen("new_image")
if err != nil { if err != nil {
panic(err) panic(err)
@ -39,55 +53,41 @@ func ListenNewImageEvents(db *sql.DB, eventManager *EventManager) {
imageId := uuid.MustParse(parameters.Extra) imageId := uuid.MustParse(parameters.Extra)
eventManager.listeners[parameters.Extra] = make(chan string) eventManager.listeners[parameters.Extra] = make(chan string)
databaseEventLog.Debug("Starting processing image", "ImageID", imageId)
ctx := context.Background() ctx := context.Background()
go func() { go func() {
noteAgent, err := agents.NewNoteAgent(noteModel) noteAgent := agents.NewNoteAgent(createLogger("Notes 📝"), noteModel)
if err != nil { contactAgent := agents.NewContactAgent(createLogger("Contacts 👥"), contactModel)
panic(err) locationAgent := agents.NewLocationAgent(createLogger("Locations 📍"), locationModel)
} eventAgent := agents.NewEventAgent(createLogger("Events 📅"), eventModel, locationModel)
contactAgent, err := agents.NewContactAgent(contactModel)
if err != nil {
panic(err)
}
locationAgent, err := agents.NewLocationAgent(locationModel)
if err != nil {
panic(err)
}
eventAgent, err := agents.NewEventAgent(eventModel)
if err != nil {
panic(err)
}
image, err := imageModel.GetToProcessWithData(ctx, imageId) image, err := imageModel.GetToProcessWithData(ctx, imageId)
if err != nil { if err != nil {
log.Println("Failed to GetToProcessWithData") databaseEventLog.Error("Failed to GetToProcessWithData", "error", err)
log.Println(err)
return return
} }
if err := imageModel.StartProcessing(ctx, image.ID); err != nil { if err := imageModel.StartProcessing(ctx, image.ID); err != nil {
log.Println("Failed to FinishProcessing") databaseEventLog.Error("Failed to FinishProcessing", "error", err)
log.Println(err)
return return
} }
orchestrator, err := agents.NewOrchestratorAgent(noteAgent, contactAgent, locationAgent, eventAgent, image.Image.ImageName, image.Image.Image) orchestrator := agents.NewOrchestratorAgent(createLogger("Orchestrator 🎼"), noteAgent, contactAgent, locationAgent, eventAgent, image.Image.ImageName, image.Image.Image)
err = orchestrator.RunAgent(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
if err != nil { if err != nil {
panic(err) databaseEventLog.Error("Orchestrator failed", "error", err)
return
} }
// Still need to find some way to hide this complexity away. _, err = imageModel.FinishProcessing(ctx, image.ID)
// I don't think wrapping agents in structs actually works too well.
err = orchestrator.Client.RunAgent(agents.OrchestratorPrompt, agents.OrchestratorTools, "noAction", image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
if err != nil { if err != nil {
log.Println(err) databaseEventLog.Error("Failed to finish processing", "ImageID", imageId)
return
} }
imageModel.FinishProcessing(ctx, image.ID) databaseEventLog.Debug("Starting processing image", "ImageID", imageId)
}() }()
} }
} }
@ -122,9 +122,6 @@ func ListenProcessingImageStatus(db *sql.DB, eventManager *EventManager) {
stringUuid := data.Extra[0:36] stringUuid := data.Extra[0:36]
status := data.Extra[36:] status := data.Extra[36:]
fmt.Printf("UUID: %s\n", stringUuid)
fmt.Printf("Receiving :s\n", data.Extra)
imageListener, exists := eventManager.listeners[stringUuid] imageListener, exists := eventManager.listeners[stringUuid]
if !exists { if !exists {
continue continue