feat: adding limits
This commit is contained in:
@ -10,6 +10,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"screenmark/screenmark/.gen/haystack/haystack/model"
|
"screenmark/screenmark/.gen/haystack/haystack/model"
|
||||||
|
"screenmark/screenmark/limits"
|
||||||
"screenmark/screenmark/middleware"
|
"screenmark/screenmark/middleware"
|
||||||
"screenmark/screenmark/models"
|
"screenmark/screenmark/models"
|
||||||
|
|
||||||
@ -21,6 +22,7 @@ type ImageHandler struct {
|
|||||||
logger *log.Logger
|
logger *log.Logger
|
||||||
imageModel models.ImageModel
|
imageModel models.ImageModel
|
||||||
userModel models.UserModel
|
userModel models.UserModel
|
||||||
|
limitsManager limits.LimitsManagerMethods
|
||||||
}
|
}
|
||||||
|
|
||||||
type ImagesReturn struct {
|
type ImagesReturn struct {
|
||||||
@ -153,11 +155,11 @@ func (h *ImageHandler) CreateRoutes(r chi.Router) {
|
|||||||
r.Use(middleware.SetJson)
|
r.Use(middleware.SetJson)
|
||||||
|
|
||||||
r.Get("/", h.listImages)
|
r.Get("/", h.listImages)
|
||||||
r.Post("/{name}", h.uploadImage)
|
r.Post("/{name}", middleware.WithLimit(h.logger, h.limitsManager.HasReachedImageLimit, h.uploadImage))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateImageHandler(db *sql.DB) ImageHandler {
|
func CreateImageHandler(db *sql.DB, limitsManager limits.LimitsManagerMethods) ImageHandler {
|
||||||
imageModel := models.NewImageModel(db)
|
imageModel := models.NewImageModel(db)
|
||||||
userModel := models.NewUserModel(db)
|
userModel := models.NewUserModel(db)
|
||||||
logger := log.New(os.Stdout).WithPrefix("Images")
|
logger := log.New(os.Stdout).WithPrefix("Images")
|
||||||
@ -166,5 +168,6 @@ func CreateImageHandler(db *sql.DB) ImageHandler {
|
|||||||
logger: logger,
|
logger: logger,
|
||||||
imageModel: imageModel,
|
imageModel: imageModel,
|
||||||
userModel: userModel,
|
userModel: userModel,
|
||||||
|
limitsManager: limitsManager,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
61
backend/limits/limits.go
Normal file
61
backend/limits/limits.go
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
package limits
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
. "screenmark/screenmark/.gen/haystack/haystack/table"
|
||||||
|
|
||||||
|
. "github.com/go-jet/jet/v2/postgres"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
LISTS_LIMIT = 10
|
||||||
|
IMAGE_LIMIT = 50
|
||||||
|
)
|
||||||
|
|
||||||
|
type LimitsManager struct {
|
||||||
|
dbPool *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
type LimitsManagerMethods interface {
|
||||||
|
HasReachedStackLimit(userID uuid.UUID) (bool, error)
|
||||||
|
HasReachedImageLimit(userID uuid.UUID) (bool, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type listCount struct {
|
||||||
|
ListCount int `alias:"list_count"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *LimitsManager) HasReachedStackLimit(userID uuid.UUID) (bool, error) {
|
||||||
|
getStacks := Lists.
|
||||||
|
SELECT(COUNT(Lists.UserID).AS("listCount.ListCount")).
|
||||||
|
WHERE(Lists.UserID.EQ(UUID(userID)))
|
||||||
|
|
||||||
|
var count listCount
|
||||||
|
err := getStacks.Query(m.dbPool, &count)
|
||||||
|
|
||||||
|
return count.ListCount >= LISTS_LIMIT, err
|
||||||
|
}
|
||||||
|
|
||||||
|
type imageCount struct {
|
||||||
|
ImageCount int `alias:"image_count"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *LimitsManager) HasReachedImageLimit(userID uuid.UUID) (bool, error) {
|
||||||
|
getStacks := UserImages.
|
||||||
|
SELECT(COUNT(UserImages.UserID).AS("imageCount.ImageCount")).
|
||||||
|
WHERE(UserImages.UserID.EQ(UUID(userID)))
|
||||||
|
|
||||||
|
var count imageCount
|
||||||
|
err := getStacks.Query(m.dbPool, &count)
|
||||||
|
|
||||||
|
return count.ImageCount >= IMAGE_LIMIT, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateLimitsManager(db *sql.DB) *LimitsManager {
|
||||||
|
return &LimitsManager{
|
||||||
|
db,
|
||||||
|
}
|
||||||
|
}
|
36
backend/middleware/limits.go
Normal file
36
backend/middleware/limits.go
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/charmbracelet/log"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
func WithLimit(logger *log.Logger, getLimit func(userID uuid.UUID) (bool, error), next func(w http.ResponseWriter, r *http.Request)) func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
|
|
||||||
|
userID, err := GetUserID(ctx, logger, w)
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
hasReachedLimit, err := getLimit(userID)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("failed to image limit", "err", err)
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Limits", "hasReachedLimit", hasReachedLimit)
|
||||||
|
|
||||||
|
if hasReachedLimit {
|
||||||
|
w.WriteHeader(http.StatusTooManyRequests)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
next(w, r)
|
||||||
|
}
|
||||||
|
}
|
@ -6,6 +6,7 @@ import (
|
|||||||
"screenmark/screenmark/agents/client"
|
"screenmark/screenmark/agents/client"
|
||||||
"screenmark/screenmark/auth"
|
"screenmark/screenmark/auth"
|
||||||
"screenmark/screenmark/images"
|
"screenmark/screenmark/images"
|
||||||
|
"screenmark/screenmark/limits"
|
||||||
"screenmark/screenmark/models"
|
"screenmark/screenmark/models"
|
||||||
"screenmark/screenmark/stacks"
|
"screenmark/screenmark/stacks"
|
||||||
|
|
||||||
@ -27,9 +28,11 @@ func setupRouter(db *sql.DB) chi.Router {
|
|||||||
imageModel := models.NewImageModel(db)
|
imageModel := models.NewImageModel(db)
|
||||||
stackModel := models.NewListModel(db)
|
stackModel := models.NewListModel(db)
|
||||||
|
|
||||||
stackHandler := stacks.CreateStackHandler(db)
|
limitsManager := limits.CreateLimitsManager(db)
|
||||||
|
|
||||||
|
stackHandler := stacks.CreateStackHandler(db, limitsManager)
|
||||||
authHandler := auth.CreateAuthHandler(db)
|
authHandler := auth.CreateAuthHandler(db)
|
||||||
imageHandler := images.CreateImageHandler(db)
|
imageHandler := images.CreateImageHandler(db, limitsManager)
|
||||||
|
|
||||||
notifier := NewNotifier[Notification](10)
|
notifier := NewNotifier[Notification](10)
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
. "screenmark/screenmark/.gen/haystack/haystack/model"
|
. "screenmark/screenmark/.gen/haystack/haystack/model"
|
||||||
|
"screenmark/screenmark/limits"
|
||||||
"screenmark/screenmark/middleware"
|
"screenmark/screenmark/middleware"
|
||||||
"screenmark/screenmark/models"
|
"screenmark/screenmark/models"
|
||||||
"strings"
|
"strings"
|
||||||
@ -17,6 +18,7 @@ import (
|
|||||||
type StackHandler struct {
|
type StackHandler struct {
|
||||||
logger *log.Logger
|
logger *log.Logger
|
||||||
stackModel models.ListModel
|
stackModel models.ListModel
|
||||||
|
limitsManager limits.LimitsManagerMethods
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *StackHandler) getAllStacks(w http.ResponseWriter, r *http.Request) {
|
func (h *StackHandler) getAllStacks(w http.ResponseWriter, r *http.Request) {
|
||||||
@ -144,18 +146,19 @@ func (h *StackHandler) CreateRoutes(r chi.Router) {
|
|||||||
r.Get("/", h.getAllStacks)
|
r.Get("/", h.getAllStacks)
|
||||||
r.Get("/{listID}", h.getStackItems)
|
r.Get("/{listID}", h.getStackItems)
|
||||||
|
|
||||||
r.Post("/", middleware.WithValidatedPost(h.createStack))
|
r.Post("/", middleware.WithLimit(h.logger, h.limitsManager.HasReachedStackLimit, middleware.WithValidatedPost(h.createStack)))
|
||||||
r.Patch("/{listID}", middleware.WithValidatedPost(h.editStack))
|
r.Patch("/{listID}", middleware.WithValidatedPost(h.editStack))
|
||||||
r.Delete("/{listID}", h.deleteStack)
|
r.Delete("/{listID}", h.deleteStack)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateStackHandler(db *sql.DB) StackHandler {
|
func CreateStackHandler(db *sql.DB, limitsManager limits.LimitsManagerMethods) StackHandler {
|
||||||
stackModel := models.NewListModel(db)
|
stackModel := models.NewListModel(db)
|
||||||
logger := log.New(os.Stdout).WithPrefix("Stacks")
|
logger := log.New(os.Stdout).WithPrefix("Stacks")
|
||||||
|
|
||||||
return StackHandler{
|
return StackHandler{
|
||||||
logger,
|
logger,
|
||||||
stackModel,
|
stackModel,
|
||||||
|
limitsManager,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user