feat: adding allowed keys to only send when we have a userId

This is to prevent users that aren't connected to the socket (somehow),
to not fill up memory with buffered messages we'll never need.
This commit is contained in:
2025-05-10 14:29:48 +01:00
parent ff7960e2dd
commit 7b0c84e88e
2 changed files with 52 additions and 12 deletions

View File

@ -8,6 +8,8 @@ type Notifier[TNotification any] struct {
bufferSize int bufferSize int
Listeners map[string]chan TNotification Listeners map[string]chan TNotification
AllowedKeys map[string]bool
} }
func (n *Notifier[TNotification]) Create(id string) error { func (n *Notifier[TNotification]) Create(id string) error {
@ -15,15 +17,23 @@ func (n *Notifier[TNotification]) Create(id string) error {
return errors.New("This listener already exists") return errors.New("This listener already exists")
} }
if _, exists := n.AllowedKeys[id]; !exists {
return errors.New("This key cannot be used to create a listener")
}
n.Listeners[id] = make(chan TNotification, n.bufferSize) n.Listeners[id] = make(chan TNotification, n.bufferSize)
return nil return nil
} }
var ChannelFullErr = errors.New("Channel is full")
// Ensures the listener exists before sending // Ensures the listener exists before sending
func (n *Notifier[TNotification]) SendAndCreate(id string, notification TNotification) error { func (n *Notifier[TNotification]) SendAndCreate(id string, notification TNotification) error {
if _, exists := n.Listeners[id]; !exists { if _, exists := n.Listeners[id]; !exists {
n.Create(id) if err := n.Create(id); err != nil {
return err
}
} }
ch := n.Listeners[id] ch := n.Listeners[id]
@ -32,7 +42,7 @@ func (n *Notifier[TNotification]) SendAndCreate(id string, notification TNotific
case ch <- notification: case ch <- notification:
return nil return nil
default: default:
return errors.New("Channel is full") return ChannelFullErr
} }
} }
@ -46,10 +56,18 @@ func (n *Notifier[TNotification]) Delete(id string) error {
return nil return nil
} }
func NewNotifier[TNotification any](bufferSize int) Notifier[TNotification] { func (n *Notifier[TNotification]) AddKey(id string) {
return Notifier[TNotification]{ n.AllowedKeys[id] = true
bufferSize: bufferSize,
Listeners: make(map[string]chan TNotification),
}
} }
func (n *Notifier[TNotification]) RemoveKey(id string) {
delete(n.AllowedKeys, id)
}
func NewNotifier[TNotification any](bufferSize int) Notifier[TNotification] {
return Notifier[TNotification]{
bufferSize: bufferSize,
Listeners: make(map[string]chan TNotification),
AllowedKeys: make(map[string]bool),
}
}

View File

@ -4,16 +4,25 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestSendingNotifications(t *testing.T) { func TestSendingNotifications(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
require := require.New(t)
notifier := NewNotifier[string](3) notifier := NewNotifier[string](3)
notifier.SendAndCreate("1", "a") notifier.AddKey("1")
notifier.SendAndCreate("1", "b")
notifier.SendAndCreate("1", "c") err := notifier.SendAndCreate("1", "a")
require.NoError(err)
err = notifier.SendAndCreate("1", "b")
require.NoError(err)
err = notifier.SendAndCreate("1", "c")
require.NoError(err)
ch := notifier.Listeners["1"] ch := notifier.Listeners["1"]
@ -28,11 +37,24 @@ func TestSendingNotifications(t *testing.T) {
func TestFullBuffer(t *testing.T) { func TestFullBuffer(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
require := require.New(t)
notifier := NewNotifier[string](1) notifier := NewNotifier[string](1)
notifier.SendAndCreate("1", "a") notifier.AddKey("1")
err := notifier.SendAndCreate("1", "b")
err := notifier.SendAndCreate("1", "a")
require.NoError(err)
err = notifier.SendAndCreate("1", "b")
assert.Error(err) assert.Error(err)
} }
func TestNoAllowedKey(t *testing.T) {
require := require.New(t)
notifier := NewNotifier[string](1)
err := notifier.SendAndCreate("1", "a")
require.Error(err)
}