From d97593d487ca3b2bf50696d5e785d5533893efc9 Mon Sep 17 00:00:00 2001 From: John Costa Date: Sat, 30 Aug 2025 10:45:53 +0100 Subject: [PATCH] feat: adding limits --- backend/images/handler.go | 19 ++++++----- backend/limits/limits.go | 61 ++++++++++++++++++++++++++++++++++++ backend/middleware/limits.go | 36 +++++++++++++++++++++ backend/router.go | 7 +++-- backend/stacks/handler.go | 11 ++++--- 5 files changed, 120 insertions(+), 14 deletions(-) create mode 100644 backend/limits/limits.go create mode 100644 backend/middleware/limits.go diff --git a/backend/images/handler.go b/backend/images/handler.go index a35b289..35fa794 100644 --- a/backend/images/handler.go +++ b/backend/images/handler.go @@ -10,6 +10,7 @@ import ( "os" "path/filepath" "screenmark/screenmark/.gen/haystack/haystack/model" + "screenmark/screenmark/limits" "screenmark/screenmark/middleware" "screenmark/screenmark/models" @@ -18,9 +19,10 @@ import ( ) type ImageHandler struct { - logger *log.Logger - imageModel models.ImageModel - userModel models.UserModel + logger *log.Logger + imageModel models.ImageModel + userModel models.UserModel + limitsManager limits.LimitsManagerMethods } type ImagesReturn struct { @@ -153,18 +155,19 @@ func (h *ImageHandler) CreateRoutes(r chi.Router) { r.Use(middleware.SetJson) r.Get("/", h.listImages) - r.Post("/{name}", h.uploadImage) + r.Post("/{name}", middleware.WithLimit(h.logger, h.limitsManager.HasReachedImageLimit, h.uploadImage)) }) } -func CreateImageHandler(db *sql.DB) ImageHandler { +func CreateImageHandler(db *sql.DB, limitsManager limits.LimitsManagerMethods) ImageHandler { imageModel := models.NewImageModel(db) userModel := models.NewUserModel(db) logger := log.New(os.Stdout).WithPrefix("Images") return ImageHandler{ - logger: logger, - imageModel: imageModel, - userModel: userModel, + logger: logger, + imageModel: imageModel, + userModel: userModel, + limitsManager: limitsManager, } } diff --git a/backend/limits/limits.go b/backend/limits/limits.go new file mode 100644 index 0000000..3190d1e --- /dev/null +++ b/backend/limits/limits.go @@ -0,0 +1,61 @@ +package limits + +import ( + "database/sql" + + . "screenmark/screenmark/.gen/haystack/haystack/table" + + . "github.com/go-jet/jet/v2/postgres" + + "github.com/google/uuid" +) + +const ( + LISTS_LIMIT = 10 + IMAGE_LIMIT = 50 +) + +type LimitsManager struct { + dbPool *sql.DB +} + +type LimitsManagerMethods interface { + HasReachedStackLimit(userID uuid.UUID) (bool, error) + HasReachedImageLimit(userID uuid.UUID) (bool, error) +} + +type listCount struct { + ListCount int `alias:"list_count"` +} + +func (m *LimitsManager) HasReachedStackLimit(userID uuid.UUID) (bool, error) { + getStacks := Lists. + SELECT(COUNT(Lists.UserID).AS("listCount.ListCount")). + WHERE(Lists.UserID.EQ(UUID(userID))) + + var count listCount + err := getStacks.Query(m.dbPool, &count) + + return count.ListCount >= LISTS_LIMIT, err +} + +type imageCount struct { + ImageCount int `alias:"image_count"` +} + +func (m *LimitsManager) HasReachedImageLimit(userID uuid.UUID) (bool, error) { + getStacks := UserImages. + SELECT(COUNT(UserImages.UserID).AS("imageCount.ImageCount")). + WHERE(UserImages.UserID.EQ(UUID(userID))) + + var count imageCount + err := getStacks.Query(m.dbPool, &count) + + return count.ImageCount >= IMAGE_LIMIT, err +} + +func CreateLimitsManager(db *sql.DB) *LimitsManager { + return &LimitsManager{ + db, + } +} diff --git a/backend/middleware/limits.go b/backend/middleware/limits.go new file mode 100644 index 0000000..ce41770 --- /dev/null +++ b/backend/middleware/limits.go @@ -0,0 +1,36 @@ +package middleware + +import ( + "net/http" + + "github.com/charmbracelet/log" + "github.com/google/uuid" +) + +func WithLimit(logger *log.Logger, getLimit func(userID uuid.UUID) (bool, error), next func(w http.ResponseWriter, r *http.Request)) func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + userID, err := GetUserID(ctx, logger, w) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + return + } + + hasReachedLimit, err := getLimit(userID) + if err != nil { + logger.Error("failed to image limit", "err", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + + logger.Info("Limits", "hasReachedLimit", hasReachedLimit) + + if hasReachedLimit { + w.WriteHeader(http.StatusTooManyRequests) + return + } + + next(w, r) + } +} diff --git a/backend/router.go b/backend/router.go index 7e233e8..00f1dee 100644 --- a/backend/router.go +++ b/backend/router.go @@ -6,6 +6,7 @@ import ( "screenmark/screenmark/agents/client" "screenmark/screenmark/auth" "screenmark/screenmark/images" + "screenmark/screenmark/limits" "screenmark/screenmark/models" "screenmark/screenmark/stacks" @@ -27,9 +28,11 @@ func setupRouter(db *sql.DB) chi.Router { imageModel := models.NewImageModel(db) stackModel := models.NewListModel(db) - stackHandler := stacks.CreateStackHandler(db) + limitsManager := limits.CreateLimitsManager(db) + + stackHandler := stacks.CreateStackHandler(db, limitsManager) authHandler := auth.CreateAuthHandler(db) - imageHandler := images.CreateImageHandler(db) + imageHandler := images.CreateImageHandler(db, limitsManager) notifier := NewNotifier[Notification](10) diff --git a/backend/stacks/handler.go b/backend/stacks/handler.go index 736db8e..7e350e1 100644 --- a/backend/stacks/handler.go +++ b/backend/stacks/handler.go @@ -6,6 +6,7 @@ import ( "net/http" "os" . "screenmark/screenmark/.gen/haystack/haystack/model" + "screenmark/screenmark/limits" "screenmark/screenmark/middleware" "screenmark/screenmark/models" "strings" @@ -15,8 +16,9 @@ import ( ) type StackHandler struct { - logger *log.Logger - stackModel models.ListModel + logger *log.Logger + stackModel models.ListModel + limitsManager limits.LimitsManagerMethods } func (h *StackHandler) getAllStacks(w http.ResponseWriter, r *http.Request) { @@ -144,18 +146,19 @@ func (h *StackHandler) CreateRoutes(r chi.Router) { r.Get("/", h.getAllStacks) r.Get("/{listID}", h.getStackItems) - r.Post("/", middleware.WithValidatedPost(h.createStack)) + r.Post("/", middleware.WithLimit(h.logger, h.limitsManager.HasReachedStackLimit, middleware.WithValidatedPost(h.createStack))) r.Patch("/{listID}", middleware.WithValidatedPost(h.editStack)) r.Delete("/{listID}", h.deleteStack) }) } -func CreateStackHandler(db *sql.DB) StackHandler { +func CreateStackHandler(db *sql.DB, limitsManager limits.LimitsManagerMethods) StackHandler { stackModel := models.NewListModel(db) logger := log.New(os.Stdout).WithPrefix("Stacks") return StackHandler{ logger, stackModel, + limitsManager, } }