more refactoring into seperate handlers

This commit is contained in:
2025-08-25 13:16:40 +01:00
parent 10cea769bf
commit a78f766122
9 changed files with 367 additions and 372 deletions

View File

@ -1,4 +1,4 @@
package main
package auth
import (
"errors"

View File

@ -1,4 +1,4 @@
package main
package auth
import (
"testing"

View File

@ -1,4 +1,4 @@
package main
package auth
import (
"fmt"

107
backend/auth/handler.go Normal file
View File

@ -0,0 +1,107 @@
package auth
import (
"database/sql"
"net/http"
"os"
"screenmark/screenmark/.gen/haystack/haystack/model"
"screenmark/screenmark/middleware"
"screenmark/screenmark/models"
"github.com/charmbracelet/log"
"github.com/go-chi/chi/v5"
)
type AuthHandler struct {
logger *log.Logger
user models.UserModel
auth Auth
}
type loginBody struct {
Email string `json:"email"`
}
type codeBody struct {
Email string `json:"email"`
Code string `json:"code"`
}
type codeReturn struct {
Access string `json:"access"`
Refresh string `json:"refresh"`
}
func (h *AuthHandler) login(body loginBody, w http.ResponseWriter, r *http.Request) {
// TODO: validate email
err := h.auth.CreateCode(body.Email)
if err != nil {
middleware.WriteErrorInternal(h.logger, "could not create a code", w)
return
}
w.WriteHeader(http.StatusOK)
}
func (h *AuthHandler) code(body codeBody, w http.ResponseWriter, r *http.Request) {
if err := h.auth.UseCode(body.Email, body.Code); err != nil {
middleware.WriteErrorBadRequest(h.logger, "email or code are incorrect", w)
return
}
// TODO: we should only keep emails around for a little bit.
// Time to first login should be less than 10 minutes.
// So actually, they shouldn't be written to our database.
if exists := h.user.DoesUserExist(r.Context(), body.Email); !exists {
h.user.Save(r.Context(), model.Users{
Email: body.Email,
})
}
uuid, err := h.user.GetUserIdFromEmail(r.Context(), body.Email)
if err != nil {
middleware.WriteErrorBadRequest(h.logger, "failed to get user", w)
return
}
refresh := middleware.CreateRefreshToken(uuid)
access := middleware.CreateAccessToken(uuid)
codeReturn := codeReturn{
Access: access,
Refresh: refresh,
}
middleware.WriteJsonOrError(h.logger, codeReturn, w)
}
func (h *AuthHandler) CreateRoutes(r chi.Router) {
h.logger.Info("Mounting auth router")
r.Group(func(r chi.Router) {
r.Use(middleware.SetJson)
r.Post("/login", middleware.WithValidatedPost(h.login))
r.Post("/code", middleware.WithValidatedPost(h.code))
})
}
func CreateAuthHandler(db *sql.DB) AuthHandler {
userModel := models.NewUserModel(db)
logger := log.New(os.Stdout).WithPrefix("Auth")
mailer, err := CreateMailClient()
if err != nil {
panic(err)
}
auth := CreateAuth(mailer)
return AuthHandler{
logger,
userModel,
auth,
}
}

170
backend/images/handler.go Normal file
View File

@ -0,0 +1,170 @@
package images
import (
"bytes"
"database/sql"
"encoding/base64"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"screenmark/screenmark/.gen/haystack/haystack/model"
"screenmark/screenmark/middleware"
"screenmark/screenmark/models"
"github.com/charmbracelet/log"
"github.com/go-chi/chi/v5"
)
type ImageHandler struct {
logger *log.Logger
imageModel models.ImageModel
userModel models.UserModel
}
type ImagesReturn struct {
UserImages []models.UserImageWithImage `json:"userImages"`
ProcessingImages []models.UserProcessingImage `json:"processingImages"`
Lists []models.ListsWithImages `json:"lists"`
}
func (h *ImageHandler) serveImage(w http.ResponseWriter, r *http.Request) {
imageId, err := middleware.GetPathParamID(h.logger, "id", w, r)
if err != nil {
return
}
image, err := h.imageModel.Get(r.Context(), imageId)
if err != nil {
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, "Could not get image")
return
}
// TODO: this could be part of the db table
extension := filepath.Ext(image.ImageName)
if len(extension) == 0 {
// Same hack
extension = "png"
}
extension = extension[1:]
w.Header().Add("Content-Type", "image/"+extension)
w.Write(image.Image)
}
func (h *ImageHandler) listImages(w http.ResponseWriter, r *http.Request) {
userId, err := middleware.GetUserID(r.Context(), h.logger, w)
if err != nil {
return
}
images, err := h.userModel.GetUserImages(r.Context(), userId)
if err != nil {
middleware.WriteErrorInternal(h.logger, "could not get user images", w)
return
}
processingImages, err := h.imageModel.GetProcessing(r.Context(), userId)
if err != nil {
middleware.WriteErrorInternal(h.logger, "could not get processing images", w)
return
}
listsWithImages, err := h.userModel.ListWithImages(r.Context(), userId)
if err != nil {
middleware.WriteErrorInternal(h.logger, "could not get lists with images", w)
return
}
imagesReturn := ImagesReturn{
UserImages: images,
ProcessingImages: processingImages,
Lists: listsWithImages,
}
middleware.WriteJsonOrError(h.logger, imagesReturn, w)
}
func (h *ImageHandler) uploadImage(w http.ResponseWriter, r *http.Request) {
imageName := chi.URLParam(r, "name")
if len(imageName) == 0 {
middleware.WriteErrorBadRequest(h.logger, "you need to provide a name in the path", w)
return
}
userId, err := middleware.GetUserID(r.Context(), h.logger, w)
if err != nil {
return
}
contentType := r.Header.Get("Content-Type")
image := make([]byte, 0)
switch contentType {
case "application/base64":
decoder := base64.NewDecoder(base64.StdEncoding, r.Body)
buf := &bytes.Buffer{}
_, err := io.Copy(buf, decoder)
if err != nil {
middleware.WriteErrorBadRequest(h.logger, "base64 decoding failed", w)
return
}
image = buf.Bytes()
case "application/oclet-stream", "image/png":
bodyData, err := io.ReadAll(r.Body)
if err != nil {
middleware.WriteErrorBadRequest(h.logger, "binary data reading failed", w)
return
}
// TODO: check headers
image = bodyData
default:
middleware.WriteErrorBadRequest(h.logger, "unsupported content type, need octet-stream or base64", w)
return
}
userImage, err := h.imageModel.Process(r.Context(), userId, model.Image{
Image: image,
ImageName: imageName,
})
if err != nil {
middleware.WriteErrorInternal(h.logger, "could not save image to DB", w)
return
}
middleware.WriteJsonOrError(h.logger, userImage, w)
}
func (h *ImageHandler) CreateRoutes(r chi.Router) {
h.logger.Info("Mounting image router")
// Public route for serving images (not protected)
r.Get("/image/{id}", h.serveImage)
// Protected routes
r.Group(func(r chi.Router) {
r.Use(middleware.ProtectedRoute)
r.Use(middleware.SetJson)
r.Get("/image", h.listImages)
r.Post("/image/{name}", h.uploadImage)
})
}
func CreateImageHandler(db *sql.DB) ImageHandler {
imageModel := models.NewImageModel(db)
userModel := models.NewUserModel(db)
logger := log.New(os.Stdout).WithPrefix("Images")
return ImageHandler{
logger: logger,
imageModel: imageModel,
userModel: userModel,
}
}

View File

@ -1,16 +1,11 @@
package main
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"path/filepath"
"screenmark/screenmark/.gen/haystack/haystack/model"
"screenmark/screenmark/agents/client"
"screenmark/screenmark/auth"
"screenmark/screenmark/images"
"screenmark/screenmark/models"
"screenmark/screenmark/stacks"
@ -18,7 +13,6 @@ import (
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/google/uuid"
"github.com/joho/godotenv"
)
@ -42,16 +36,10 @@ func main() {
}
imageModel := models.NewImageModel(db)
userModel := models.NewUserModel(db)
stackHandler := stacks.CreateStackHandler(db)
mail, err := CreateMailClient()
if err != nil {
panic(err)
}
auth := CreateAuth(mail)
authHandler := auth.CreateAuthHandler(db)
imageHandler := images.CreateImageHandler(db)
notifier := NewNotifier[Notification](10)
@ -65,200 +53,8 @@ func main() {
r.Use(ourmiddleware.CorsMiddleware)
r.Route("/stacks", stackHandler.CreateRoutes)
// Temporarily not in protect route because we aren't using cookies.
// Therefore they don't get automatically attached to the request.
// So <img src=""> cannot send the tokensend the token
r.Get("/image/{id}", func(w http.ResponseWriter, r *http.Request) {
stringImageId := r.PathValue("id")
// userId := r.Context().Value(USER_ID).(uuid.UUID)
imageId, err := uuid.Parse(stringImageId)
if err != nil {
w.WriteHeader(http.StatusForbidden)
fmt.Fprintf(w, "You cannot read this")
return
}
// if authorized := imageModel.IsUserAuthorized(r.Context(), imageId, userId); !authorized {
// w.WriteHeader(http.StatusForbidden)
// fmt.Fprintf(w, "You cannot read this")
// return
// }
image, err := imageModel.Get(r.Context(), imageId)
if err != nil {
log.Println(err)
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, "Could not get image")
return
}
// TODO: this could be part of the db table
extension := filepath.Ext(image.ImageName)
if len(extension) == 0 {
// Same hack
extension = "png"
}
extension = extension[1:]
w.Header().Add("Content-Type", "image/"+extension)
w.Write(image.Image)
})
r.Group(func(r chi.Router) {
r.Use(ourmiddleware.ProtectedRoute)
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Content-Type", "application/json")
next.ServeHTTP(w, r)
})
})
r.Get("/image", func(w http.ResponseWriter, r *http.Request) {
userId := r.Context().Value(ourmiddleware.USER_ID).(uuid.UUID)
if err != nil {
w.WriteHeader(http.StatusForbidden)
fmt.Fprintf(w, "You cannot read this")
return
}
images, err := userModel.GetUserImages(r.Context(), userId)
if err != nil {
log.Println(err)
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, "Something went wrong")
return
}
processingImages, err := imageModel.GetProcessing(r.Context(), userId)
if err != nil {
log.Println(err)
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, "Something went wrong")
return
}
listsWithImages, err := userModel.ListWithImages(r.Context(), userId)
if err != nil {
log.Println(err)
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, "Something went wrong")
return
}
type ImagesReturn struct {
UserImages []models.UserImageWithImage
ProcessingImages []models.UserProcessingImage
Lists []models.ListsWithImages
}
imagesReturn := ImagesReturn{
UserImages: images,
ProcessingImages: processingImages,
Lists: listsWithImages,
}
jsonImages, err := json.Marshal(imagesReturn)
if err != nil {
log.Println(err)
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "Could not create JSON response for this image")
return
}
w.Write(jsonImages)
})
r.Post("/image/{name}", func(w http.ResponseWriter, r *http.Request) {
imageName := r.PathValue("name")
userId := r.Context().Value(ourmiddleware.USER_ID).(uuid.UUID)
if len(imageName) == 0 {
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "You need to provide a name in the path")
return
}
contentType := r.Header.Get("Content-Type")
fmt.Printf("Content-Type: %s\n", contentType)
// TODO: length checks on body
// TODO: extract this shit out
image := make([]byte, 0)
switch contentType {
case "application/base64":
decoder := base64.NewDecoder(base64.StdEncoding, r.Body)
buf := &bytes.Buffer{}
decodedIamge, err := io.Copy(buf, decoder)
if err != nil {
log.Println(err)
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "bruh, base64 aint decoding")
return
}
fmt.Println(string(image))
fmt.Println(decodedIamge)
image = buf.Bytes()
case "application/oclet-stream", "image/png":
bodyData, err := io.ReadAll(r.Body)
if err != nil {
log.Println(err)
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "bruh, binary aint binaring")
return
}
// TODO: check headers
image = bodyData
default:
log.Println("bad stuff?")
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "Bruh, you need oclet stream or base64")
return
}
if err != nil {
log.Println("First case")
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "Couldnt read the image from the request body")
return
}
userImage, err := imageModel.Process(r.Context(), userId, model.Image{
Image: image,
ImageName: imageName,
Description: "",
})
if err != nil {
log.Println("Second case")
log.Println(err)
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "Could not save image to DB")
return
}
jsonUserImage, err := json.Marshal(userImage)
if err != nil {
log.Println("Third case")
log.Println(err)
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "Could not create JSON response for this image")
return
}
w.WriteHeader(http.StatusCreated)
fmt.Fprint(w, string(jsonUserImage))
w.Header().Add("Content-Type", "application/json")
})
})
r.Route("/auth", authHandler.CreateRoutes)
r.Route("/images", imageHandler.CreateRoutes)
r.Route("/notifications", func(r chi.Router) {
r.Use(ourmiddleware.GetUserIdFromUrl)
@ -266,124 +62,6 @@ func main() {
r.Get("/", CreateEventsHandler(&notifier))
})
r.Post("/login", func(w http.ResponseWriter, r *http.Request) {
type LoginBody struct {
Email string `json:"email"`
}
loginBody := LoginBody{}
err := json.NewDecoder(r.Body).Decode(&loginBody)
if err != nil {
log.Println(err)
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "Request body was not correct")
return
}
// TODO: validate it's an email
auth.CreateCode(loginBody.Email)
w.WriteHeader(http.StatusOK)
})
type CodeReturn struct {
Access string `json:"access"`
Refresh string `json:"refresh"`
}
r.Post("/code", func(w http.ResponseWriter, r *http.Request) {
type CodeBody struct {
Email string `json:"email"`
Code string `json:"code"`
}
codeBody := CodeBody{}
if err := json.NewDecoder(r.Body).Decode(&codeBody); err != nil {
log.Println(err)
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "Request body was not correct")
return
}
if err := auth.UseCode(codeBody.Email, codeBody.Code); err != nil {
log.Println(err)
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "email or code are incorrect")
return
}
if exists := userModel.DoesUserExist(r.Context(), codeBody.Email); !exists {
userModel.Save(r.Context(), model.Users{
Email: codeBody.Email,
})
}
uuid, err := userModel.GetUserIdFromEmail(r.Context(), codeBody.Email)
if err != nil {
log.Println(err)
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprintf(w, "Something went wrong.")
return
}
refresh := ourmiddleware.CreateRefreshToken(uuid)
access := ourmiddleware.CreateAccessToken(uuid)
codeReturn := CodeReturn{
Access: access,
Refresh: refresh,
}
fmt.Println(codeReturn)
json, err := json.Marshal(codeReturn)
if err != nil {
log.Println(err)
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprintf(w, "Something went wrong.")
return
}
w.WriteHeader(http.StatusOK)
w.Header().Add("Content-Type", "application/json")
fmt.Fprint(w, string(json))
})
r.Get("/demo-login", func(w http.ResponseWriter, r *http.Request) {
uuid, err := userModel.GetUserIdFromEmail(r.Context(), "demo@email.com")
if err != nil {
log.Println(err)
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprintf(w, "Something went wrong.")
return
}
refresh := ourmiddleware.CreateRefreshToken(uuid)
access := ourmiddleware.CreateAccessToken(uuid)
codeReturn := CodeReturn{
Access: access,
Refresh: refresh,
}
fmt.Println(codeReturn)
json, err := json.Marshal(codeReturn)
if err != nil {
log.Println(err)
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprintf(w, "Something went wrong.")
return
}
w.WriteHeader(http.StatusOK)
w.Header().Add("Content-Type", "application/json")
fmt.Fprint(w, string(json))
})
logWriter := DatabaseWriter{
dbPool: db,
}

