package images import ( "bytes" "database/sql" "encoding/base64" "fmt" "io" "net/http" "os" "path/filepath" "screenmark/screenmark/limits" "screenmark/screenmark/middleware" "screenmark/screenmark/models" "github.com/charmbracelet/log" "github.com/go-chi/chi/v5" "github.com/google/uuid" ) type ImageHandler struct { logger *log.Logger imageModel models.ImageModel userModel models.UserModel limitsManager limits.LimitsManagerMethods jwtManager *middleware.JwtManager } type ImagesReturn struct { UserImages []models.UserImageWithImage `json:"userImages"` Lists []models.ListsWithImages `json:"lists"` } func (h *ImageHandler) serveImage(w http.ResponseWriter, r *http.Request) { imageID, err := middleware.GetPathParamID(h.logger, "id", w, r) if err != nil { return } ctx := r.Context() userID, err := middleware.GetUserID(ctx, h.logger, w) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } image, exists, err := h.imageModel.Get(r.Context(), imageID) if err != nil { w.WriteHeader(http.StatusNotFound) fmt.Fprintf(w, "Could not get image") return } // Do not leak that this ID exists. if !exists || *image.UserID != userID { w.WriteHeader(http.StatusNotFound) return } // TODO: this could be part of the db table extension := filepath.Ext(image.ImageName) if len(extension) == 0 { // Same hack extension = "png" } extension = extension[1:] w.Header().Add("Content-Type", "image/"+extension) w.Write(image.Image) } func (h *ImageHandler) listImages(w http.ResponseWriter, r *http.Request) { userId, err := middleware.GetUserID(r.Context(), h.logger, w) if err != nil { return } images, err := h.userModel.GetUserImages(r.Context(), userId) if err != nil { middleware.WriteErrorInternal(h.logger, "could not get user images", w) return } listsWithImages, err := h.userModel.ListWithImages(r.Context(), userId) if err != nil { middleware.WriteErrorInternal(h.logger, "could not get lists with images", w) return } imagesReturn := ImagesReturn{ UserImages: images, Lists: listsWithImages, } middleware.WriteJsonOrError(h.logger, imagesReturn, w) } func (h *ImageHandler) uploadImage(w http.ResponseWriter, r *http.Request) { imageName := chi.URLParam(r, "name") if len(imageName) == 0 { middleware.WriteErrorBadRequest(h.logger, "you need to provide a name in the path", w) return } userID, err := middleware.GetUserID(r.Context(), h.logger, w) if err != nil { return } contentType := r.Header.Get("Content-Type") image := make([]byte, 0) switch contentType { case "application/base64": decoder := base64.NewDecoder(base64.StdEncoding, r.Body) buf := &bytes.Buffer{} _, err := io.Copy(buf, decoder) if err != nil { middleware.WriteErrorBadRequest(h.logger, "base64 decoding failed", w) return } image = buf.Bytes() case "application/oclet-stream", "image/png": bodyData, err := io.ReadAll(r.Body) if err != nil { middleware.WriteErrorBadRequest(h.logger, "binary data reading failed", w) return } // TODO: check headers image = bodyData default: middleware.WriteErrorBadRequest(h.logger, "unsupported content type, need octet-stream or base64", w) return } ctx := r.Context() err = h.imageModel.Save(ctx, imageName, image, userID) if err != nil { middleware.WriteErrorInternal(h.logger, "could not save image to DB", w) return } w.WriteHeader(http.StatusOK) } func (h *ImageHandler) deleteImage(w http.ResponseWriter, r *http.Request) { stringImageID := chi.URLParam(r, "image-id") imageID, err := uuid.Parse(stringImageID) if err != nil { w.WriteHeader(http.StatusBadRequest) return } ctx := r.Context() userID, err := middleware.GetUserID(ctx, h.logger, w) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } exists, err := h.imageModel.Delete(ctx, imageID, userID) if err != nil { h.logger.Warn("cannot delete image", "error", err) w.WriteHeader(http.StatusBadRequest) return } // Don't leak if the image exists or not if !exists { w.WriteHeader(http.StatusNotFound) return } w.WriteHeader(http.StatusOK) } func (h *ImageHandler) CreateRoutes(r chi.Router) { h.logger.Info("Mounting image router") r.Group(func(r chi.Router) { r.Use(middleware.ProtectedRouteURL(h.jwtManager)) r.Get("/{id}", h.serveImage) }) r.Group(func(r chi.Router) { r.Use(middleware.ProtectedRoute(h.jwtManager)) r.Use(middleware.SetJson) r.Get("/", h.listImages) r.Post("/{name}", middleware.WithLimit(h.logger, h.limitsManager.HasReachedImageLimit, h.uploadImage)) r.Delete("/{image-id}", h.deleteImage) }) } func CreateImageHandler(db *sql.DB, limitsManager limits.LimitsManagerMethods, jwtManager *middleware.JwtManager) ImageHandler { imageModel := models.NewImageModel(db) userModel := models.NewUserModel(db) logger := log.New(os.Stdout).WithPrefix("Images") return ImageHandler{ logger: logger, imageModel: imageModel, userModel: userModel, limitsManager: limitsManager, jwtManager: jwtManager, } }