diff --git a/backend/agents/create_list_agent.go b/backend/agents/create_list_agent.go index 2d43548..ed3e708 100644 --- a/backend/agents/create_list_agent.go +++ b/backend/agents/create_list_agent.go @@ -76,10 +76,10 @@ type createNewListArguments struct { type CreateListAgent struct { client client.AgentClient - listModel models.StackModel + stackModel models.StackModel } -func (agent *CreateListAgent) CreateList(log *log.Logger, userID uuid.UUID, userReq string) error { +func (agent *CreateListAgent) CreateList(log *log.Logger, userID uuid.UUID, title string, userReq string) error { request := client.AgentRequestBody{ Model: "policy/images", Temperature: 0.3, @@ -93,7 +93,10 @@ func (agent *CreateListAgent) CreateList(log *log.Logger, userID uuid.UUID, user } request.Chat.AddSystem(agent.client.Options.SystemPrompt) - request.Chat.AddUser(userReq) + + req := fmt.Sprintf("List title: %s | Users list description: %s", title, userReq) + + request.Chat.AddUser(req) resp, err := agent.client.Request(&request) if err != nil { @@ -120,12 +123,12 @@ func (agent *CreateListAgent) CreateList(log *log.Logger, userID uuid.UUID, user }) } - _, err = agent.listModel.Save(ctx, userID, createListArgs.Title, createListArgs.Description, model.Progress_Complete) + _, err = agent.stackModel.Save(ctx, userID, createListArgs.Title, createListArgs.Description, model.Progress_Complete) if err != nil { return fmt.Errorf("creating list agent, saving list: %w", err) } - err = agent.listModel.SaveItems(ctx, schemaItems) + err = agent.stackModel.SaveItems(ctx, schemaItems) if err != nil { return fmt.Errorf("creating list agent, saving items: %w", err) } diff --git a/backend/agents/list_agent.go b/backend/agents/list_agent.go index edfd437..81221dd 100644 --- a/backend/agents/list_agent.go +++ b/backend/agents/list_agent.go @@ -176,7 +176,7 @@ type addToListArguments struct { Schema []models.IDValue } -func NewListAgent(log *log.Logger, stackModel models.StackModel, limitsMethods limits.LimitsManagerMethods) client.AgentClient { +func NewStackAgent(log *log.Logger, stackModel models.StackModel, limitsMethods limits.LimitsManagerMethods) client.AgentClient { agentClient := client.CreateAgentClient(client.CreateAgentClientOptions{ SystemPrompt: listPrompt, JsonTools: listTools, diff --git a/backend/models/stacks.go b/backend/models/stacks.go index 6d01ef7..1864e71 100644 --- a/backend/models/stacks.go +++ b/backend/models/stacks.go @@ -142,6 +142,20 @@ func (m StackModel) SaveSchemaItems(ctx context.Context, imageID uuid.UUID, item return err } +// ======================================== +// UPDATE methods +// ======================================== + +func (m StackModel) UpdateProcess(ctx context.Context, stackID uuid.UUID, process model.Progress) error { + updateStackProgressStmt := Stacks.UPDATE(Stacks.Status). + SET(process). + WHERE(Stacks.ID.EQ(UUID(stackID))) + + _, err := updateStackProgressStmt.ExecContext(ctx, m.dbPool) + + return err +} + // ======================================== // DELETE methods // ======================================== diff --git a/backend/notifications/image_notification.go b/backend/notifications/image_notification.go index 4ec47a4..de2ca74 100644 --- a/backend/notifications/image_notification.go +++ b/backend/notifications/image_notification.go @@ -9,7 +9,7 @@ import ( const ( IMAGE_TYPE = "image" - LIST_TYPE = "list" + STACK_TYPE = "list" ) type ImageNotification struct { @@ -21,18 +21,18 @@ type ImageNotification struct { Status string } -type ListNotification struct { +type StackNotification struct { Type string - ListID uuid.UUID - Name string + StackID uuid.UUID + Name string Status string } type Notification struct { image *ImageNotification - list *ListNotification + stack *StackNotification } func GetImageNotification(image ImageNotification) Notification { @@ -41,9 +41,9 @@ func GetImageNotification(image ImageNotification) Notification { } } -func GetListNotification(list ListNotification) Notification { +func GetStackNotification(list StackNotification) Notification { return Notification{ - list: &list, + stack: &list, } } @@ -52,8 +52,8 @@ func (n Notification) MarshalJSON() ([]byte, error) { return json.Marshal(n.image) } - if n.list != nil { - return json.Marshal(n.list) + if n.stack != nil { + return json.Marshal(n.stack) } return nil, fmt.Errorf("no image or list present") diff --git a/backend/processor/image.go b/backend/processor/image.go index 46574f2..dc3b108 100644 --- a/backend/processor/image.go +++ b/backend/processor/image.go @@ -23,11 +23,8 @@ type ImageProcessor struct { descriptionAgent agents.DescriptionAgent stackAgent client.AgentClient - // TODO: add the notifier here - Processor *Processor[model.Image] - - notifier *notifications.Notifier[notifications.Notification] + notifier *notifications.Notifier[notifications.Notification] } func (p *ImageProcessor) setImageToProcess(ctx context.Context, image model.Image) { @@ -128,7 +125,7 @@ func NewImageProcessor( } descriptionAgent := agents.NewDescriptionAgent(logger, imageModel) - stackAgent := agents.NewListAgent(logger, listModel, limitsManager) + stackAgent := agents.NewStackAgent(logger, listModel, limitsManager) imageProcessor := ImageProcessor{ imageModel: imageModel, diff --git a/backend/processor/stack.go b/backend/processor/stack.go new file mode 100644 index 0000000..63df3d0 --- /dev/null +++ b/backend/processor/stack.go @@ -0,0 +1,127 @@ +package processor + +import ( + "context" + "fmt" + "screenmark/screenmark/.gen/haystack/haystack/model" + "screenmark/screenmark/agents" + "screenmark/screenmark/models" + "screenmark/screenmark/notifications" + "sync" + + "github.com/charmbracelet/log" +) + +const STACK_PROCESS_AT_A_TIME = 10 + +// TODO: +// This processor contains a lot of shared stuff. +// If we ever want to do more generic stuff with "in-progress" and stuff +// we can extract that into a common thing +// +// However, this will require a pretty big DB shuffle. + +type StackProcessor struct { + stackModel models.StackModel + logger *log.Logger + + stackAgent agents.CreateListAgent + + Processor *Processor[model.Stacks] + + notifier *notifications.Notifier[notifications.Notification] +} + +func (p *StackProcessor) setStackToProcess(ctx context.Context, stack model.Stacks) { + err := p.stackModel.UpdateProcess(ctx, stack.ID, model.Progress_InProgress) + if err != nil { + // TODO: what can we actually do here for the errors? + // We can't stop the work for the others + + p.logger.Error("failed to update stack", "err", err) + + // TODO: we can use context here to actually pass some information through + return + } +} + +func (p *StackProcessor) extractInfo(ctx context.Context, stack model.Stacks) { + err := p.stackAgent.CreateList(p.logger, stack.UserID, stack.Name, stack.Description) + if err != nil { + // Again, wtf do we do? + // Although i think the agent actually returns an error when it's finished + p.logger.Error("failed to process image", "err", err) + return + } +} + +func (p *StackProcessor) processImage(stack model.Stacks) { + p.logger.Info("Processing image", "ID", stack.ID) + + ctx := context.Background() + + p.setStackToProcess(ctx, stack) + + var wg sync.WaitGroup + + // Future proofing! + wg.Add(1) + + stackNotification := notifications.GetStackNotification(notifications.StackNotification{ + Type: notifications.STACK_TYPE, + Status: string(model.Progress_InProgress), + StackID: stack.ID, + Name: stack.Name, + }) + + err := p.notifier.SendAndCreate(stack.UserID.String(), stackNotification) + if err != nil { + p.logger.Error("sending in progress notification", "err", err) + return + } + + go func() { + p.extractInfo(ctx, stack) + wg.Done() + }() + + wg.Wait() + + // TODO: there is some repeated code here. The ergonomicts of the notifications, + // isn't the best. + stackNotification = notifications.GetStackNotification(notifications.StackNotification{ + Type: notifications.STACK_TYPE, + Status: string(model.Progress_Complete), + StackID: stack.ID, + Name: stack.Name, + }) + + err = p.notifier.SendAndCreate(stack.UserID.String(), stackNotification) + if err != nil { + p.logger.Error("sending done notification", "err", err) + return + } +} + +func NewStackProcessor( + logger *log.Logger, + stackModel models.StackModel, + notifier *notifications.Notifier[notifications.Notification], +) (StackProcessor, error) { + if notifier == nil { + return StackProcessor{}, fmt.Errorf("notifier is nil") + } + + stackAgent := agents.NewCreateListAgent(logger, stackModel) + + imageProcessor := StackProcessor{ + logger: logger, + stackModel: stackModel, + stackAgent: stackAgent, + notifier: notifier, + } + + imageProcessor.Processor = NewProcessor(int(IMAGE_PROCESS_AT_A_TIME), imageProcessor.processImage) + + return imageProcessor, nil +} diff --git a/backend/router.go b/backend/router.go index 33c0213..e6107ed 100644 --- a/backend/router.go +++ b/backend/router.go @@ -41,7 +41,14 @@ func setupRouter(db *sql.DB, jwtManager *ourmiddleware.JwtManager) (chi.Router, return nil, fmt.Errorf("processor: %w", err) } + stackProcessorLog := createLogger("Stack0 Processor", os.Stdout) + stackProcessor, err := processor.NewStackProcessor(stackProcessorLog, stackModel, ¬ifier) + if err != nil { + return nil, fmt.Errorf("processor: %w", err) + } + go imageProcessor.Processor.Work() + go stackProcessor.Processor.Work() stackHandler := stacks.CreateStackHandler(db, limitsManager, jwtManager) authHandler := auth.CreateAuthHandler(db, jwtManager) diff --git a/backend/stacks/handler.go b/backend/stacks/handler.go index 76c0b4a..f5ebb89 100644 --- a/backend/stacks/handler.go +++ b/backend/stacks/handler.go @@ -209,14 +209,14 @@ func (h *StackHandler) createStack(body CreateStackBody, w http.ResponseWriter, } // TODO: Add the stack processor here - _, err = h.stackModel.Save(ctx, userID, body.Title, body.Description, model.Progress_NotStarted) + stack, err := h.stackModel.Save(ctx, userID, body.Title, body.Description, model.Progress_NotStarted) if err != nil { h.logger.Warn("could not save stack", "err", err) w.WriteHeader(http.StatusInternalServerError) return } - w.WriteHeader(http.StatusOK) + middleware.WriteJsonOrError(h.logger, stack, w) } func (h *StackHandler) CreateRoutes(r chi.Router) {