diff --git a/backend/events.go b/backend/events.go index 520df97..b4b20b7 100644 --- a/backend/events.go +++ b/backend/events.go @@ -3,6 +3,8 @@ package main import ( "context" "database/sql" + "fmt" + "net/http" "os" "screenmark/screenmark/agents" "screenmark/screenmark/models" @@ -94,14 +96,67 @@ func ListenProcessingImageStatus(db *sql.DB, notifier *Notifier[string]) { for { select { case data := <-listener.Notify: - stringUuid := data.Extra[0:36] + imageStringUuid := data.Extra[0:36] status := data.Extra[36:] - logger.Info("Update", "id", stringUuid, "status", status) + imageUuid, err := uuid.Parse(imageStringUuid) + if err != nil { + logger.Error(err) + continue + } - if err := notifier.SendAndCreate(stringUuid, status); err != nil { + userId, err := models.GetUserId(db, context.Background(), imageUuid) + if err != nil { + logger.Error("GetUserID failed", "err", err) + continue + } + + logger.Info("Update", "id", imageStringUuid, "status", status) + + if err := notifier.SendAndCreate(userId.String(), status); err != nil { logger.Error(err) } } } } + +func CreateEventsHandler(notifier *Notifier[string]) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + userId := r.Context().Value(USER_ID).(uuid.UUID) + if userId == uuid.Nil { + w.WriteHeader(http.StatusUnauthorized) + return + } + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.(http.Flusher).Flush() + + notifier.AddKey(userId.String()) + defer notifier.RemoveKey(userId.String()) + + if err := notifier.Create(userId.String()); err != nil { + // TODO: this could be better. + // EG: The user could attempt to create many connections + // and they just get a 500, with no explanation. + w.WriteHeader(http.StatusInternalServerError) + return + } + + listener := notifier.Listeners[userId.String()] + + for { + select { + case <-r.Context().Done(): + fmt.Fprint(w, "event: close\ndata: Connection closed\n\n") + w.(http.Flusher).Flush() + return + case msg := <-listener: + fmt.Printf("Sending msg %s\n", msg) + fmt.Fprintf(w, "event: data\ndata: %s\n\n", msg) + w.(http.Flusher).Flush() + } + } + } +} diff --git a/backend/main.go b/backend/main.go index 87f50c6..6c5da7d 100644 --- a/backend/main.go +++ b/backend/main.go @@ -267,41 +267,10 @@ func main() { }) - r.Get("/image-events/{id}", func(w http.ResponseWriter, r *http.Request) { - // TODO: authentication :) - // - // id := r.PathValue("id") - // - // // TODO: get the current status of the image and send it across. - // ctx, cancel := context.WithCancel(r.Context()) - // - // w.Header().Set("Content-Type", "text/event-stream") - // w.Header().Set("Cache-Control", "no-cache") - // w.Header().Set("Connection", "keep-alive") - // w.(http.Flusher).Flush() - // - // for { - // select { - // case <-ctx.Done(): - // fmt.Fprint(w, "event: close\ndata: Connection closed\n\n") - // w.(http.Flusher).Flush() - // cancel() - // return - // case data := <-imageNotifier: - // if data == "" { - // cancel() - // continue - // } - // - // fmt.Printf("Status received: %s\n", data) - // fmt.Fprintf(w, "event: data\ndata: %s\n\n", data) - // w.(http.Flusher).Flush() - // - // if data == "complete" { - // cancel() - // } - // } - // } + r.Route("/notifications", func(r chi.Router) { + r.Use(GetUserIdFromUrl) + + r.Get("/", CreateEventsHandler(¬ifier)) }) r.Post("/login", func(w http.ResponseWriter, r *http.Request) { diff --git a/backend/middleware.go b/backend/middleware.go index 84227e2..9f9e00c 100644 --- a/backend/middleware.go +++ b/backend/middleware.go @@ -37,3 +37,25 @@ func ProtectedRoute(next http.Handler) http.Handler { next.ServeHTTP(w, newR) }) } + +func GetUserIdFromUrl(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + token := r.URL.Query().Get("token") + + if len(token) == 0 { + w.WriteHeader(http.StatusUnauthorized) + return + } + + userId, err := GetUserIdFromAccess(token) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + return + } + + contextWithUserId := context.WithValue(r.Context(), USER_ID, userId) + + newR := r.WithContext(contextWithUserId) + next.ServeHTTP(w, newR) + }) +} diff --git a/backend/models/image.go b/backend/models/image.go index ae0697d..8ca72bb 100644 --- a/backend/models/image.go +++ b/backend/models/image.go @@ -173,6 +173,17 @@ func (m ImageModel) IsUserAuthorized(ctx context.Context, imageId uuid.UUID, use return err != nil && userImage.UserID.String() == userId.String() } +func GetUserId(dbPool *sql.DB, ctx context.Context, imageId uuid.UUID) (uuid.UUID, error) { + getUserIdStmt := UserImagesToProcess. + SELECT(UserImagesToProcess.UserID). + WHERE(UserImagesToProcess.ID.EQ(UUID(imageId))) + + userImage := model.UserImagesToProcess{} + err := getUserIdStmt.QueryContext(ctx, dbPool, &userImage) + + return userImage.UserID, err +} + func NewImageModel(db *sql.DB) ImageModel { return ImageModel{dbPool: db} }