From ff7960e2dd3abf0a307e3d1482ae3b6c0640dead Mon Sep 17 00:00:00 2001 From: John Costa Date: Sat, 10 May 2025 14:21:18 +0100 Subject: [PATCH] feat: notifier module with buffered channels --- backend/events.go | 31 ++----------- backend/main.go | 82 +++++++++++++++-------------------- backend/notifications.go | 55 +++++++++++++++++++++++ backend/notifications_test.go | 38 ++++++++++++++++ 4 files changed, 132 insertions(+), 74 deletions(-) create mode 100644 backend/notifications.go create mode 100644 backend/notifications_test.go diff --git a/backend/events.go b/backend/events.go index 6eda913..83afb62 100644 --- a/backend/events.go +++ b/backend/events.go @@ -13,7 +13,7 @@ import ( "github.com/lib/pq" ) -func ListenNewImageEvents(db *sql.DB, eventManager *EventManager) { +func ListenNewImageEvents(db *sql.DB, notifier *Notifier[string]) { listener := pq.NewListener(os.Getenv("DB_CONNECTION"), time.Second, time.Second, func(event pq.ListenerEventType, err error) { if err != nil { panic(err) @@ -39,7 +39,6 @@ func ListenNewImageEvents(db *sql.DB, eventManager *EventManager) { select { case parameters := <-listener.Notify: imageId := uuid.MustParse(parameters.Extra) - eventManager.listeners[parameters.Extra] = make(chan string) databaseEventLog.Debug("Starting processing image", "ImageID", imageId) @@ -78,18 +77,7 @@ func ListenNewImageEvents(db *sql.DB, eventManager *EventManager) { } } -type EventManager struct { - // Maps processing image UUID to a channel - listeners map[string]chan string -} - -func NewEventManager() EventManager { - return EventManager{ - listeners: make(map[string]chan string), - } -} - -func ListenProcessingImageStatus(db *sql.DB, eventManager *EventManager) { +func ListenProcessingImageStatus(db *sql.DB, notifier *Notifier[string]) { listener := pq.NewListener(os.Getenv("DB_CONNECTION"), time.Second, time.Second, func(event pq.ListenerEventType, err error) { if err != nil { panic(err) @@ -111,20 +99,7 @@ func ListenProcessingImageStatus(db *sql.DB, eventManager *EventManager) { logger.Info("Update", "id", stringUuid, "status", status) - imageListener, exists := eventManager.listeners[stringUuid] - if !exists { - continue - } - - logger.Info("Sending...") - imageListener <- status - - if status != "complete" { - continue - } - - close(imageListener) - delete(eventManager.listeners, stringUuid) + notifier.SendAndCreate(stringUuid, status) } } } diff --git a/backend/main.go b/backend/main.go index 037bc29..87f50c6 100644 --- a/backend/main.go +++ b/backend/main.go @@ -2,7 +2,6 @@ package main import ( "bytes" - "context" "encoding/base64" "encoding/json" "fmt" @@ -49,10 +48,10 @@ func main() { auth := CreateAuth(mail) - eventManager := NewEventManager() + notifier := NewNotifier[string](10) - go ListenNewImageEvents(db, &eventManager) - go ListenProcessingImageStatus(db, &eventManager) + go ListenNewImageEvents(db, ¬ifier) + go ListenProcessingImageStatus(db, ¬ifier) r := chi.NewRouter() @@ -270,48 +269,39 @@ func main() { r.Get("/image-events/{id}", func(w http.ResponseWriter, r *http.Request) { // TODO: authentication :) - - id := r.PathValue("id") - - // TODO: get the current status of the image and send it across. - ctx, cancel := context.WithCancel(r.Context()) - - imageNotifier, exists := eventManager.listeners[id] - if !exists { - fmt.Println("Not found!") - w.WriteHeader(http.StatusNotFound) - w.(http.Flusher).Flush() - cancel() - return - } - - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - w.(http.Flusher).Flush() - - for { - select { - case <-ctx.Done(): - fmt.Fprint(w, "event: close\ndata: Connection closed\n\n") - w.(http.Flusher).Flush() - cancel() - return - case data := <-imageNotifier: - if data == "" { - cancel() - continue - } - - fmt.Printf("Status received: %s\n", data) - fmt.Fprintf(w, "event: data\ndata: %s\n\n", data) - w.(http.Flusher).Flush() - - if data == "complete" { - cancel() - } - } - } + // + // id := r.PathValue("id") + // + // // TODO: get the current status of the image and send it across. + // ctx, cancel := context.WithCancel(r.Context()) + // + // w.Header().Set("Content-Type", "text/event-stream") + // w.Header().Set("Cache-Control", "no-cache") + // w.Header().Set("Connection", "keep-alive") + // w.(http.Flusher).Flush() + // + // for { + // select { + // case <-ctx.Done(): + // fmt.Fprint(w, "event: close\ndata: Connection closed\n\n") + // w.(http.Flusher).Flush() + // cancel() + // return + // case data := <-imageNotifier: + // if data == "" { + // cancel() + // continue + // } + // + // fmt.Printf("Status received: %s\n", data) + // fmt.Fprintf(w, "event: data\ndata: %s\n\n", data) + // w.(http.Flusher).Flush() + // + // if data == "complete" { + // cancel() + // } + // } + // } }) r.Post("/login", func(w http.ResponseWriter, r *http.Request) { diff --git a/backend/notifications.go b/backend/notifications.go new file mode 100644 index 0000000..076367b --- /dev/null +++ b/backend/notifications.go @@ -0,0 +1,55 @@ +package main + +import ( + "errors" +) + +type Notifier[TNotification any] struct { + bufferSize int + + Listeners map[string]chan TNotification +} + +func (n *Notifier[TNotification]) Create(id string) error { + if _, exists := n.Listeners[id]; exists { + return errors.New("This listener already exists") + } + + n.Listeners[id] = make(chan TNotification, n.bufferSize) + + return nil +} + +// Ensures the listener exists before sending +func (n *Notifier[TNotification]) SendAndCreate(id string, notification TNotification) error { + if _, exists := n.Listeners[id]; !exists { + n.Create(id) + } + + ch := n.Listeners[id] + + select { + case ch <- notification: + return nil + default: + return errors.New("Channel is full") + } +} + +func (n *Notifier[TNotification]) Delete(id string) error { + if _, exists := n.Listeners[id]; !exists { + return errors.New("This listener does not exists") + } + + delete(n.Listeners, id) + + return nil +} + +func NewNotifier[TNotification any](bufferSize int) Notifier[TNotification] { + return Notifier[TNotification]{ + bufferSize: bufferSize, + Listeners: make(map[string]chan TNotification), + } +} + diff --git a/backend/notifications_test.go b/backend/notifications_test.go new file mode 100644 index 0000000..b2c8cfd --- /dev/null +++ b/backend/notifications_test.go @@ -0,0 +1,38 @@ +package main + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSendingNotifications(t *testing.T) { + assert := assert.New(t) + + notifier := NewNotifier[string](3) + + notifier.SendAndCreate("1", "a") + notifier.SendAndCreate("1", "b") + notifier.SendAndCreate("1", "c") + + ch := notifier.Listeners["1"] + + a := <-ch + b := <-ch + c := <-ch + + assert.Equal(a, "a") + assert.Equal(b, "b") + assert.Equal(c, "c") +} + +func TestFullBuffer(t *testing.T) { + assert := assert.New(t) + + notifier := NewNotifier[string](1) + + notifier.SendAndCreate("1", "a") + err := notifier.SendAndCreate("1", "b") + + assert.Error(err) +}