From a78f76612230669ca384afdb5a5e2913eb2bb7d6 Mon Sep 17 00:00:00 2001 From: John Costa Date: Mon, 25 Aug 2025 13:16:40 +0100 Subject: [PATCH] more refactoring into seperate handlers --- backend/{ => auth}/auth.go | 2 +- backend/{ => auth}/auth_test.go | 2 +- backend/{ => auth}/email.go | 2 +- backend/auth/handler.go | 107 ++++++++++ backend/images/handler.go | 170 ++++++++++++++++ backend/main.go | 334 +------------------------------- backend/middleware/body.go | 29 +++ backend/middleware/util.go | 48 +++++ backend/stacks/handler.go | 45 +---- 9 files changed, 367 insertions(+), 372 deletions(-) rename backend/{ => auth}/auth.go (98%) rename backend/{ => auth}/auth_test.go (97%) rename backend/{ => auth}/email.go (98%) create mode 100644 backend/auth/handler.go create mode 100644 backend/images/handler.go create mode 100644 backend/middleware/body.go create mode 100644 backend/middleware/util.go diff --git a/backend/auth.go b/backend/auth/auth.go similarity index 98% rename from backend/auth.go rename to backend/auth/auth.go index ec81b7f..f8d1396 100644 --- a/backend/auth.go +++ b/backend/auth/auth.go @@ -1,4 +1,4 @@ -package main +package auth import ( "errors" diff --git a/backend/auth_test.go b/backend/auth/auth_test.go similarity index 97% rename from backend/auth_test.go rename to backend/auth/auth_test.go index 04519fb..cf06612 100644 --- a/backend/auth_test.go +++ b/backend/auth/auth_test.go @@ -1,4 +1,4 @@ -package main +package auth import ( "testing" diff --git a/backend/email.go b/backend/auth/email.go similarity index 98% rename from backend/email.go rename to backend/auth/email.go index fc91505..6b0fe6f 100644 --- a/backend/email.go +++ b/backend/auth/email.go @@ -1,4 +1,4 @@ -package main +package auth import ( "fmt" diff --git a/backend/auth/handler.go b/backend/auth/handler.go new file mode 100644 index 0000000..213fad4 --- /dev/null +++ b/backend/auth/handler.go @@ -0,0 +1,107 @@ +package auth + +import ( + "database/sql" + "net/http" + "os" + "screenmark/screenmark/.gen/haystack/haystack/model" + "screenmark/screenmark/middleware" + "screenmark/screenmark/models" + + "github.com/charmbracelet/log" + "github.com/go-chi/chi/v5" +) + +type AuthHandler struct { + logger *log.Logger + + user models.UserModel + + auth Auth +} + +type loginBody struct { + Email string `json:"email"` +} + +type codeBody struct { + Email string `json:"email"` + Code string `json:"code"` +} + +type codeReturn struct { + Access string `json:"access"` + Refresh string `json:"refresh"` +} + +func (h *AuthHandler) login(body loginBody, w http.ResponseWriter, r *http.Request) { + // TODO: validate email + err := h.auth.CreateCode(body.Email) + if err != nil { + middleware.WriteErrorInternal(h.logger, "could not create a code", w) + return + } + + w.WriteHeader(http.StatusOK) +} + +func (h *AuthHandler) code(body codeBody, w http.ResponseWriter, r *http.Request) { + if err := h.auth.UseCode(body.Email, body.Code); err != nil { + middleware.WriteErrorBadRequest(h.logger, "email or code are incorrect", w) + return + } + + // TODO: we should only keep emails around for a little bit. + // Time to first login should be less than 10 minutes. + // So actually, they shouldn't be written to our database. + if exists := h.user.DoesUserExist(r.Context(), body.Email); !exists { + h.user.Save(r.Context(), model.Users{ + Email: body.Email, + }) + } + + uuid, err := h.user.GetUserIdFromEmail(r.Context(), body.Email) + if err != nil { + middleware.WriteErrorBadRequest(h.logger, "failed to get user", w) + return + } + + refresh := middleware.CreateRefreshToken(uuid) + access := middleware.CreateAccessToken(uuid) + + codeReturn := codeReturn{ + Access: access, + Refresh: refresh, + } + + middleware.WriteJsonOrError(h.logger, codeReturn, w) +} + +func (h *AuthHandler) CreateRoutes(r chi.Router) { + h.logger.Info("Mounting auth router") + + r.Group(func(r chi.Router) { + r.Use(middleware.SetJson) + + r.Post("/login", middleware.WithValidatedPost(h.login)) + r.Post("/code", middleware.WithValidatedPost(h.code)) + }) +} + +func CreateAuthHandler(db *sql.DB) AuthHandler { + userModel := models.NewUserModel(db) + logger := log.New(os.Stdout).WithPrefix("Auth") + + mailer, err := CreateMailClient() + if err != nil { + panic(err) + } + + auth := CreateAuth(mailer) + + return AuthHandler{ + logger, + userModel, + auth, + } +} diff --git a/backend/images/handler.go b/backend/images/handler.go new file mode 100644 index 0000000..7a0d038 --- /dev/null +++ b/backend/images/handler.go @@ -0,0 +1,170 @@ +package images + +import ( + "bytes" + "database/sql" + "encoding/base64" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "screenmark/screenmark/.gen/haystack/haystack/model" + "screenmark/screenmark/middleware" + "screenmark/screenmark/models" + + "github.com/charmbracelet/log" + "github.com/go-chi/chi/v5" +) + +type ImageHandler struct { + logger *log.Logger + imageModel models.ImageModel + userModel models.UserModel +} + +type ImagesReturn struct { + UserImages []models.UserImageWithImage `json:"userImages"` + ProcessingImages []models.UserProcessingImage `json:"processingImages"` + Lists []models.ListsWithImages `json:"lists"` +} + +func (h *ImageHandler) serveImage(w http.ResponseWriter, r *http.Request) { + imageId, err := middleware.GetPathParamID(h.logger, "id", w, r) + if err != nil { + return + } + + image, err := h.imageModel.Get(r.Context(), imageId) + if err != nil { + w.WriteHeader(http.StatusNotFound) + fmt.Fprintf(w, "Could not get image") + return + } + + // TODO: this could be part of the db table + extension := filepath.Ext(image.ImageName) + if len(extension) == 0 { + // Same hack + extension = "png" + } + extension = extension[1:] + + w.Header().Add("Content-Type", "image/"+extension) + w.Write(image.Image) +} + +func (h *ImageHandler) listImages(w http.ResponseWriter, r *http.Request) { + userId, err := middleware.GetUserID(r.Context(), h.logger, w) + if err != nil { + return + } + + images, err := h.userModel.GetUserImages(r.Context(), userId) + if err != nil { + middleware.WriteErrorInternal(h.logger, "could not get user images", w) + return + } + + processingImages, err := h.imageModel.GetProcessing(r.Context(), userId) + if err != nil { + middleware.WriteErrorInternal(h.logger, "could not get processing images", w) + return + } + + listsWithImages, err := h.userModel.ListWithImages(r.Context(), userId) + if err != nil { + middleware.WriteErrorInternal(h.logger, "could not get lists with images", w) + return + } + + imagesReturn := ImagesReturn{ + UserImages: images, + ProcessingImages: processingImages, + Lists: listsWithImages, + } + + middleware.WriteJsonOrError(h.logger, imagesReturn, w) +} + +func (h *ImageHandler) uploadImage(w http.ResponseWriter, r *http.Request) { + imageName := chi.URLParam(r, "name") + if len(imageName) == 0 { + middleware.WriteErrorBadRequest(h.logger, "you need to provide a name in the path", w) + return + } + + userId, err := middleware.GetUserID(r.Context(), h.logger, w) + if err != nil { + return + } + + contentType := r.Header.Get("Content-Type") + + image := make([]byte, 0) + switch contentType { + case "application/base64": + decoder := base64.NewDecoder(base64.StdEncoding, r.Body) + buf := &bytes.Buffer{} + + _, err := io.Copy(buf, decoder) + if err != nil { + middleware.WriteErrorBadRequest(h.logger, "base64 decoding failed", w) + return + } + + image = buf.Bytes() + case "application/oclet-stream", "image/png": + bodyData, err := io.ReadAll(r.Body) + if err != nil { + middleware.WriteErrorBadRequest(h.logger, "binary data reading failed", w) + return + } + // TODO: check headers + + image = bodyData + default: + middleware.WriteErrorBadRequest(h.logger, "unsupported content type, need octet-stream or base64", w) + return + } + + userImage, err := h.imageModel.Process(r.Context(), userId, model.Image{ + Image: image, + ImageName: imageName, + }) + + if err != nil { + middleware.WriteErrorInternal(h.logger, "could not save image to DB", w) + return + } + + middleware.WriteJsonOrError(h.logger, userImage, w) +} + +func (h *ImageHandler) CreateRoutes(r chi.Router) { + h.logger.Info("Mounting image router") + + // Public route for serving images (not protected) + r.Get("/image/{id}", h.serveImage) + + // Protected routes + r.Group(func(r chi.Router) { + r.Use(middleware.ProtectedRoute) + r.Use(middleware.SetJson) + + r.Get("/image", h.listImages) + r.Post("/image/{name}", h.uploadImage) + }) +} + +func CreateImageHandler(db *sql.DB) ImageHandler { + imageModel := models.NewImageModel(db) + userModel := models.NewUserModel(db) + logger := log.New(os.Stdout).WithPrefix("Images") + + return ImageHandler{ + logger: logger, + imageModel: imageModel, + userModel: userModel, + } +} diff --git a/backend/main.go b/backend/main.go index 9db64e9..bec0b0f 100644 --- a/backend/main.go +++ b/backend/main.go @@ -1,16 +1,11 @@ package main import ( - "bytes" - "encoding/base64" - "encoding/json" - "fmt" - "io" "log" "net/http" - "path/filepath" - "screenmark/screenmark/.gen/haystack/haystack/model" "screenmark/screenmark/agents/client" + "screenmark/screenmark/auth" + "screenmark/screenmark/images" "screenmark/screenmark/models" "screenmark/screenmark/stacks" @@ -18,7 +13,6 @@ import ( "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" - "github.com/google/uuid" "github.com/joho/godotenv" ) @@ -42,16 +36,10 @@ func main() { } imageModel := models.NewImageModel(db) - userModel := models.NewUserModel(db) stackHandler := stacks.CreateStackHandler(db) - - mail, err := CreateMailClient() - if err != nil { - panic(err) - } - - auth := CreateAuth(mail) + authHandler := auth.CreateAuthHandler(db) + imageHandler := images.CreateImageHandler(db) notifier := NewNotifier[Notification](10) @@ -65,200 +53,8 @@ func main() { r.Use(ourmiddleware.CorsMiddleware) r.Route("/stacks", stackHandler.CreateRoutes) - - // Temporarily not in protect route because we aren't using cookies. - // Therefore they don't get automatically attached to the request. - // So cannot send the tokensend the token - r.Get("/image/{id}", func(w http.ResponseWriter, r *http.Request) { - stringImageId := r.PathValue("id") - // userId := r.Context().Value(USER_ID).(uuid.UUID) - - imageId, err := uuid.Parse(stringImageId) - if err != nil { - w.WriteHeader(http.StatusForbidden) - fmt.Fprintf(w, "You cannot read this") - return - } - - // if authorized := imageModel.IsUserAuthorized(r.Context(), imageId, userId); !authorized { - // w.WriteHeader(http.StatusForbidden) - // fmt.Fprintf(w, "You cannot read this") - // return - // } - - image, err := imageModel.Get(r.Context(), imageId) - if err != nil { - log.Println(err) - w.WriteHeader(http.StatusNotFound) - fmt.Fprintf(w, "Could not get image") - return - } - - // TODO: this could be part of the db table - extension := filepath.Ext(image.ImageName) - if len(extension) == 0 { - // Same hack - extension = "png" - } - extension = extension[1:] - - w.Header().Add("Content-Type", "image/"+extension) - w.Write(image.Image) - }) - - r.Group(func(r chi.Router) { - r.Use(ourmiddleware.ProtectedRoute) - r.Use(func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Add("Content-Type", "application/json") - - next.ServeHTTP(w, r) - }) - }) - - r.Get("/image", func(w http.ResponseWriter, r *http.Request) { - userId := r.Context().Value(ourmiddleware.USER_ID).(uuid.UUID) - if err != nil { - w.WriteHeader(http.StatusForbidden) - fmt.Fprintf(w, "You cannot read this") - return - } - - images, err := userModel.GetUserImages(r.Context(), userId) - if err != nil { - log.Println(err) - w.WriteHeader(http.StatusNotFound) - fmt.Fprintf(w, "Something went wrong") - return - } - - processingImages, err := imageModel.GetProcessing(r.Context(), userId) - if err != nil { - log.Println(err) - w.WriteHeader(http.StatusNotFound) - fmt.Fprintf(w, "Something went wrong") - return - } - - listsWithImages, err := userModel.ListWithImages(r.Context(), userId) - if err != nil { - log.Println(err) - w.WriteHeader(http.StatusNotFound) - fmt.Fprintf(w, "Something went wrong") - return - } - - type ImagesReturn struct { - UserImages []models.UserImageWithImage - ProcessingImages []models.UserProcessingImage - Lists []models.ListsWithImages - } - - imagesReturn := ImagesReturn{ - UserImages: images, - ProcessingImages: processingImages, - Lists: listsWithImages, - } - - jsonImages, err := json.Marshal(imagesReturn) - if err != nil { - log.Println(err) - w.WriteHeader(http.StatusBadRequest) - fmt.Fprintf(w, "Could not create JSON response for this image") - return - } - - w.Write(jsonImages) - }) - - r.Post("/image/{name}", func(w http.ResponseWriter, r *http.Request) { - imageName := r.PathValue("name") - userId := r.Context().Value(ourmiddleware.USER_ID).(uuid.UUID) - - if len(imageName) == 0 { - w.WriteHeader(http.StatusBadRequest) - fmt.Fprintf(w, "You need to provide a name in the path") - return - } - - contentType := r.Header.Get("Content-Type") - - fmt.Printf("Content-Type: %s\n", contentType) - - // TODO: length checks on body - // TODO: extract this shit out - image := make([]byte, 0) - switch contentType { - case "application/base64": - decoder := base64.NewDecoder(base64.StdEncoding, r.Body) - buf := &bytes.Buffer{} - - decodedIamge, err := io.Copy(buf, decoder) - if err != nil { - log.Println(err) - w.WriteHeader(http.StatusBadRequest) - fmt.Fprintf(w, "bruh, base64 aint decoding") - return - } - - fmt.Println(string(image)) - fmt.Println(decodedIamge) - - image = buf.Bytes() - case "application/oclet-stream", "image/png": - bodyData, err := io.ReadAll(r.Body) - if err != nil { - log.Println(err) - w.WriteHeader(http.StatusBadRequest) - fmt.Fprintf(w, "bruh, binary aint binaring") - return - } - // TODO: check headers - - image = bodyData - default: - log.Println("bad stuff?") - w.WriteHeader(http.StatusBadRequest) - fmt.Fprintf(w, "Bruh, you need oclet stream or base64") - return - } - - if err != nil { - log.Println("First case") - w.WriteHeader(http.StatusBadRequest) - fmt.Fprintf(w, "Couldnt read the image from the request body") - return - } - - userImage, err := imageModel.Process(r.Context(), userId, model.Image{ - Image: image, - ImageName: imageName, - Description: "", - }) - if err != nil { - log.Println("Second case") - log.Println(err) - w.WriteHeader(http.StatusBadRequest) - fmt.Fprintf(w, "Could not save image to DB") - return - } - - jsonUserImage, err := json.Marshal(userImage) - if err != nil { - log.Println("Third case") - log.Println(err) - w.WriteHeader(http.StatusBadRequest) - fmt.Fprintf(w, "Could not create JSON response for this image") - return - } - - w.WriteHeader(http.StatusCreated) - - fmt.Fprint(w, string(jsonUserImage)) - w.Header().Add("Content-Type", "application/json") - }) - - }) + r.Route("/auth", authHandler.CreateRoutes) + r.Route("/images", imageHandler.CreateRoutes) r.Route("/notifications", func(r chi.Router) { r.Use(ourmiddleware.GetUserIdFromUrl) @@ -266,124 +62,6 @@ func main() { r.Get("/", CreateEventsHandler(¬ifier)) }) - r.Post("/login", func(w http.ResponseWriter, r *http.Request) { - type LoginBody struct { - Email string `json:"email"` - } - - loginBody := LoginBody{} - err := json.NewDecoder(r.Body).Decode(&loginBody) - if err != nil { - log.Println(err) - w.WriteHeader(http.StatusBadRequest) - fmt.Fprintf(w, "Request body was not correct") - return - } - - // TODO: validate it's an email - - auth.CreateCode(loginBody.Email) - - w.WriteHeader(http.StatusOK) - }) - - type CodeReturn struct { - Access string `json:"access"` - Refresh string `json:"refresh"` - } - - r.Post("/code", func(w http.ResponseWriter, r *http.Request) { - type CodeBody struct { - Email string `json:"email"` - Code string `json:"code"` - } - - codeBody := CodeBody{} - if err := json.NewDecoder(r.Body).Decode(&codeBody); err != nil { - log.Println(err) - w.WriteHeader(http.StatusBadRequest) - fmt.Fprintf(w, "Request body was not correct") - return - } - - if err := auth.UseCode(codeBody.Email, codeBody.Code); err != nil { - log.Println(err) - w.WriteHeader(http.StatusBadRequest) - fmt.Fprintf(w, "email or code are incorrect") - return - } - - if exists := userModel.DoesUserExist(r.Context(), codeBody.Email); !exists { - userModel.Save(r.Context(), model.Users{ - Email: codeBody.Email, - }) - } - - uuid, err := userModel.GetUserIdFromEmail(r.Context(), codeBody.Email) - if err != nil { - log.Println(err) - w.WriteHeader(http.StatusInternalServerError) - fmt.Fprintf(w, "Something went wrong.") - return - } - - refresh := ourmiddleware.CreateRefreshToken(uuid) - access := ourmiddleware.CreateAccessToken(uuid) - - codeReturn := CodeReturn{ - Access: access, - Refresh: refresh, - } - - fmt.Println(codeReturn) - - json, err := json.Marshal(codeReturn) - if err != nil { - log.Println(err) - w.WriteHeader(http.StatusInternalServerError) - fmt.Fprintf(w, "Something went wrong.") - return - } - - w.WriteHeader(http.StatusOK) - w.Header().Add("Content-Type", "application/json") - - fmt.Fprint(w, string(json)) - }) - - r.Get("/demo-login", func(w http.ResponseWriter, r *http.Request) { - uuid, err := userModel.GetUserIdFromEmail(r.Context(), "demo@email.com") - if err != nil { - log.Println(err) - w.WriteHeader(http.StatusInternalServerError) - fmt.Fprintf(w, "Something went wrong.") - return - } - - refresh := ourmiddleware.CreateRefreshToken(uuid) - access := ourmiddleware.CreateAccessToken(uuid) - - codeReturn := CodeReturn{ - Access: access, - Refresh: refresh, - } - - fmt.Println(codeReturn) - - json, err := json.Marshal(codeReturn) - if err != nil { - log.Println(err) - w.WriteHeader(http.StatusInternalServerError) - fmt.Fprintf(w, "Something went wrong.") - return - } - - w.WriteHeader(http.StatusOK) - w.Header().Add("Content-Type", "application/json") - - fmt.Fprint(w, string(json)) - }) - logWriter := DatabaseWriter{ dbPool: db, } diff --git a/backend/middleware/body.go b/backend/middleware/body.go new file mode 100644 index 0000000..2e7bf90 --- /dev/null +++ b/backend/middleware/body.go @@ -0,0 +1,29 @@ +package middleware + +import ( + "encoding/json" + "io" + "net/http" +) + +func WithValidatedPost[K any]( + fn func(request K, w http.ResponseWriter, r *http.Request), +) func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + request := new(K) + + body, err := io.ReadAll(r.Body) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + + err = json.Unmarshal(body, request) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + + fn(*request, w, r) + } +} diff --git a/backend/middleware/util.go b/backend/middleware/util.go new file mode 100644 index 0000000..9c77f70 --- /dev/null +++ b/backend/middleware/util.go @@ -0,0 +1,48 @@ +package middleware + +import ( + "encoding/json" + "net/http" + + "github.com/charmbracelet/log" +) + +func WriteJsonOrError[K any](logger *log.Logger, object K, w http.ResponseWriter) { + jsonObject, err := json.Marshal(object) + if err != nil { + logger.Warn("could not marshal json object", "err", err) + w.WriteHeader(http.StatusBadRequest) + return + } + + w.Write(jsonObject) + w.WriteHeader(http.StatusOK) +} + +type ErrorObject struct { + Error string `json:"error"` +} + +func writeError(logger *log.Logger, error string, w http.ResponseWriter, code int) { + e := ErrorObject{ + error, + } + + jsonObject, err := json.Marshal(e) + if err != nil { + logger.Warn("could not marshal json object", "err", err) + w.WriteHeader(http.StatusBadRequest) + return + } + + w.Write(jsonObject) + w.WriteHeader(code) +} + +func WriteErrorBadRequest(logger *log.Logger, error string, w http.ResponseWriter) { + writeError(logger, error, w, http.StatusBadRequest) +} + +func WriteErrorInternal(logger *log.Logger, error string, w http.ResponseWriter) { + writeError(logger, error, w, http.StatusInternalServerError) +} diff --git a/backend/stacks/handler.go b/backend/stacks/handler.go index 510c0fd..b95f6dc 100644 --- a/backend/stacks/handler.go +++ b/backend/stacks/handler.go @@ -2,8 +2,6 @@ package stacks import ( "database/sql" - "encoding/json" - "io" "net/http" "os" "screenmark/screenmark/middleware" @@ -13,45 +11,11 @@ import ( "github.com/go-chi/chi/v5" ) -func writeJsonOrError[K any](logger *log.Logger, object K, w http.ResponseWriter) { - jsonObject, err := json.Marshal(object) - if err != nil { - logger.Warn("could not marshal json object", "err", err) - w.WriteHeader(http.StatusBadRequest) - return - } - - w.Write(jsonObject) - w.WriteHeader(http.StatusOK) -} - type StackHandler struct { logger *log.Logger stackModel models.ListModel } -func withValidatedPost[K any]( - fn func(request K, w http.ResponseWriter, r *http.Request), -) func(w http.ResponseWriter, r *http.Request) { - return func(w http.ResponseWriter, r *http.Request) { - request := new(K) - - body, err := io.ReadAll(r.Body) - if err != nil { - w.WriteHeader(http.StatusBadRequest) - return - } - - err = json.Unmarshal(body, request) - if err != nil { - w.WriteHeader(http.StatusBadRequest) - return - } - - fn(*request, w, r) - } -} - func (h *StackHandler) getAllStacks(w http.ResponseWriter, r *http.Request) { ctx := r.Context() @@ -67,7 +31,7 @@ func (h *StackHandler) getAllStacks(w http.ResponseWriter, r *http.Request) { return } - writeJsonOrError(h.logger, lists, w) + middleware.WriteJsonOrError(h.logger, lists, w) } func (h *StackHandler) getStackItems(w http.ResponseWriter, r *http.Request) { @@ -91,7 +55,7 @@ func (h *StackHandler) getStackItems(w http.ResponseWriter, r *http.Request) { return } - writeJsonOrError(h.logger, lists, w) + middleware.WriteJsonOrError(h.logger, lists, w) } type EditStack struct { @@ -136,9 +100,8 @@ func (h *StackHandler) CreateRoutes(r chi.Router) { r.Get("/", h.getAllStacks) r.Get("/{listID}", h.getStackItems) - r.Post("/", withValidatedPost(h.createStack)) - - r.Patch("/{listID}", withValidatedPost(h.editStack)) + r.Post("/", middleware.WithValidatedPost(h.createStack)) + r.Patch("/{listID}", middleware.WithValidatedPost(h.editStack)) }) }