From f5e65524aa3fbbd326f3aab7bc08b32c4d38d8d1 Mon Sep 17 00:00:00 2001 From: John Costa Date: Tue, 19 Aug 2025 21:49:48 +0100 Subject: [PATCH] improving by extracting common userID method --- backend/stacks/handler.go | 41 +++++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/backend/stacks/handler.go b/backend/stacks/handler.go index 6a345ea..29ed0a6 100644 --- a/backend/stacks/handler.go +++ b/backend/stacks/handler.go @@ -30,17 +30,27 @@ type StackHandler struct { stackModel models.ListModel } -func (h *StackHandler) getAllStacks(w http.ResponseWriter, r *http.Request) { +func (h *StackHandler) withUserID( + fn func(userID uuid.UUID, w http.ResponseWriter, r *http.Request), +) func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + userID, err := middleware.GetUserID(ctx) + if err != nil { + h.logger.Warn("could not get users in get all stacks", "err", err) + w.WriteHeader(http.StatusUnauthorized) + return + } + + fn(userID, w, r) + } +} + +func (h *StackHandler) getAllStacks(userID uuid.UUID, w http.ResponseWriter, r *http.Request) { ctx := r.Context() - userId, err := middleware.GetUserID(ctx) - if err != nil { - h.logger.Warn("could not get users in get all stacks", "err", err) - w.WriteHeader(http.StatusUnauthorized) - return - } - - lists, err := h.stackModel.List(ctx, userId) + lists, err := h.stackModel.List(ctx, userID) if err != nil { h.logger.Warn("could not get stacks", "err", err) w.WriteHeader(http.StatusBadRequest) @@ -50,16 +60,9 @@ func (h *StackHandler) getAllStacks(w http.ResponseWriter, r *http.Request) { writeJsonOrError(h.logger, lists, w) } -func (h *StackHandler) getStackItems(w http.ResponseWriter, r *http.Request) { +func (h *StackHandler) getStackItems(userID uuid.UUID, w http.ResponseWriter, r *http.Request) { ctx := r.Context() - _, err := middleware.GetUserID(ctx) - if err != nil { - h.logger.Warn("could not get users in get all stacks", "err", err) - w.WriteHeader(http.StatusUnauthorized) - return - } - listID := r.PathValue("listID") if len(listID) == 0 { h.logger.Warn("listID is not present in path") @@ -93,8 +96,8 @@ func (h *StackHandler) CreateRoutes(r chi.Router) { r.Use(middleware.ProtectedRoute) r.Use(middleware.SetJson) - r.Get("/", h.getAllStacks) - r.Get("/{listID}", h.getStackItems) + r.Get("/", h.withUserID(h.getAllStacks)) + r.Get("/{listID}", h.withUserID(h.getStackItems)) }) }