fix: notification system

This commit is contained in:
2025-10-05 12:10:06 +01:00
parent 649cfe0b02
commit 980b42aa44
12 changed files with 191 additions and 124 deletions

View File

@ -5,77 +5,22 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"screenmark/screenmark/middleware" "screenmark/screenmark/middleware"
"screenmark/screenmark/notifications"
"strconv" "strconv"
"github.com/google/uuid" "github.com/google/uuid"
) )
const (
IMAGE_TYPE = "image"
LIST_TYPE = "list"
)
type imageNotification struct {
Type string
ImageID uuid.UUID
ImageName string
Status string
}
type listNotification struct {
Type string
ListID uuid.UUID
Name string
Status string
}
type Notification struct {
image *imageNotification
list *listNotification
}
func getImageNotification(image imageNotification) Notification {
return Notification{
image: &image,
}
}
func getListNotification(list listNotification) Notification {
return Notification{
list: &list,
}
}
func (n Notification) MarshalJSON() ([]byte, error) {
if n.image != nil {
return json.Marshal(n.image)
}
if n.list != nil {
return json.Marshal(n.list)
}
return nil, fmt.Errorf("no image or list present")
}
func (n *Notification) UnmarshalJSON(data []byte) error {
return fmt.Errorf("unimplemented")
}
/* /*
* TODO: We have channels open every a user sends an image. * TODO: We have channels open every a user sends an image.
* We never close these channels. * We never close these channels.
* *
* What is a reasonable default? Close the channel after 1 minute of inactivity? * What is a reasonable default? Close the channel after 1 minute of inactivity?
*/ */
func CreateEventsHandler(notifier *Notifier[Notification]) http.HandlerFunc { func CreateEventsHandler(notifier *notifications.Notifier[notifications.Notification]) http.HandlerFunc {
counter := 0 counter := 0
userSplitters := make(map[string]*ChannelSplitter[Notification]) userSplitters := make(map[string]*notifications.ChannelSplitter[notifications.Notification])
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
_userId := r.Context().Value(middleware.USER_ID).(uuid.UUID) _userId := r.Context().Value(middleware.USER_ID).(uuid.UUID)
@ -98,7 +43,7 @@ func CreateEventsHandler(notifier *Notifier[Notification]) http.HandlerFunc {
userNotifications := notifier.Listeners[userId] userNotifications := notifier.Listeners[userId]
if _, exists := userSplitters[userId]; !exists { if _, exists := userSplitters[userId]; !exists {
splitter := NewChannelSplitter(userNotifications) splitter := notifications.NewChannelSplitter(userNotifications)
userSplitters[userId] = &splitter userSplitters[userId] = &splitter
splitter.Listen() splitter.Listen()

View File

@ -154,7 +154,11 @@ func (h *ImageHandler) uploadImage(w http.ResponseWriter, r *http.Request) {
h.logger.Info("About to add image") h.logger.Info("About to add image")
h.processor.Add(newImage) h.processor.Add(newImage)
w.WriteHeader(http.StatusOK) // We nullify the image's data, so we're not transferring all that
// data back to the frontend.
newImage.Image = nil
middleware.WriteJsonOrError(h.logger, newImage, w)
} }
func (h *ImageHandler) deleteImage(w http.ResponseWriter, r *http.Request) { func (h *ImageHandler) deleteImage(w http.ResponseWriter, r *http.Request) {

View File

@ -181,7 +181,11 @@ func setupTestContext(t *testing.T) *TestContext {
} }
jwtManager := middleware.NewJwtManager([]byte("test-jwt-secret")) jwtManager := middleware.NewJwtManager([]byte("test-jwt-secret"))
router := setupRouter(db, jwtManager) router, err := setupRouter(db, jwtManager)
if err != nil {
panic(err)
}
server := httptest.NewServer(router) server := httptest.NewServer(router)
tc.db = db tc.db = db

View File

@ -28,7 +28,10 @@ func main() {
panic(err) panic(err)
} }
router := setupRouter(db, jwtManager) router, err := setupRouter(db, jwtManager)
if err != nil {
panic(err)
}
port, exists := os.LookupEnv("PORT") port, exists := os.LookupEnv("PORT")
if !exists { if !exists {

View File

@ -0,0 +1,38 @@
package notifications
type ChannelSplitter[TNotification any] struct {
ch chan TNotification
Listeners map[string]chan TNotification
}
func (s *ChannelSplitter[TNotification]) Listen() {
go func() {
for {
select {
case msg := <-s.ch:
for _, v := range s.Listeners {
v <- msg
}
}
}
}()
}
func (s *ChannelSplitter[TNotification]) Add(id string) chan TNotification {
ch := make(chan TNotification)
s.Listeners[id] = ch
return ch
}
func (s *ChannelSplitter[TNotification]) Remove(id string) {
delete(s.Listeners, id)
}
func NewChannelSplitter[TNotification any](ch chan TNotification) ChannelSplitter[TNotification] {
return ChannelSplitter[TNotification]{
ch: ch,
Listeners: make(map[string]chan TNotification),
}
}

View File

@ -0,0 +1,64 @@
package notifications
import (
"encoding/json"
"fmt"
"github.com/google/uuid"
)
const (
IMAGE_TYPE = "image"
LIST_TYPE = "list"
)
type ImageNotification struct {
Type string
ImageID uuid.UUID
ImageName string
Status string
}
type ListNotification struct {
Type string
ListID uuid.UUID
Name string
Status string
}
type Notification struct {
image *ImageNotification
list *ListNotification
}
func GetImageNotification(image ImageNotification) Notification {
return Notification{
image: &image,
}
}
func GetListNotification(list ListNotification) Notification {
return Notification{
list: &list,
}
}
func (n Notification) MarshalJSON() ([]byte, error) {
if n.image != nil {
return json.Marshal(n.image)
}
if n.list != nil {
return json.Marshal(n.list)
}
return nil, fmt.Errorf("no image or list present")
}
func (n *Notification) UnmarshalJSON(data []byte) error {
return fmt.Errorf("unimplemented")
}

View File

@ -1,4 +1,4 @@
package main package notifications
import ( import (
"errors" "errors"
@ -56,42 +56,3 @@ func NewNotifier[TNotification any](bufferSize int) Notifier[TNotification] {
Listeners: make(map[string]chan TNotification), Listeners: make(map[string]chan TNotification),
} }
} }
// ----------------------------------
type ChannelSplitter[TNotification any] struct {
ch chan TNotification
Listeners map[string]chan TNotification
}
func (s *ChannelSplitter[TNotification]) Listen() {
go func() {
for {
select {
case msg := <-s.ch:
for _, v := range s.Listeners {
v <- msg
}
}
}
}()
}
func (s *ChannelSplitter[TNotification]) Add(id string) chan TNotification {
ch := make(chan TNotification)
s.Listeners[id] = ch
return ch
}
func (s *ChannelSplitter[TNotification]) Remove(id string) {
delete(s.Listeners, id)
}
func NewChannelSplitter[TNotification any](ch chan TNotification) ChannelSplitter[TNotification] {
return ChannelSplitter[TNotification]{
ch: ch,
Listeners: make(map[string]chan TNotification),
}
}

View File

@ -1,4 +1,4 @@
package main package notifications
import ( import (
"testing" "testing"

View File

@ -2,11 +2,13 @@ package processor
import ( import (
"context" "context"
"fmt"
"screenmark/screenmark/.gen/haystack/haystack/model" "screenmark/screenmark/.gen/haystack/haystack/model"
"screenmark/screenmark/agents" "screenmark/screenmark/agents"
"screenmark/screenmark/agents/client" "screenmark/screenmark/agents/client"
"screenmark/screenmark/limits" "screenmark/screenmark/limits"
"screenmark/screenmark/models" "screenmark/screenmark/models"
"screenmark/screenmark/notifications"
"sync" "sync"
"github.com/charmbracelet/log" "github.com/charmbracelet/log"
@ -24,6 +26,8 @@ type ImageProcessor struct {
// TODO: add the notifier here // TODO: add the notifier here
Processor *Processor[model.Image] Processor *Processor[model.Image]
notifier *notifications.Notifier[notifications.Notification]
} }
func (p *ImageProcessor) setImageToProcess(ctx context.Context, image model.Image) { func (p *ImageProcessor) setImageToProcess(ctx context.Context, image model.Image) {
@ -71,6 +75,19 @@ func (p *ImageProcessor) processImage(image model.Image) {
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(2) wg.Add(2)
imageNotification := notifications.GetImageNotification(notifications.ImageNotification{
Type: notifications.IMAGE_TYPE,
ImageID: image.ID,
ImageName: image.ImageName,
Status: string(model.Progress_InProgress),
})
err := p.notifier.SendAndCreate(image.UserID.String(), imageNotification)
if err != nil {
p.logger.Error("sending in progress notification", "err", err)
return
}
go func() { go func() {
p.describe(ctx, image) p.describe(ctx, image)
wg.Done() wg.Done()
@ -82,9 +99,34 @@ func (p *ImageProcessor) processImage(image model.Image) {
}() }()
wg.Wait() wg.Wait()
// TODO: there is some repeated code here. The ergonomicts of the notifications,
// isn't the best.
imageNotification = notifications.GetImageNotification(notifications.ImageNotification{
Type: notifications.IMAGE_TYPE,
ImageID: image.ID,
ImageName: image.ImageName,
Status: string(model.Progress_Complete),
})
err = p.notifier.SendAndCreate(image.UserID.String(), imageNotification)
if err != nil {
p.logger.Error("sending done notification", "err", err)
return
}
}
func NewImageProcessor(
logger *log.Logger,
imageModel models.ImageModel,
listModel models.StackModel,
limitsManager limits.LimitsManagerMethods,
notifier *notifications.Notifier[notifications.Notification],
) (ImageProcessor, error) {
if notifier == nil {
return ImageProcessor{}, fmt.Errorf("notifier is nil")
} }
func NewImageProcessor(logger *log.Logger, imageModel models.ImageModel, listModel models.StackModel, limitsManager limits.LimitsManagerMethods) ImageProcessor {
descriptionAgent := agents.NewDescriptionAgent(logger, imageModel) descriptionAgent := agents.NewDescriptionAgent(logger, imageModel)
stackAgent := agents.NewListAgent(logger, listModel, limitsManager) stackAgent := agents.NewListAgent(logger, listModel, limitsManager)
@ -93,9 +135,11 @@ func NewImageProcessor(logger *log.Logger, imageModel models.ImageModel, listMod
logger: logger, logger: logger,
descriptionAgent: descriptionAgent, descriptionAgent: descriptionAgent,
stackAgent: stackAgent, stackAgent: stackAgent,
notifier: notifier,
} }
imageProcessor.Processor = NewProcessor(int(IMAGE_PROCESS_AT_A_TIME), imageProcessor.processImage) imageProcessor.Processor = NewProcessor(int(IMAGE_PROCESS_AT_A_TIME), imageProcessor.processImage)
return imageProcessor return imageProcessor, nil
} }

View File

@ -2,12 +2,14 @@ package main
import ( import (
"database/sql" "database/sql"
"fmt"
"os" "os"
"screenmark/screenmark/agents/client" "screenmark/screenmark/agents/client"
"screenmark/screenmark/auth" "screenmark/screenmark/auth"
"screenmark/screenmark/images" "screenmark/screenmark/images"
"screenmark/screenmark/limits" "screenmark/screenmark/limits"
"screenmark/screenmark/models" "screenmark/screenmark/models"
"screenmark/screenmark/notifications"
"screenmark/screenmark/processor" "screenmark/screenmark/processor"
"screenmark/screenmark/stacks" "screenmark/screenmark/stacks"
@ -25,22 +27,26 @@ func (client TestAiClient) GetImageInfo(imageName string, imageData []byte) (cli
return client.ImageInfo, nil return client.ImageInfo, nil
} }
func setupRouter(db *sql.DB, jwtManager *ourmiddleware.JwtManager) chi.Router { func setupRouter(db *sql.DB, jwtManager *ourmiddleware.JwtManager) (chi.Router, error) {
limitsManager := limits.CreateLimitsManager(db) limitsManager := limits.CreateLimitsManager(db)
imageModel := models.NewImageModel(db) imageModel := models.NewImageModel(db)
stackModel := models.NewStackModel(db) stackModel := models.NewStackModel(db)
notifier := notifications.NewNotifier[notifications.Notification](10)
imageProcessorLogger := createLogger("Image Processor", os.Stdout) imageProcessorLogger := createLogger("Image Processor", os.Stdout)
imageProcessor := processor.NewImageProcessor(imageProcessorLogger, imageModel, stackModel, limitsManager) imageProcessor, err := processor.NewImageProcessor(imageProcessorLogger, imageModel, stackModel, limitsManager, &notifier)
if err != nil {
return nil, fmt.Errorf("processor: %w", err)
}
go imageProcessor.Processor.Work() go imageProcessor.Processor.Work()
stackHandler := stacks.CreateStackHandler(db, limitsManager, jwtManager) stackHandler := stacks.CreateStackHandler(db, limitsManager, jwtManager)
authHandler := auth.CreateAuthHandler(db, jwtManager) authHandler := auth.CreateAuthHandler(db, jwtManager)
imageHandler := images.CreateImageHandler(db, limitsManager, jwtManager, imageProcessor.Processor) imageHandler := images.CreateImageHandler(db, limitsManager, jwtManager, imageProcessor.Processor)
notifier := NewNotifier[Notification](10)
r := chi.NewRouter() r := chi.NewRouter()
r.Use(middleware.Logger) r.Use(middleware.Logger)
@ -56,5 +62,5 @@ func setupRouter(db *sql.DB, jwtManager *ourmiddleware.JwtManager) chi.Router {
r.Get("/", CreateEventsHandler(&notifier)) r.Get("/", CreateEventsHandler(&notifier))
}) })
return r return r, nil
} }

View File

@ -39,6 +39,8 @@ export const Notifications = (onCompleteImage: () => void) => {
const [accessToken] = createResource(getAccessToken); const [accessToken] = createResource(getAccessToken);
const dataEventListener = (e: MessageEvent<unknown>) => { const dataEventListener = (e: MessageEvent<unknown>) => {
debugger;
if (typeof e.data !== "string") { if (typeof e.data !== "string") {
console.error("Error type is not string"); console.error("Error type is not string");
return; return;
@ -98,7 +100,7 @@ export const Notifications = (onCompleteImage: () => void) => {
upsertImageProcessing( upsertImageProcessing(
Object.fromEntries( Object.fromEntries(
images.filter(i => i.Status !== 'complete').map((i) => [ images.filter(i => i.Status === 'complete').map((i) => [
i.ID, i.ID,
{ {
Type: "image", Type: "image",

View File

@ -78,17 +78,10 @@ const getBaseAuthorizedRequest = async ({
method, method,
}); });
}; };
const sendImageResponseValidator = strictObject({
ID: pipe(string(), uuid()),
ImageID: pipe(string(), uuid()),
UserID: pipe(string(), uuid()),
Status: string(),
});
export const sendImageFile = async ( export const sendImageFile = async (
imageName: string, imageName: string,
file: File, file: File,
): Promise<InferOutput<typeof sendImageResponseValidator>> => { ): Promise<InferOutput<typeof imageValidator>> => {
const request = await getBaseAuthorizedRequest({ const request = await getBaseAuthorizedRequest({
path: `images/${imageName}`, path: `images/${imageName}`,
body: file, body: file,
@ -98,7 +91,7 @@ export const sendImageFile = async (
request.headers.set("Content-Type", "application/oclet-stream"); request.headers.set("Content-Type", "application/oclet-stream");
const res = await fetch(request).then((res) => res.json()); const res = await fetch(request).then((res) => res.json());
const parsedRes = safeParse(sendImageResponseValidator, res); const parsedRes = safeParse(imageValidator, res);
if (!parsedRes.success) { if (!parsedRes.success) {
console.log(parsedRes.issues) console.log(parsedRes.issues)
@ -146,7 +139,7 @@ export class ImageLimitReached extends Error {
export const sendImage = async ( export const sendImage = async (
imageName: string, imageName: string,
base64Image: string, base64Image: string,
): Promise<InferOutput<typeof sendImageResponseValidator>> => { ): Promise<InferOutput<typeof imageValidator>> => {
const request = await getBaseAuthorizedRequest({ const request = await getBaseAuthorizedRequest({
path: `images/${imageName}`, path: `images/${imageName}`,
body: base64Image, body: base64Image,
@ -162,16 +155,16 @@ export const sendImage = async (
const res = await rawRes.json(); const res = await rawRes.json();
const parsedRes = safeParse(sendImageResponseValidator, res); const parsedRes = safeParse(imageValidator, res);
if (!parsedRes.success) { if (!parsedRes.success) {
console.log(parsedRes.issues) console.log("Parsing issues: ", parsedRes.issues)
throw new Error(JSON.stringify(parsedRes.issues)); throw new Error(JSON.stringify(parsedRes.issues));
} }
return parsedRes.output; return parsedRes.output;
}; };
const userImageValidator = strictObject({ const imageValidator = strictObject({
ID: pipe(string(), uuid()), ID: pipe(string(), uuid()),
CreatedAt: string(), CreatedAt: string(),
UserID: pipe(string(), uuid()), UserID: pipe(string(), uuid()),
@ -181,7 +174,10 @@ const userImageValidator = strictObject({
ImageName: string(), ImageName: string(),
Status: union([literal('not-started'), literal('in-progress'), literal('complete')]), Status: union([literal('not-started'), literal('in-progress'), literal('complete')]),
})
const userImageValidator = strictObject({
...imageValidator.entries,
ImageStacks: pipe(nullable(array( ImageStacks: pipe(nullable(array(
strictObject({ strictObject({
ID: pipe(string(), uuid()), ID: pipe(string(), uuid()),