From 7b0c84e88e5bd7523dd9b5ea912f5dc9581f6bf8 Mon Sep 17 00:00:00 2001 From: John Costa Date: Sat, 10 May 2025 14:29:48 +0100 Subject: [PATCH] feat: adding allowed keys to only send when we have a userId This is to prevent users that aren't connected to the socket (somehow), to not fill up memory with buffered messages we'll never need. --- backend/notifications.go | 32 +++++++++++++++++++++++++------- backend/notifications_test.go | 32 +++++++++++++++++++++++++++----- 2 files changed, 52 insertions(+), 12 deletions(-) diff --git a/backend/notifications.go b/backend/notifications.go index 076367b..d4cf4c0 100644 --- a/backend/notifications.go +++ b/backend/notifications.go @@ -8,6 +8,8 @@ type Notifier[TNotification any] struct { bufferSize int Listeners map[string]chan TNotification + + AllowedKeys map[string]bool } func (n *Notifier[TNotification]) Create(id string) error { @@ -15,15 +17,23 @@ func (n *Notifier[TNotification]) Create(id string) error { return errors.New("This listener already exists") } + if _, exists := n.AllowedKeys[id]; !exists { + return errors.New("This key cannot be used to create a listener") + } + n.Listeners[id] = make(chan TNotification, n.bufferSize) return nil } +var ChannelFullErr = errors.New("Channel is full") + // Ensures the listener exists before sending func (n *Notifier[TNotification]) SendAndCreate(id string, notification TNotification) error { if _, exists := n.Listeners[id]; !exists { - n.Create(id) + if err := n.Create(id); err != nil { + return err + } } ch := n.Listeners[id] @@ -32,7 +42,7 @@ func (n *Notifier[TNotification]) SendAndCreate(id string, notification TNotific case ch <- notification: return nil default: - return errors.New("Channel is full") + return ChannelFullErr } } @@ -46,10 +56,18 @@ func (n *Notifier[TNotification]) Delete(id string) error { return nil } -func NewNotifier[TNotification any](bufferSize int) Notifier[TNotification] { - return Notifier[TNotification]{ - bufferSize: bufferSize, - Listeners: make(map[string]chan TNotification), - } +func (n *Notifier[TNotification]) AddKey(id string) { + n.AllowedKeys[id] = true } +func (n *Notifier[TNotification]) RemoveKey(id string) { + delete(n.AllowedKeys, id) +} + +func NewNotifier[TNotification any](bufferSize int) Notifier[TNotification] { + return Notifier[TNotification]{ + bufferSize: bufferSize, + Listeners: make(map[string]chan TNotification), + AllowedKeys: make(map[string]bool), + } +} diff --git a/backend/notifications_test.go b/backend/notifications_test.go index b2c8cfd..d986bb1 100644 --- a/backend/notifications_test.go +++ b/backend/notifications_test.go @@ -4,16 +4,25 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestSendingNotifications(t *testing.T) { assert := assert.New(t) + require := require.New(t) notifier := NewNotifier[string](3) - notifier.SendAndCreate("1", "a") - notifier.SendAndCreate("1", "b") - notifier.SendAndCreate("1", "c") + notifier.AddKey("1") + + err := notifier.SendAndCreate("1", "a") + require.NoError(err) + + err = notifier.SendAndCreate("1", "b") + require.NoError(err) + + err = notifier.SendAndCreate("1", "c") + require.NoError(err) ch := notifier.Listeners["1"] @@ -28,11 +37,24 @@ func TestSendingNotifications(t *testing.T) { func TestFullBuffer(t *testing.T) { assert := assert.New(t) + require := require.New(t) notifier := NewNotifier[string](1) - notifier.SendAndCreate("1", "a") - err := notifier.SendAndCreate("1", "b") + notifier.AddKey("1") + + err := notifier.SendAndCreate("1", "a") + require.NoError(err) + + err = notifier.SendAndCreate("1", "b") assert.Error(err) } + +func TestNoAllowedKey(t *testing.T) { + require := require.New(t) + notifier := NewNotifier[string](1) + + err := notifier.SendAndCreate("1", "a") + require.Error(err) +}