View File

@ -0,0 +1,29 @@
package middleware
import (
"encoding/json"
"io"
"net/http"
)
func WithValidatedPost[K any](
fn func(request K, w http.ResponseWriter, r *http.Request),
) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
request := new(K)
body, err := io.ReadAll(r.Body)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
err = json.Unmarshal(body, request)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
fn(*request, w, r)
}
}

View File

@ -0,0 +1,48 @@
package middleware
import (
"encoding/json"
"net/http"
"github.com/charmbracelet/log"
)
func WriteJsonOrError[K any](logger *log.Logger, object K, w http.ResponseWriter) {
jsonObject, err := json.Marshal(object)
if err != nil {
logger.Warn("could not marshal json object", "err", err)
w.WriteHeader(http.StatusBadRequest)
return
}
w.Write(jsonObject)
w.WriteHeader(http.StatusOK)
}
type ErrorObject struct {
Error string `json:"error"`
}
func writeError(logger *log.Logger, error string, w http.ResponseWriter, code int) {
e := ErrorObject{
error,
}
jsonObject, err := json.Marshal(e)
if err != nil {
logger.Warn("could not marshal json object", "err", err)
w.WriteHeader(http.StatusBadRequest)
return
}
w.Write(jsonObject)
w.WriteHeader(code)
}
func WriteErrorBadRequest(logger *log.Logger, error string, w http.ResponseWriter) {
writeError(logger, error, w, http.StatusBadRequest)
}
func WriteErrorInternal(logger *log.Logger, error string, w http.ResponseWriter) {
writeError(logger, error, w, http.StatusInternalServerError)
}

