diff --git a/backend/auth/handler.go b/backend/auth/handler.go index 249e205..57680bb 100644 --- a/backend/auth/handler.go +++ b/backend/auth/handler.go @@ -18,6 +18,8 @@ type AuthHandler struct { user models.UserModel auth Auth + + jwtManager *middleware.JwtManager } type loginBody struct { @@ -65,8 +67,8 @@ func (h *AuthHandler) code(body codeBody, w http.ResponseWriter, r *http.Request return } - refresh := middleware.CreateRefreshToken(uuid) - access := middleware.CreateAccessToken(uuid) + refresh := h.jwtManager.CreateRefreshToken(uuid) + access := h.jwtManager.CreateAccessToken(uuid) codeReturn := codeReturn{ Access: access, @@ -87,7 +89,7 @@ func (h *AuthHandler) CreateRoutes(r chi.Router) { }) } -func CreateAuthHandler(db *sql.DB) AuthHandler { +func CreateAuthHandler(db *sql.DB, jwtManager *middleware.JwtManager) AuthHandler { userModel := models.NewUserModel(db) logger := log.New(os.Stdout).WithPrefix("Auth") @@ -99,8 +101,9 @@ func CreateAuthHandler(db *sql.DB) AuthHandler { auth := CreateAuth(mailer) return AuthHandler{ - logger, - userModel, - auth, + logger: logger, + user: userModel, + auth: auth, + jwtManager: jwtManager, } } diff --git a/backend/images/handler.go b/backend/images/handler.go index e9e8a1e..befde10 100644 --- a/backend/images/handler.go +++ b/backend/images/handler.go @@ -28,6 +28,8 @@ type ImageHandler struct { limitsManager limits.LimitsManagerMethods processImage func(imageID uuid.UUID) + + jwtManager *middleware.JwtManager } type ImagesReturn struct { @@ -250,12 +252,12 @@ func (h *ImageHandler) CreateRoutes(r chi.Router) { h.logger.Info("Mounting image router") r.Group(func(r chi.Router) { - r.Use(middleware.ProtectedRouteURL) + r.Use(middleware.ProtectedRouteURL(h.jwtManager)) r.Get("/{id}", h.serveImage) }) r.Group(func(r chi.Router) { - r.Use(middleware.ProtectedRoute) + r.Use(middleware.ProtectedRoute(h.jwtManager)) r.Use(middleware.SetJson) r.Get("/", h.listImages) @@ -264,7 +266,7 @@ func (h *ImageHandler) CreateRoutes(r chi.Router) { }) } -func CreateImageHandler(db *sql.DB, limitsManager limits.LimitsManagerMethods, processImage func(imageID uuid.UUID)) ImageHandler { +func CreateImageHandler(db *sql.DB, limitsManager limits.LimitsManagerMethods, processImage func(imageID uuid.UUID), jwtManager *middleware.JwtManager) ImageHandler { imageModel := models.NewImageModel(db) userModel := models.NewUserModel(db) logger := log.New(os.Stdout).WithPrefix("Images") @@ -275,5 +277,6 @@ func CreateImageHandler(db *sql.DB, limitsManager limits.LimitsManagerMethods, p userModel: userModel, limitsManager: limitsManager, processImage: processImage, + jwtManager: jwtManager, } } diff --git a/backend/integration_test.go b/backend/integration_test.go index b7a26f1..948bfee 100644 --- a/backend/integration_test.go +++ b/backend/integration_test.go @@ -57,11 +57,12 @@ type TestUser struct { } type TestContext struct { - db *sql.DB - router chi.Router - server *httptest.Server - users []TestUser - cleanup func() + db *sql.DB + router chi.Router + server *httptest.Server + users []TestUser + cleanup func() + jwtManager *middleware.JwtManager } func setupTestDatabase() (*sql.DB, func(), error) { @@ -179,12 +180,14 @@ func setupTestContext(t *testing.T) *TestContext { t.Fatalf("Failed to setup test database: %v", err) } - router := setupRouter(db) + jwtManager := middleware.NewJwtManager([]byte("test-jwt-secret")) + router := setupRouter(db, jwtManager) server := httptest.NewServer(router) tc.db = db tc.router = router tc.server = server + tc.jwtManager = jwtManager tc.cleanup = func() { server.Close() cleanup() @@ -202,7 +205,7 @@ func (tc *TestContext) createTestUser(email string) TestUser { } // Create access token for the user - accessToken := middleware.CreateAccessToken(userID) + accessToken := tc.jwtManager.CreateAccessToken(userID) user := TestUser{ ID: userID, diff --git a/backend/main.go b/backend/main.go index 10cdff4..193dc8e 100644 --- a/backend/main.go +++ b/backend/main.go @@ -4,6 +4,7 @@ import ( "fmt" "net/http" "os" + "screenmark/screenmark/middleware" "screenmark/screenmark/models" "github.com/joho/godotenv" @@ -15,12 +16,19 @@ func main() { panic(err) } + jwtSecret := os.Getenv("JWT_SECRET") + if jwtSecret == "" { + panic("JWT_SECRET environment variable not set") + } + + jwtManager := middleware.NewJwtManager([]byte(jwtSecret)) + db, err := models.InitDatabase() if err != nil { panic(err) } - router := setupRouter(db) + router := setupRouter(db, jwtManager) port, exists := os.LookupEnv("PORT") if !exists { diff --git a/backend/middleware/jwt.go b/backend/middleware/jwt.go index f8865a3..c1957d5 100644 --- a/backend/middleware/jwt.go +++ b/backend/middleware/jwt.go @@ -21,10 +21,15 @@ type JwtClaims struct { Expire time.Time } -// obviously this is very not secure. TODO: extract to env -var JWT_SECRET = []byte("very secret") +type JwtManager struct { + secret []byte +} -func createToken(claims JwtClaims) *jwt.Token { +func NewJwtManager(secret []byte) *JwtManager { + return &JwtManager{secret: secret} +} + +func (jm *JwtManager) createToken(claims JwtClaims) *jwt.Token { return jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ "UserID": claims.UserID, "Type": claims.Type, @@ -32,15 +37,14 @@ func createToken(claims JwtClaims) *jwt.Token { }) } -func CreateRefreshToken(userId uuid.UUID) string { - token := createToken(JwtClaims{ +func (jm *JwtManager) CreateRefreshToken(userId uuid.UUID) string { + token := jm.createToken(JwtClaims{ UserID: userId.String(), Type: Refresh, - Expire: time.Now().Add(time.Hour * 24 * 7), + Expire: time.Now().Add(time.Hour * 24 * 30), }) - // TODO: bruh what is this - tokenString, err := token.SignedString(JWT_SECRET) + tokenString, err := token.SignedString(jm.secret) if err != nil { panic(err) } @@ -48,15 +52,14 @@ func CreateRefreshToken(userId uuid.UUID) string { return tokenString } -func CreateAccessToken(userId uuid.UUID) string { - token := createToken(JwtClaims{ +func (jm *JwtManager) CreateAccessToken(userId uuid.UUID) string { + token := jm.createToken(JwtClaims{ UserID: userId.String(), Type: Access, - Expire: time.Now().Add(time.Hour), + Expire: time.Now().Add(time.Minute), }) - // TODO: bruh what is this - tokenString, err := token.SignedString(JWT_SECRET) + tokenString, err := token.SignedString(jm.secret) if err != nil { panic(err) } @@ -66,18 +69,15 @@ func CreateAccessToken(userId uuid.UUID) string { var NotValidToken = errors.New("Not a valid token") -func GetUserIdFromAccess(accessToken string) (uuid.UUID, error) { +func (jm *JwtManager) GetUserIdFromAccess(accessToken string) (uuid.UUID, error) { token, err := jwt.Parse(accessToken, func(token *jwt.Token) (any, error) { - return JWT_SECRET, nil + return jm.secret, nil }, jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()})) if err != nil { return uuid.Nil, err } - // Blah blah, check expiry and stuff - - // this function is stupid if claims, ok := token.Claims.(jwt.MapClaims); ok { tokenType, ok := claims["Type"] if !ok || tokenType.(string) != "access" { @@ -94,3 +94,7 @@ func GetUserIdFromAccess(accessToken string) (uuid.UUID, error) { return uuid.Nil, NotValidToken } } + +func GetUserIdFromAccess(jm *JwtManager, accessToken string) (uuid.UUID, error) { + return jm.GetUserIdFromAccess(accessToken) +} diff --git a/backend/middleware/middleware.go b/backend/middleware/middleware.go index 135197b..12609a2 100644 --- a/backend/middleware/middleware.go +++ b/backend/middleware/middleware.go @@ -50,65 +50,71 @@ func GetUserID(ctx context.Context, logger *log.Logger, w http.ResponseWriter) ( return userIdUuid, nil } -func ProtectedRouteURL(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - token := r.URL.Query().Get("token") +func ProtectedRouteURL(jm *JwtManager) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + token := r.URL.Query().Get("token") - userId, err := GetUserIdFromAccess(token) - if err != nil { - w.WriteHeader(http.StatusUnauthorized) - return - } + userId, err := GetUserIdFromAccess(jm, token) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + return + } - contextWithUserId := context.WithValue(r.Context(), USER_ID, userId) + contextWithUserId := context.WithValue(r.Context(), USER_ID, userId) - newR := r.WithContext(contextWithUserId) - next.ServeHTTP(w, newR) - }) + newR := r.WithContext(contextWithUserId) + next.ServeHTTP(w, newR) + }) + } } -func ProtectedRoute(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - token := r.Header.Get("Authorization") +func ProtectedRoute(jm *JwtManager) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + token := r.Header.Get("Authorization") - if len(token) < len("Bearer ") { - w.WriteHeader(http.StatusUnauthorized) - return - } + if len(token) < len("Bearer ") { + w.WriteHeader(http.StatusUnauthorized) + return + } - userId, err := GetUserIdFromAccess(token[len("Bearer "):]) - if err != nil { - w.WriteHeader(http.StatusUnauthorized) - return - } + userId, err := GetUserIdFromAccess(jm, token[len("Bearer "):]) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + return + } - contextWithUserId := context.WithValue(r.Context(), USER_ID, userId) + contextWithUserId := context.WithValue(r.Context(), USER_ID, userId) - newR := r.WithContext(contextWithUserId) - next.ServeHTTP(w, newR) - }) + newR := r.WithContext(contextWithUserId) + next.ServeHTTP(w, newR) + }) + } } -func GetUserIdFromUrl(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - token := r.URL.Query().Get("token") +func GetUserIdFromUrl(jm *JwtManager) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + token := r.URL.Query().Get("token") - if len(token) == 0 { - w.WriteHeader(http.StatusUnauthorized) - return - } + if len(token) == 0 { + w.WriteHeader(http.StatusUnauthorized) + return + } - userId, err := GetUserIdFromAccess(token) - if err != nil { - w.WriteHeader(http.StatusUnauthorized) - return - } + userId, err := GetUserIdFromAccess(jm, token) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + return + } - contextWithUserId := context.WithValue(r.Context(), USER_ID, userId) + contextWithUserId := context.WithValue(r.Context(), USER_ID, userId) - newR := r.WithContext(contextWithUserId) - next.ServeHTTP(w, newR) - }) + newR := r.WithContext(contextWithUserId) + next.ServeHTTP(w, newR) + }) + } } func GetPathParamID(logger *log.Logger, param string, w http.ResponseWriter, r *http.Request) (uuid.UUID, error) { diff --git a/backend/router.go b/backend/router.go index 5dda1f0..faf3ecf 100644 --- a/backend/router.go +++ b/backend/router.go @@ -24,7 +24,7 @@ func (client TestAiClient) GetImageInfo(imageName string, imageData []byte) (cli return client.ImageInfo, nil } -func setupRouter(db *sql.DB) chi.Router { +func setupRouter(db *sql.DB, jwtManager *ourmiddleware.JwtManager) chi.Router { imageModel := models.NewImageModel(db) stackModel := models.NewListModel(db) @@ -33,9 +33,9 @@ func setupRouter(db *sql.DB) chi.Router { processImageLogger := createLogger("Process Image", os.Stdout) processImage := ProcessImage(processImageLogger, db) - stackHandler := stacks.CreateStackHandler(db, limitsManager) - authHandler := auth.CreateAuthHandler(db) - imageHandler := images.CreateImageHandler(db, limitsManager, processImage) + stackHandler := stacks.CreateStackHandler(db, limitsManager, jwtManager) + authHandler := auth.CreateAuthHandler(db, jwtManager) + imageHandler := images.CreateImageHandler(db, limitsManager, processImage, jwtManager) notifier := NewNotifier[Notification](10) @@ -62,7 +62,7 @@ func setupRouter(db *sql.DB) chi.Router { r.Route("/images", imageHandler.CreateRoutes) r.Route("/notifications", func(r chi.Router) { - r.Use(ourmiddleware.GetUserIdFromUrl) + r.Use(ourmiddleware.GetUserIdFromUrl(jwtManager)) r.Get("/", CreateEventsHandler(¬ifier)) }) diff --git a/backend/stacks/handler.go b/backend/stacks/handler.go index 5bcb41d..850b7de 100644 --- a/backend/stacks/handler.go +++ b/backend/stacks/handler.go @@ -23,6 +23,8 @@ type StackHandler struct { stackModel models.ListModel limitsManager limits.LimitsManagerMethods + + jwtManager *middleware.JwtManager } func (h *StackHandler) getAllStacks(w http.ResponseWriter, r *http.Request) { @@ -185,7 +187,7 @@ func (h *StackHandler) CreateRoutes(r chi.Router) { h.logger.Info("Mounting stack router") r.Group(func(r chi.Router) { - r.Use(middleware.ProtectedRoute) + r.Use(middleware.ProtectedRoute(h.jwtManager)) r.Use(middleware.SetJson) r.Get("/", h.getAllStacks) @@ -198,15 +200,16 @@ func (h *StackHandler) CreateRoutes(r chi.Router) { }) } -func CreateStackHandler(db *sql.DB, limitsManager limits.LimitsManagerMethods) StackHandler { +func CreateStackHandler(db *sql.DB, limitsManager limits.LimitsManagerMethods, jwtManager *middleware.JwtManager) StackHandler { stackModel := models.NewListModel(db) imageModel := models.NewImageModel(db) logger := log.New(os.Stdout).WithPrefix("Stacks") return StackHandler{ - logger, - imageModel, - stackModel, - limitsManager, + logger: logger, + imageModel: imageModel, + stackModel: stackModel, + limitsManager: limitsManager, + jwtManager: jwtManager, } }