refactor(agents): no need to wrap them in another struct

This commit is contained in:
2025-04-17 10:36:11 +01:00
parent fa486153b4
commit e42aa75639
6 changed files with 66 additions and 118 deletions

View File

@ -3,11 +3,9 @@ package agents
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"os"
"screenmark/screenmark/.gen/haystack/haystack/model" "screenmark/screenmark/.gen/haystack/haystack/model"
"screenmark/screenmark/agents/client" "screenmark/screenmark/agents/client"
"screenmark/screenmark/models" "screenmark/screenmark/models"
"time"
"github.com/charmbracelet/log" "github.com/charmbracelet/log"
"github.com/google/uuid" "github.com/google/uuid"
@ -81,12 +79,6 @@ const contactTools = `
] ]
` `
type ContactAgent struct {
client client.AgentClient
contactModel models.ContactModel
}
type listContactsArguments struct{} type listContactsArguments struct{}
type createContactsArguments struct { type createContactsArguments struct {
Name string `json:"name"` Name string `json:"name"`
@ -98,23 +90,14 @@ type linkContactArguments struct {
ContactID string `json:"contactId"` ContactID string `json:"contactId"`
} }
func NewContactAgent(contactModel models.ContactModel) (ContactAgent, error) { func NewContactAgent(log *log.Logger, contactModel models.ContactModel) (client.AgentClient, error) {
agentClient, err := client.CreateAgentClient(log.NewWithOptions(os.Stdout, log.Options{ agentClient, err := client.CreateAgentClient(log)
ReportTimestamp: true,
TimeFormat: time.Kitchen,
Prefix: "Contacts 👥",
}))
if err != nil { if err != nil {
return ContactAgent{}, err return client.AgentClient{}, err
}
agent := ContactAgent{
client: agentClient,
contactModel: contactModel,
} }
agentClient.ToolHandler.AddTool("listContacts", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { agentClient.ToolHandler.AddTool("listContacts", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
return agent.contactModel.List(context.Background(), info.UserId) return contactModel.List(context.Background(), info.UserId)
}) })
agentClient.ToolHandler.AddTool("createContact", func(info client.ToolHandlerInfo, _args string, call client.ToolCall) (any, error) { agentClient.ToolHandler.AddTool("createContact", func(info client.ToolHandlerInfo, _args string, call client.ToolCall) (any, error) {
@ -126,7 +109,7 @@ func NewContactAgent(contactModel models.ContactModel) (ContactAgent, error) {
ctx := context.Background() ctx := context.Background()
contact, err := agent.contactModel.Save(ctx, info.UserId, model.Contacts{ contact, err := contactModel.Save(ctx, info.UserId, model.Contacts{
Name: args.Name, Name: args.Name,
PhoneNumber: args.PhoneNumber, PhoneNumber: args.PhoneNumber,
Email: args.Email, Email: args.Email,
@ -136,7 +119,7 @@ func NewContactAgent(contactModel models.ContactModel) (ContactAgent, error) {
return model.Contacts{}, err return model.Contacts{}, err
} }
_, err = agent.contactModel.SaveToImage(ctx, info.ImageId, contact.ID) _, err = contactModel.SaveToImage(ctx, info.ImageId, contact.ID)
if err != nil { if err != nil {
return model.Contacts{}, err return model.Contacts{}, err
} }
@ -158,7 +141,7 @@ func NewContactAgent(contactModel models.ContactModel) (ContactAgent, error) {
return "", err return "", err
} }
_, err = agent.contactModel.SaveToImage(ctx, info.ImageId, contactUuid) _, err = contactModel.SaveToImage(ctx, info.ImageId, contactUuid)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -166,5 +149,5 @@ func NewContactAgent(contactModel models.ContactModel) (ContactAgent, error) {
return "Saved", nil return "Saved", nil
}) })
return agent, nil return agentClient, nil
} }

View File

@ -3,7 +3,6 @@ package agents
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"os"
"screenmark/screenmark/.gen/haystack/haystack/model" "screenmark/screenmark/.gen/haystack/haystack/model"
"screenmark/screenmark/agents/client" "screenmark/screenmark/agents/client"
"screenmark/screenmark/models" "screenmark/screenmark/models"
@ -98,14 +97,6 @@ const eventTools = `
} }
]` ]`
type EventAgent struct {
client client.AgentClient
eventsModel models.EventModel
locationAgent LocationAgent
}
type listEventArguments struct{} type listEventArguments struct{}
type createEventArguments struct { type createEventArguments struct {
Name string `json:"name"` Name string `json:"name"`
@ -117,25 +108,15 @@ type linkEventArguments struct {
EventID string `json:"eventId"` EventID string `json:"eventId"`
} }
func NewEventAgent(eventsModel models.EventModel, locationAgent LocationAgent) (EventAgent, error) { func NewEventAgent(log *log.Logger, eventsModel models.EventModel, locationAgent client.AgentClient) (client.AgentClient, error) {
agentClient, err := client.CreateAgentClient(log.NewWithOptions(os.Stdout, log.Options{ agentClient, err := client.CreateAgentClient(log)
ReportTimestamp: true,
TimeFormat: time.Kitchen,
Prefix: "Events 📍",
}))
if err != nil { if err != nil {
return EventAgent{}, err return client.AgentClient{}, err
}
agent := EventAgent{
client: agentClient,
eventsModel: eventsModel,
locationAgent: locationAgent,
} }
agentClient.ToolHandler.AddTool("listEvents", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { agentClient.ToolHandler.AddTool("listEvents", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
return agent.eventsModel.List(context.Background(), info.UserId) return eventsModel.List(context.Background(), info.UserId)
}) })
agentClient.ToolHandler.AddTool("createEvent", func(info client.ToolHandlerInfo, _args string, call client.ToolCall) (any, error) { agentClient.ToolHandler.AddTool("createEvent", func(info client.ToolHandlerInfo, _args string, call client.ToolCall) (any, error) {
@ -159,7 +140,7 @@ func NewEventAgent(eventsModel models.EventModel, locationAgent LocationAgent) (
return model.Events{}, err return model.Events{}, err
} }
events, err := agent.eventsModel.Save(ctx, info.UserId, model.Events{ events, err := eventsModel.Save(ctx, info.UserId, model.Events{
Name: args.Name, Name: args.Name,
StartDateTime: &startTime, StartDateTime: &startTime,
EndDateTime: &endTime, EndDateTime: &endTime,
@ -169,7 +150,7 @@ func NewEventAgent(eventsModel models.EventModel, locationAgent LocationAgent) (
return model.Events{}, err return model.Events{}, err
} }
_, err = agent.eventsModel.SaveToImage(ctx, info.ImageId, events.ID) _, err = eventsModel.SaveToImage(ctx, info.ImageId, events.ID)
if err != nil { if err != nil {
return model.Events{}, err return model.Events{}, err
} }
@ -191,16 +172,16 @@ func NewEventAgent(eventsModel models.EventModel, locationAgent LocationAgent) (
return "", err return "", err
} }
agent.eventsModel.SaveToImage(ctx, info.ImageId, contactUuid) eventsModel.SaveToImage(ctx, info.ImageId, contactUuid)
return "Saved", nil return "Saved", nil
}) })
agentClient.ToolHandler.AddTool("getEventLocationId", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { agentClient.ToolHandler.AddTool("getEventLocationId", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
query := "Can you get me the ID of the location present in this image?" query := "Can you get me the ID of the location present in this image?"
locationAgent.client.RunAgent(locationPrompt, locationTools, "finish", &query, info.UserId, info.ImageId, info.ImageName, *info.Image) locationAgent.RunAgent(locationPrompt, locationTools, "finish", &query, info.UserId, info.ImageId, info.ImageName, *info.Image)
return locationAgent.client.Reply, nil return locationAgent.Reply, nil
}) })
return agent, nil return agentClient, nil
} }

