feat: adding limits

This commit is contained in:
2025-08-30 10:45:53 +01:00
parent de96f12b55
commit d97593d487
5 changed files with 120 additions and 14 deletions

View File

@ -10,6 +10,7 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"screenmark/screenmark/.gen/haystack/haystack/model" "screenmark/screenmark/.gen/haystack/haystack/model"
"screenmark/screenmark/limits"
"screenmark/screenmark/middleware" "screenmark/screenmark/middleware"
"screenmark/screenmark/models" "screenmark/screenmark/models"
@ -18,9 +19,10 @@ import (
) )
type ImageHandler struct { type ImageHandler struct {
logger *log.Logger logger *log.Logger
imageModel models.ImageModel imageModel models.ImageModel
userModel models.UserModel userModel models.UserModel
limitsManager limits.LimitsManagerMethods
} }
type ImagesReturn struct { type ImagesReturn struct {
@ -153,18 +155,19 @@ func (h *ImageHandler) CreateRoutes(r chi.Router) {
r.Use(middleware.SetJson) r.Use(middleware.SetJson)
r.Get("/", h.listImages) r.Get("/", h.listImages)
r.Post("/{name}", h.uploadImage) r.Post("/{name}", middleware.WithLimit(h.logger, h.limitsManager.HasReachedImageLimit, h.uploadImage))
}) })
} }
func CreateImageHandler(db *sql.DB) ImageHandler { func CreateImageHandler(db *sql.DB, limitsManager limits.LimitsManagerMethods) ImageHandler {
imageModel := models.NewImageModel(db) imageModel := models.NewImageModel(db)
userModel := models.NewUserModel(db) userModel := models.NewUserModel(db)
logger := log.New(os.Stdout).WithPrefix("Images") logger := log.New(os.Stdout).WithPrefix("Images")
return ImageHandler{ return ImageHandler{
logger: logger, logger: logger,
imageModel: imageModel, imageModel: imageModel,
userModel: userModel, userModel: userModel,
limitsManager: limitsManager,
} }
} }

61
backend/limits/limits.go Normal file
View File

@ -0,0 +1,61 @@
package limits
import (
"database/sql"
. "screenmark/screenmark/.gen/haystack/haystack/table"
. "github.com/go-jet/jet/v2/postgres"
"github.com/google/uuid"
)
const (
LISTS_LIMIT = 10
IMAGE_LIMIT = 50
)
type LimitsManager struct {
dbPool *sql.DB
}
type LimitsManagerMethods interface {
HasReachedStackLimit(userID uuid.UUID) (bool, error)
HasReachedImageLimit(userID uuid.UUID) (bool, error)
}
type listCount struct {
ListCount int `alias:"list_count"`
}
func (m *LimitsManager) HasReachedStackLimit(userID uuid.UUID) (bool, error) {
getStacks := Lists.
SELECT(COUNT(Lists.UserID).AS("listCount.ListCount")).
WHERE(Lists.UserID.EQ(UUID(userID)))
var count listCount
err := getStacks.Query(m.dbPool, &count)
return count.ListCount >= LISTS_LIMIT, err
}
type imageCount struct {
ImageCount int `alias:"image_count"`
}
func (m *LimitsManager) HasReachedImageLimit(userID uuid.UUID) (bool, error) {
getStacks := UserImages.
SELECT(COUNT(UserImages.UserID).AS("imageCount.ImageCount")).
WHERE(UserImages.UserID.EQ(UUID(userID)))
var count imageCount
err := getStacks.Query(m.dbPool, &count)
return count.ImageCount >= IMAGE_LIMIT, err
}
func CreateLimitsManager(db *sql.DB) *LimitsManager {
return &LimitsManager{
db,
}
}

View File