View File

@ -2,8 +2,6 @@ package stacks
import (
"database/sql"
"encoding/json"
"io"
"net/http"
"os"
"screenmark/screenmark/middleware"
@ -13,45 +11,11 @@ import (
"github.com/go-chi/chi/v5"
)
func writeJsonOrError[K any](logger *log.Logger, object K, w http.ResponseWriter) {
jsonObject, err := json.Marshal(object)
if err != nil {
logger.Warn("could not marshal json object", "err", err)
w.WriteHeader(http.StatusBadRequest)
return
}
w.Write(jsonObject)
w.WriteHeader(http.StatusOK)
}
type StackHandler struct {
logger *log.Logger
stackModel models.ListModel
}
func withValidatedPost[K any](
fn func(request K, w http.ResponseWriter, r *http.Request),
) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
request := new(K)
body, err := io.ReadAll(r.Body)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
err = json.Unmarshal(body, request)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
fn(*request, w, r)
}
}
func (h *StackHandler) getAllStacks(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
@ -67,7 +31,7 @@ func (h *StackHandler) getAllStacks(w http.ResponseWriter, r *http.Request) {
return
}
writeJsonOrError(h.logger, lists, w)
middleware.WriteJsonOrError(h.logger, lists, w)
}
func (h *StackHandler) getStackItems(w http.ResponseWriter, r *http.Request) {
@ -91,7 +55,7 @@ func (h *StackHandler) getStackItems(w http.ResponseWriter, r *http.Request) {
return
}
writeJsonOrError(h.logger, lists, w)
middleware.WriteJsonOrError(h.logger, lists, w)
}
type EditStack struct {
@ -136,9 +100,8 @@ func (h *StackHandler) CreateRoutes(r chi.Router) {
r.Get("/", h.getAllStacks)
r.Get("/{listID}", h.getStackItems)
r.Post("/", withValidatedPost(h.createStack))
r.Patch("/{listID}", withValidatedPost(h.editStack))
r.Post("/", middleware.WithValidatedPost(h.createStack))
r.Patch("/{listID}", middleware.WithValidatedPost(h.editStack))
})
}