View File

@ -3,11 +3,9 @@ package agents
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"os"
"screenmark/screenmark/.gen/haystack/haystack/model" "screenmark/screenmark/.gen/haystack/haystack/model"
"screenmark/screenmark/agents/client" "screenmark/screenmark/agents/client"
"screenmark/screenmark/models" "screenmark/screenmark/models"
"time"
"github.com/charmbracelet/log" "github.com/charmbracelet/log"
"github.com/google/uuid" "github.com/google/uuid"
@ -97,12 +95,6 @@ const locationTools = `
} }
]` ]`
type LocationAgent struct {
client client.AgentClient
locationModel models.LocationModel
}
type listLocationArguments struct{} type listLocationArguments struct{}
type createLocationArguments struct { type createLocationArguments struct {
Name string `json:"name"` Name string `json:"name"`
@ -112,24 +104,15 @@ type linkLocationArguments struct {
LocationID string `json:"locationId"` LocationID string `json:"locationId"`
} }
func NewLocationAgent(locationModel models.LocationModel) (LocationAgent, error) { func NewLocationAgent(log *log.Logger, locationModel models.LocationModel) (client.AgentClient, error) {
agentClient, err := client.CreateAgentClient(log.NewWithOptions(os.Stdout, log.Options{ agentClient, err := client.CreateAgentClient(log)
ReportTimestamp: true,
TimeFormat: time.Kitchen,
Prefix: "Locations 📍",
}))
if err != nil { if err != nil {
return LocationAgent{}, err return client.AgentClient{}, err
}
agent := LocationAgent{
client: agentClient,
locationModel: locationModel,
} }
agentClient.ToolHandler.AddTool("listLocations", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { agentClient.ToolHandler.AddTool("listLocations", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
return agent.locationModel.List(context.Background(), info.UserId) return locationModel.List(context.Background(), info.UserId)
}) })
agentClient.ToolHandler.AddTool("createLocation", func(info client.ToolHandlerInfo, _args string, call client.ToolCall) (any, error) { agentClient.ToolHandler.AddTool("createLocation", func(info client.ToolHandlerInfo, _args string, call client.ToolCall) (any, error) {
@ -141,7 +124,7 @@ func NewLocationAgent(locationModel models.LocationModel) (LocationAgent, error)
ctx := context.Background() ctx := context.Background()
location, err := agent.locationModel.Save(ctx, info.UserId, model.Locations{ location, err := locationModel.Save(ctx, info.UserId, model.Locations{
Name: args.Name, Name: args.Name,
Address: args.Address, Address: args.Address,
}) })
@ -150,7 +133,7 @@ func NewLocationAgent(locationModel models.LocationModel) (LocationAgent, error)
return model.Locations{}, err return model.Locations{}, err
} }
_, err = agent.locationModel.SaveToImage(ctx, info.ImageId, location.ID) _, err = locationModel.SaveToImage(ctx, info.ImageId, location.ID)
if err != nil { if err != nil {
return model.Locations{}, err return model.Locations{}, err
} }
@ -172,14 +155,14 @@ func NewLocationAgent(locationModel models.LocationModel) (LocationAgent, error)
return "", err return "", err
} }
agent.locationModel.SaveToImage(ctx, info.ImageId, contactUuid) locationModel.SaveToImage(ctx, info.ImageId, contactUuid)
return "Saved", nil return "Saved", nil
}) })
agentClient.ToolHandler.AddTool("reply", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { agentClient.ToolHandler.AddTool("reply", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
agent.client.Log.Debug(args) agentClient.Log.Debug(args)
return "ok", nil return "ok", nil
}) })
return agent, nil return agentClient, nil
} }

View File

@ -2,11 +2,9 @@ package agents
import ( import (
"context" "context"
"os"
"screenmark/screenmark/.gen/haystack/haystack/model" "screenmark/screenmark/.gen/haystack/haystack/model"
"screenmark/screenmark/agents/client" "screenmark/screenmark/agents/client"
"screenmark/screenmark/models" "screenmark/screenmark/models"
"time"
"github.com/charmbracelet/log" "github.com/charmbracelet/log"
"github.com/google/uuid" "github.com/google/uuid"
@ -70,12 +68,8 @@ func (agent NoteAgent) GetNotes(userId uuid.UUID, imageId uuid.UUID, imageName s
return nil return nil
} }
func NewNoteAgent(noteModel models.NoteModel) (NoteAgent, error) { func NewNoteAgent(log *log.Logger, noteModel models.NoteModel) (NoteAgent, error) {
client, err := client.CreateAgentClient(log.NewWithOptions(os.Stdout, log.Options{ client, err := client.CreateAgentClient(log)
ReportTimestamp: true,
TimeFormat: time.Kitchen,
Prefix: "Notes 📝",
}))
if err != nil { if err != nil {
return NoteAgent{}, err return NoteAgent{}, err
} }

View File

@ -2,9 +2,7 @@ package agents
import ( import (
"errors" "errors"
"os"
"screenmark/screenmark/agents/client" "screenmark/screenmark/agents/client"
"time"
"github.com/charmbracelet/log" "github.com/charmbracelet/log"
) )
@ -114,15 +112,11 @@ type Status struct {
Ok bool `json:"ok"` Ok bool `json:"ok"`
} }
func NewOrchestratorAgent(noteAgent NoteAgent, contactAgent ContactAgent, locationAgent LocationAgent, eventAgent EventAgent, imageName string, imageData []byte) (OrchestratorAgent, error) { func NewOrchestratorAgent(log *log.Logger, noteAgent NoteAgent, contactAgent client.AgentClient, locationAgent client.AgentClient, eventAgent client.AgentClient, imageName string, imageData []byte) (client.AgentClient, error) {
agent, err := client.CreateAgentClient(log.NewWithOptions(os.Stdout, log.Options{ agent, err := client.CreateAgentClient(log)
ReportTimestamp: true,
TimeFormat: time.Kitchen,
Prefix: "Orchestrator 🎼",
}))
if err != nil { if err != nil {
return OrchestratorAgent{}, err return client.AgentClient{}, err
} }
agent.ToolHandler.AddTool("noteAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { agent.ToolHandler.AddTool("noteAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
@ -134,7 +128,7 @@ func NewOrchestratorAgent(noteAgent NoteAgent, contactAgent ContactAgent, locati
}) })
agent.ToolHandler.AddTool("contactAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { agent.ToolHandler.AddTool("contactAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
go contactAgent.client.RunAgent(contactPrompt, contactTools, "finish", nil, info.UserId, info.ImageId, imageName, imageData) go contactAgent.RunAgent(contactPrompt, contactTools, "finish", nil, info.UserId, info.ImageId, imageName, imageData)
return Status{ return Status{
Ok: true, Ok: true,
@ -142,7 +136,7 @@ func NewOrchestratorAgent(noteAgent NoteAgent, contactAgent ContactAgent, locati
}) })
agent.ToolHandler.AddTool("locationAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { agent.ToolHandler.AddTool("locationAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
go locationAgent.client.RunAgent(locationPrompt, locationTools, "finish", nil, info.UserId, info.ImageId, imageName, imageData) go locationAgent.RunAgent(locationPrompt, locationTools, "finish", nil, info.UserId, info.ImageId, imageName, imageData)
return Status{ return Status{
Ok: true, Ok: true,
@ -150,7 +144,7 @@ func NewOrchestratorAgent(noteAgent NoteAgent, contactAgent ContactAgent, locati
}) })
agent.ToolHandler.AddTool("eventAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) { agent.ToolHandler.AddTool("eventAgent", func(info client.ToolHandlerInfo, args string, call client.ToolCall) (any, error) {
go eventAgent.client.RunAgent(eventPrompt, eventTools, "finish", nil, info.UserId, info.ImageId, imageName, imageData) go agent.RunAgent(eventPrompt, eventTools, "finish", nil, info.UserId, info.ImageId, imageName, imageData)
return Status{ return Status{
Ok: true, Ok: true,
@ -165,7 +159,5 @@ func NewOrchestratorAgent(noteAgent NoteAgent, contactAgent ContactAgent, locati
}, errors.New("Finished! Kinda bad return type but...") }, errors.New("Finished! Kinda bad return type but...")
}) })
return OrchestratorAgent{ return agent, nil
Client: agent,
}, nil
} }

View File

@ -4,16 +4,24 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"log"
"os" "os"
"screenmark/screenmark/agents" "screenmark/screenmark/agents"
"screenmark/screenmark/models" "screenmark/screenmark/models"
"time" "time"
"github.com/charmbracelet/log"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/lib/pq" "github.com/lib/pq"
) )
func createLogger(prefix string) *log.Logger {
return log.NewWithOptions(os.Stdout, log.Options{
ReportTimestamp: true,
TimeFormat: time.Kitchen,
Prefix: prefix,
})
}
func ListenNewImageEvents(db *sql.DB, eventManager *EventManager) { func ListenNewImageEvents(db *sql.DB, eventManager *EventManager) {
listener := pq.NewListener(os.Getenv("DB_CONNECTION"), time.Second, time.Second, func(event pq.ListenerEventType, err error) { listener := pq.NewListener(os.Getenv("DB_CONNECTION"), time.Second, time.Second, func(event pq.ListenerEventType, err error) {
if err != nil { if err != nil {
@ -28,6 +36,8 @@ func ListenNewImageEvents(db *sql.DB, eventManager *EventManager) {
imageModel := models.NewImageModel(db) imageModel := models.NewImageModel(db)
contactModel := models.NewContactModel(db) contactModel := models.NewContactModel(db)
databaseEventLog := createLogger("Database Events 🤖")
err := listener.Listen("new_image") err := listener.Listen("new_image")
if err != nil { if err != nil {
panic(err) panic(err)
@ -39,55 +49,60 @@ func ListenNewImageEvents(db *sql.DB, eventManager *EventManager) {
imageId := uuid.MustParse(parameters.Extra) imageId := uuid.MustParse(parameters.Extra)
eventManager.listeners[parameters.Extra] = make(chan string) eventManager.listeners[parameters.Extra] = make(chan string)
databaseEventLog.Debug("Starting processing image", "ImageID", imageId)
ctx := context.Background() ctx := context.Background()
go func() { go func() {
noteAgent, err := agents.NewNoteAgent(noteModel) noteAgent, err := agents.NewNoteAgent(createLogger("Notes 📝"), noteModel)
if err != nil { if err != nil {
panic(err) panic(err)
} }
contactAgent, err := agents.NewContactAgent(contactModel) contactAgent, err := agents.NewContactAgent(createLogger("Contacts 👥"), contactModel)
if err != nil { if err != nil {
panic(err) panic(err)
} }
locationAgent, err := agents.NewLocationAgent(locationModel) locationAgent, err := agents.NewLocationAgent(createLogger("Locations 📍"), locationModel)
if err != nil { if err != nil {
panic(err) panic(err)
} }
eventAgent, err := agents.NewEventAgent(eventModel, locationAgent) eventAgent, err := agents.NewEventAgent(createLogger("Events 📅"), eventModel, locationAgent)
if err != nil { if err != nil {
panic(err) panic(err)
} }
image, err := imageModel.GetToProcessWithData(ctx, imageId) image, err := imageModel.GetToProcessWithData(ctx, imageId)
if err != nil { if err != nil {
log.Println("Failed to GetToProcessWithData") log.Error("Failed to GetToProcessWithData", "error", err)
log.Println(err)
return return
} }
if err := imageModel.StartProcessing(ctx, image.ID); err != nil { if err := imageModel.StartProcessing(ctx, image.ID); err != nil {
log.Println("Failed to FinishProcessing") log.Error("Failed to FinishProcessing", "error", err)
log.Println(err)
return return
} }
orchestrator, err := agents.NewOrchestratorAgent(noteAgent, contactAgent, locationAgent, eventAgent, image.Image.ImageName, image.Image.Image) orchestrator, err := agents.NewOrchestratorAgent(createLogger("Orchestrator 🎼"), noteAgent, contactAgent, locationAgent, eventAgent, image.Image.ImageName, image.Image.Image)
if err != nil { if err != nil {
panic(err) panic(err)
} }
// Still need to find some way to hide this complexity away. err = orchestrator.RunAgent(agents.OrchestratorPrompt, agents.OrchestratorTools, "noAction", nil, image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
// I don't think wrapping agents in structs actually works too well.
err = orchestrator.Client.RunAgent(agents.OrchestratorPrompt, agents.OrchestratorTools, "noAction", nil, image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image)
if err != nil { if err != nil {
log.Println(err) log.Error("Orchestrator failed", "error", "err")
return
} }
imageModel.FinishProcessing(ctx, image.ID) _, err = imageModel.FinishProcessing(ctx, image.ID)
if err != nil {
log.Error("Failed to finish processing", "ImageID", imageId)
return
}
databaseEventLog.Debug("Starting processing image", "ImageID", imageId)
}() }()
} }
} }