@ -0,0 +1,36 @@
package middleware
import (
"net/http"
"github.com/charmbracelet/log"
"github.com/google/uuid"
)
func WithLimit(logger *log.Logger, getLimit func(userID uuid.UUID) (bool, error), next func(w http.ResponseWriter, r *http.Request)) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
userID, err := GetUserID(ctx, logger, w)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
return
}
hasReachedLimit, err := getLimit(userID)
if err != nil {
logger.Error("failed to image limit", "err", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
logger.Info("Limits", "hasReachedLimit", hasReachedLimit)
if hasReachedLimit {
w.WriteHeader(http.StatusTooManyRequests)
return
}
next(w, r)
}
}

View File

@ -6,6 +6,7 @@ import (
"screenmark/screenmark/agents/client" "screenmark/screenmark/agents/client"
"screenmark/screenmark/auth" "screenmark/screenmark/auth"
"screenmark/screenmark/images" "screenmark/screenmark/images"
"screenmark/screenmark/limits"
"screenmark/screenmark/models" "screenmark/screenmark/models"
"screenmark/screenmark/stacks" "screenmark/screenmark/stacks"
@ -27,9 +28,11 @@ func setupRouter(db *sql.DB) chi.Router {
imageModel := models.NewImageModel(db) imageModel := models.NewImageModel(db)
stackModel := models.NewListModel(db) stackModel := models.NewListModel(db)
stackHandler := stacks.CreateStackHandler(db) limitsManager := limits.CreateLimitsManager(db)
stackHandler := stacks.CreateStackHandler(db, limitsManager)
authHandler := auth.CreateAuthHandler(db) authHandler := auth.CreateAuthHandler(db)
imageHandler := images.CreateImageHandler(db) imageHandler := images.CreateImageHandler(db, limitsManager)
notifier := NewNotifier[Notification](10) notifier := NewNotifier[Notification](10)

View File

@ -6,6 +6,7 @@ import (
"net/http" "net/http"
"os" "os"
. "screenmark/screenmark/.gen/haystack/haystack/model" . "screenmark/screenmark/.gen/haystack/haystack/model"
"screenmark/screenmark/limits"
"screenmark/screenmark/middleware" "screenmark/screenmark/middleware"
"screenmark/screenmark/models" "screenmark/screenmark/models"
"strings" "strings"
@ -15,8 +16,9 @@ import (
) )
type StackHandler struct { type StackHandler struct {
logger *log.Logger logger *log.Logger
stackModel models.ListModel stackModel models.ListModel
limitsManager limits.LimitsManagerMethods
} }
func (h *StackHandler) getAllStacks(w http.ResponseWriter, r *http.Request) { func (h *StackHandler) getAllStacks(w http.ResponseWriter, r *http.Request) {
@ -144,18 +146,19 @@ func (h *StackHandler) CreateRoutes(r chi.Router) {
r.Get("/", h.getAllStacks) r.Get("/", h.getAllStacks)
r.Get("/{listID}", h.getStackItems) r.Get("/{listID}", h.getStackItems)
r.Post("/", middleware.WithValidatedPost(h.createStack)) r.Post("/", middleware.WithLimit(h.logger, h.limitsManager.HasReachedStackLimit, middleware.WithValidatedPost(h.createStack)))
r.Patch("/{listID}", middleware.WithValidatedPost(h.editStack)) r.Patch("/{listID}", middleware.WithValidatedPost(h.editStack))
r.Delete("/{listID}", h.deleteStack) r.Delete("/{listID}", h.deleteStack)
}) })
} }
func CreateStackHandler(db *sql.DB) StackHandler { func CreateStackHandler(db *sql.DB, limitsManager limits.LimitsManagerMethods) StackHandler {
stackModel := models.NewListModel(db) stackModel := models.NewListModel(db)
logger := log.New(os.Stdout).WithPrefix("Stacks") logger := log.New(os.Stdout).WithPrefix("Stacks")
return StackHandler{ return StackHandler{
logger, logger,
stackModel, stackModel,
limitsManager,
} }
} }