From 980b42aa44dfca8decb1f3a32364228ac43f28c4 Mon Sep 17 00:00:00 2001 From: John Costa Date: Sun, 5 Oct 2025 12:10:06 +0100 Subject: [PATCH] fix: notification system --- backend/events.go | 63 ++---------------- backend/images/handler.go | 6 +- backend/integration_test.go | 6 +- backend/main.go | 5 +- backend/notifications/channel_splitter.go | 38 +++++++++++ backend/notifications/image_notification.go | 64 +++++++++++++++++++ backend/{ => notifications}/notifications.go | 41 +----------- .../{ => notifications}/notifications_test.go | 2 +- backend/processor/image.go | 48 +++++++++++++- backend/router.go | 16 +++-- frontend/src/contexts/Notifications.tsx | 4 +- frontend/src/network/index.ts | 22 +++---- 12 files changed, 191 insertions(+), 124 deletions(-) create mode 100644 backend/notifications/channel_splitter.go create mode 100644 backend/notifications/image_notification.go rename backend/{ => notifications}/notifications.go (59%) rename backend/{ => notifications}/notifications_test.go (97%) diff --git a/backend/events.go b/backend/events.go index 09c0c82..0421d7f 100644 --- a/backend/events.go +++ b/backend/events.go @@ -5,77 +5,22 @@ import ( "fmt" "net/http" "screenmark/screenmark/middleware" + "screenmark/screenmark/notifications" "strconv" "github.com/google/uuid" ) -const ( - IMAGE_TYPE = "image" - LIST_TYPE = "list" -) - -type imageNotification struct { - Type string - - ImageID uuid.UUID - ImageName string - - Status string -} - -type listNotification struct { - Type string - - ListID uuid.UUID - Name string - - Status string -} - -type Notification struct { - image *imageNotification - list *listNotification -} - -func getImageNotification(image imageNotification) Notification { - return Notification{ - image: &image, - } -} - -func getListNotification(list listNotification) Notification { - return Notification{ - list: &list, - } -} - -func (n Notification) MarshalJSON() ([]byte, error) { - if n.image != nil { - return json.Marshal(n.image) - } - - if n.list != nil { - return json.Marshal(n.list) - } - - return nil, fmt.Errorf("no image or list present") -} - -func (n *Notification) UnmarshalJSON(data []byte) error { - return fmt.Errorf("unimplemented") -} - /* * TODO: We have channels open every a user sends an image. * We never close these channels. * * What is a reasonable default? Close the channel after 1 minute of inactivity? */ -func CreateEventsHandler(notifier *Notifier[Notification]) http.HandlerFunc { +func CreateEventsHandler(notifier *notifications.Notifier[notifications.Notification]) http.HandlerFunc { counter := 0 - userSplitters := make(map[string]*ChannelSplitter[Notification]) + userSplitters := make(map[string]*notifications.ChannelSplitter[notifications.Notification]) return func(w http.ResponseWriter, r *http.Request) { _userId := r.Context().Value(middleware.USER_ID).(uuid.UUID) @@ -98,7 +43,7 @@ func CreateEventsHandler(notifier *Notifier[Notification]) http.HandlerFunc { userNotifications := notifier.Listeners[userId] if _, exists := userSplitters[userId]; !exists { - splitter := NewChannelSplitter(userNotifications) + splitter := notifications.NewChannelSplitter(userNotifications) userSplitters[userId] = &splitter splitter.Listen() diff --git a/backend/images/handler.go b/backend/images/handler.go index 61eda7b..9e4e83a 100644 --- a/backend/images/handler.go +++ b/backend/images/handler.go @@ -154,7 +154,11 @@ func (h *ImageHandler) uploadImage(w http.ResponseWriter, r *http.Request) { h.logger.Info("About to add image") h.processor.Add(newImage) - w.WriteHeader(http.StatusOK) + // We nullify the image's data, so we're not transferring all that + // data back to the frontend. + newImage.Image = nil + + middleware.WriteJsonOrError(h.logger, newImage, w) } func (h *ImageHandler) deleteImage(w http.ResponseWriter, r *http.Request) { diff --git a/backend/integration_test.go b/backend/integration_test.go index 948bfee..e7d980c 100644 --- a/backend/integration_test.go +++ b/backend/integration_test.go @@ -181,7 +181,11 @@ func setupTestContext(t *testing.T) *TestContext { } jwtManager := middleware.NewJwtManager([]byte("test-jwt-secret")) - router := setupRouter(db, jwtManager) + router, err := setupRouter(db, jwtManager) + if err != nil { + panic(err) + } + server := httptest.NewServer(router) tc.db = db diff --git a/backend/main.go b/backend/main.go index 193dc8e..89f0937 100644 --- a/backend/main.go +++ b/backend/main.go @@ -28,7 +28,10 @@ func main() { panic(err) } - router := setupRouter(db, jwtManager) + router, err := setupRouter(db, jwtManager) + if err != nil { + panic(err) + } port, exists := os.LookupEnv("PORT") if !exists { diff --git a/backend/notifications/channel_splitter.go b/backend/notifications/channel_splitter.go new file mode 100644 index 0000000..1027033 --- /dev/null +++ b/backend/notifications/channel_splitter.go @@ -0,0 +1,38 @@ +package notifications + +type ChannelSplitter[TNotification any] struct { + ch chan TNotification + + Listeners map[string]chan TNotification +} + +func (s *ChannelSplitter[TNotification]) Listen() { + go func() { + for { + select { + case msg := <-s.ch: + for _, v := range s.Listeners { + v <- msg + } + } + } + }() +} + +func (s *ChannelSplitter[TNotification]) Add(id string) chan TNotification { + ch := make(chan TNotification) + s.Listeners[id] = ch + + return ch +} + +func (s *ChannelSplitter[TNotification]) Remove(id string) { + delete(s.Listeners, id) +} + +func NewChannelSplitter[TNotification any](ch chan TNotification) ChannelSplitter[TNotification] { + return ChannelSplitter[TNotification]{ + ch: ch, + Listeners: make(map[string]chan TNotification), + } +} diff --git a/backend/notifications/image_notification.go b/backend/notifications/image_notification.go new file mode 100644 index 0000000..4ec47a4 --- /dev/null +++ b/backend/notifications/image_notification.go @@ -0,0 +1,64 @@ +package notifications + +import ( + "encoding/json" + "fmt" + + "github.com/google/uuid" +) + +const ( + IMAGE_TYPE = "image" + LIST_TYPE = "list" +) + +type ImageNotification struct { + Type string + + ImageID uuid.UUID + ImageName string + + Status string +} + +type ListNotification struct { + Type string + + ListID uuid.UUID + Name string + + Status string +} + +type Notification struct { + image *ImageNotification + list *ListNotification +} + +func GetImageNotification(image ImageNotification) Notification { + return Notification{ + image: &image, + } +} + +func GetListNotification(list ListNotification) Notification { + return Notification{ + list: &list, + } +} + +func (n Notification) MarshalJSON() ([]byte, error) { + if n.image != nil { + return json.Marshal(n.image) + } + + if n.list != nil { + return json.Marshal(n.list) + } + + return nil, fmt.Errorf("no image or list present") +} + +func (n *Notification) UnmarshalJSON(data []byte) error { + return fmt.Errorf("unimplemented") +} diff --git a/backend/notifications.go b/backend/notifications/notifications.go similarity index 59% rename from backend/notifications.go rename to backend/notifications/notifications.go index dc3f5ca..72b1fcd 100644 --- a/backend/notifications.go +++ b/backend/notifications/notifications.go @@ -1,4 +1,4 @@ -package main +package notifications import ( "errors" @@ -56,42 +56,3 @@ func NewNotifier[TNotification any](bufferSize int) Notifier[TNotification] { Listeners: make(map[string]chan TNotification), } } - -// ---------------------------------- - -type ChannelSplitter[TNotification any] struct { - ch chan TNotification - - Listeners map[string]chan TNotification -} - -func (s *ChannelSplitter[TNotification]) Listen() { - go func() { - for { - select { - case msg := <-s.ch: - for _, v := range s.Listeners { - v <- msg - } - } - } - }() -} - -func (s *ChannelSplitter[TNotification]) Add(id string) chan TNotification { - ch := make(chan TNotification) - s.Listeners[id] = ch - - return ch -} - -func (s *ChannelSplitter[TNotification]) Remove(id string) { - delete(s.Listeners, id) -} - -func NewChannelSplitter[TNotification any](ch chan TNotification) ChannelSplitter[TNotification] { - return ChannelSplitter[TNotification]{ - ch: ch, - Listeners: make(map[string]chan TNotification), - } -} diff --git a/backend/notifications_test.go b/backend/notifications/notifications_test.go similarity index 97% rename from backend/notifications_test.go rename to backend/notifications/notifications_test.go index 2e25752..bd4ca0c 100644 --- a/backend/notifications_test.go +++ b/backend/notifications/notifications_test.go @@ -1,4 +1,4 @@ -package main +package notifications import ( "testing" diff --git a/backend/processor/image.go b/backend/processor/image.go index ab5a302..46574f2 100644 --- a/backend/processor/image.go +++ b/backend/processor/image.go @@ -2,11 +2,13 @@ package processor import ( "context" + "fmt" "screenmark/screenmark/.gen/haystack/haystack/model" "screenmark/screenmark/agents" "screenmark/screenmark/agents/client" "screenmark/screenmark/limits" "screenmark/screenmark/models" + "screenmark/screenmark/notifications" "sync" "github.com/charmbracelet/log" @@ -24,6 +26,8 @@ type ImageProcessor struct { // TODO: add the notifier here Processor *Processor[model.Image] + + notifier *notifications.Notifier[notifications.Notification] } func (p *ImageProcessor) setImageToProcess(ctx context.Context, image model.Image) { @@ -71,6 +75,19 @@ func (p *ImageProcessor) processImage(image model.Image) { var wg sync.WaitGroup wg.Add(2) + imageNotification := notifications.GetImageNotification(notifications.ImageNotification{ + Type: notifications.IMAGE_TYPE, + ImageID: image.ID, + ImageName: image.ImageName, + Status: string(model.Progress_InProgress), + }) + + err := p.notifier.SendAndCreate(image.UserID.String(), imageNotification) + if err != nil { + p.logger.Error("sending in progress notification", "err", err) + return + } + go func() { p.describe(ctx, image) wg.Done() @@ -82,9 +99,34 @@ func (p *ImageProcessor) processImage(image model.Image) { }() wg.Wait() + + // TODO: there is some repeated code here. The ergonomicts of the notifications, + // isn't the best. + imageNotification = notifications.GetImageNotification(notifications.ImageNotification{ + Type: notifications.IMAGE_TYPE, + ImageID: image.ID, + ImageName: image.ImageName, + Status: string(model.Progress_Complete), + }) + + err = p.notifier.SendAndCreate(image.UserID.String(), imageNotification) + if err != nil { + p.logger.Error("sending done notification", "err", err) + return + } } -func NewImageProcessor(logger *log.Logger, imageModel models.ImageModel, listModel models.StackModel, limitsManager limits.LimitsManagerMethods) ImageProcessor { +func NewImageProcessor( + logger *log.Logger, + imageModel models.ImageModel, + listModel models.StackModel, + limitsManager limits.LimitsManagerMethods, + notifier *notifications.Notifier[notifications.Notification], +) (ImageProcessor, error) { + if notifier == nil { + return ImageProcessor{}, fmt.Errorf("notifier is nil") + } + descriptionAgent := agents.NewDescriptionAgent(logger, imageModel) stackAgent := agents.NewListAgent(logger, listModel, limitsManager) @@ -93,9 +135,11 @@ func NewImageProcessor(logger *log.Logger, imageModel models.ImageModel, listMod logger: logger, descriptionAgent: descriptionAgent, stackAgent: stackAgent, + + notifier: notifier, } imageProcessor.Processor = NewProcessor(int(IMAGE_PROCESS_AT_A_TIME), imageProcessor.processImage) - return imageProcessor + return imageProcessor, nil } diff --git a/backend/router.go b/backend/router.go index 1661b93..33c0213 100644 --- a/backend/router.go +++ b/backend/router.go @@ -2,12 +2,14 @@ package main import ( "database/sql" + "fmt" "os" "screenmark/screenmark/agents/client" "screenmark/screenmark/auth" "screenmark/screenmark/images" "screenmark/screenmark/limits" "screenmark/screenmark/models" + "screenmark/screenmark/notifications" "screenmark/screenmark/processor" "screenmark/screenmark/stacks" @@ -25,22 +27,26 @@ func (client TestAiClient) GetImageInfo(imageName string, imageData []byte) (cli return client.ImageInfo, nil } -func setupRouter(db *sql.DB, jwtManager *ourmiddleware.JwtManager) chi.Router { +func setupRouter(db *sql.DB, jwtManager *ourmiddleware.JwtManager) (chi.Router, error) { limitsManager := limits.CreateLimitsManager(db) imageModel := models.NewImageModel(db) stackModel := models.NewStackModel(db) + notifier := notifications.NewNotifier[notifications.Notification](10) + imageProcessorLogger := createLogger("Image Processor", os.Stdout) - imageProcessor := processor.NewImageProcessor(imageProcessorLogger, imageModel, stackModel, limitsManager) + imageProcessor, err := processor.NewImageProcessor(imageProcessorLogger, imageModel, stackModel, limitsManager, ¬ifier) + if err != nil { + return nil, fmt.Errorf("processor: %w", err) + } + go imageProcessor.Processor.Work() stackHandler := stacks.CreateStackHandler(db, limitsManager, jwtManager) authHandler := auth.CreateAuthHandler(db, jwtManager) imageHandler := images.CreateImageHandler(db, limitsManager, jwtManager, imageProcessor.Processor) - notifier := NewNotifier[Notification](10) - r := chi.NewRouter() r.Use(middleware.Logger) @@ -56,5 +62,5 @@ func setupRouter(db *sql.DB, jwtManager *ourmiddleware.JwtManager) chi.Router { r.Get("/", CreateEventsHandler(¬ifier)) }) - return r + return r, nil } diff --git a/frontend/src/contexts/Notifications.tsx b/frontend/src/contexts/Notifications.tsx index d7ea796..5133b86 100644 --- a/frontend/src/contexts/Notifications.tsx +++ b/frontend/src/contexts/Notifications.tsx @@ -39,6 +39,8 @@ export const Notifications = (onCompleteImage: () => void) => { const [accessToken] = createResource(getAccessToken); const dataEventListener = (e: MessageEvent) => { + debugger; + if (typeof e.data !== "string") { console.error("Error type is not string"); return; @@ -98,7 +100,7 @@ export const Notifications = (onCompleteImage: () => void) => { upsertImageProcessing( Object.fromEntries( - images.filter(i => i.Status !== 'complete').map((i) => [ + images.filter(i => i.Status === 'complete').map((i) => [ i.ID, { Type: "image", diff --git a/frontend/src/network/index.ts b/frontend/src/network/index.ts index 882a0b0..796c6c3 100644 --- a/frontend/src/network/index.ts +++ b/frontend/src/network/index.ts @@ -78,17 +78,10 @@ const getBaseAuthorizedRequest = async ({ method, }); }; -const sendImageResponseValidator = strictObject({ - ID: pipe(string(), uuid()), - ImageID: pipe(string(), uuid()), - UserID: pipe(string(), uuid()), - Status: string(), -}); - export const sendImageFile = async ( imageName: string, file: File, -): Promise> => { +): Promise> => { const request = await getBaseAuthorizedRequest({ path: `images/${imageName}`, body: file, @@ -98,7 +91,7 @@ export const sendImageFile = async ( request.headers.set("Content-Type", "application/oclet-stream"); const res = await fetch(request).then((res) => res.json()); - const parsedRes = safeParse(sendImageResponseValidator, res); + const parsedRes = safeParse(imageValidator, res); if (!parsedRes.success) { console.log(parsedRes.issues) @@ -146,7 +139,7 @@ export class ImageLimitReached extends Error { export const sendImage = async ( imageName: string, base64Image: string, -): Promise> => { +): Promise> => { const request = await getBaseAuthorizedRequest({ path: `images/${imageName}`, body: base64Image, @@ -162,16 +155,16 @@ export const sendImage = async ( const res = await rawRes.json(); - const parsedRes = safeParse(sendImageResponseValidator, res); + const parsedRes = safeParse(imageValidator, res); if (!parsedRes.success) { - console.log(parsedRes.issues) + console.log("Parsing issues: ", parsedRes.issues) throw new Error(JSON.stringify(parsedRes.issues)); } return parsedRes.output; }; -const userImageValidator = strictObject({ +const imageValidator = strictObject({ ID: pipe(string(), uuid()), CreatedAt: string(), UserID: pipe(string(), uuid()), @@ -181,7 +174,10 @@ const userImageValidator = strictObject({ ImageName: string(), Status: union([literal('not-started'), literal('in-progress'), literal('complete')]), +}) +const userImageValidator = strictObject({ + ...imageValidator.entries, ImageStacks: pipe(nullable(array( strictObject({ ID: pipe(string(), uuid()),