2 Commits

Author SHA1 Message Date
a3345afbfa AI: checking the actual expire 2025-09-21 14:51:02 +01:00
f078ac7d0b AI: refactoring JWT tokens to jwt token manager 2025-09-21 14:42:06 +01:00
8 changed files with 149 additions and 89 deletions

View File

@ -18,6 +18,8 @@ type AuthHandler struct {
user models.UserModel user models.UserModel
auth Auth auth Auth
jwtManager *middleware.JwtManager
} }
type loginBody struct { type loginBody struct {
@ -65,8 +67,8 @@ func (h *AuthHandler) code(body codeBody, w http.ResponseWriter, r *http.Request
return return
} }
refresh := middleware.CreateRefreshToken(uuid) refresh := h.jwtManager.CreateRefreshToken(uuid)
access := middleware.CreateAccessToken(uuid) access := h.jwtManager.CreateAccessToken(uuid)
codeReturn := codeReturn{ codeReturn := codeReturn{
Access: access, Access: access,
@ -87,7 +89,7 @@ func (h *AuthHandler) CreateRoutes(r chi.Router) {
}) })
} }
func CreateAuthHandler(db *sql.DB) AuthHandler { func CreateAuthHandler(db *sql.DB, jwtManager *middleware.JwtManager) AuthHandler {
userModel := models.NewUserModel(db) userModel := models.NewUserModel(db)
logger := log.New(os.Stdout).WithPrefix("Auth") logger := log.New(os.Stdout).WithPrefix("Auth")
@ -99,8 +101,9 @@ func CreateAuthHandler(db *sql.DB) AuthHandler {
auth := CreateAuth(mailer) auth := CreateAuth(mailer)
return AuthHandler{ return AuthHandler{
logger, logger: logger,
userModel, user: userModel,
auth, auth: auth,
jwtManager: jwtManager,
} }
} }

View File

@ -28,6 +28,8 @@ type ImageHandler struct {
limitsManager limits.LimitsManagerMethods limitsManager limits.LimitsManagerMethods
processImage func(imageID uuid.UUID) processImage func(imageID uuid.UUID)
jwtManager *middleware.JwtManager
} }
type ImagesReturn struct { type ImagesReturn struct {
@ -250,12 +252,12 @@ func (h *ImageHandler) CreateRoutes(r chi.Router) {
h.logger.Info("Mounting image router") h.logger.Info("Mounting image router")
r.Group(func(r chi.Router) { r.Group(func(r chi.Router) {
r.Use(middleware.ProtectedRouteURL) r.Use(middleware.ProtectedRouteURL(h.jwtManager))
r.Get("/{id}", h.serveImage) r.Get("/{id}", h.serveImage)
}) })
r.Group(func(r chi.Router) { r.Group(func(r chi.Router) {
r.Use(middleware.ProtectedRoute) r.Use(middleware.ProtectedRoute(h.jwtManager))
r.Use(middleware.SetJson) r.Use(middleware.SetJson)
r.Get("/", h.listImages) r.Get("/", h.listImages)
@ -264,7 +266,7 @@ func (h *ImageHandler) CreateRoutes(r chi.Router) {
}) })
} }
func CreateImageHandler(db *sql.DB, limitsManager limits.LimitsManagerMethods, processImage func(imageID uuid.UUID)) ImageHandler { func CreateImageHandler(db *sql.DB, limitsManager limits.LimitsManagerMethods, processImage func(imageID uuid.UUID), jwtManager *middleware.JwtManager) ImageHandler {
imageModel := models.NewImageModel(db) imageModel := models.NewImageModel(db)
userModel := models.NewUserModel(db) userModel := models.NewUserModel(db)
logger := log.New(os.Stdout).WithPrefix("Images") logger := log.New(os.Stdout).WithPrefix("Images")
@ -275,5 +277,6 @@ func CreateImageHandler(db *sql.DB, limitsManager limits.LimitsManagerMethods, p
userModel: userModel, userModel: userModel,
limitsManager: limitsManager, limitsManager: limitsManager,
processImage: processImage, processImage: processImage,
jwtManager: jwtManager,
} }
} }

View File

@ -57,11 +57,12 @@ type TestUser struct {
} }
type TestContext struct { type TestContext struct {
db *sql.DB db *sql.DB
router chi.Router router chi.Router
server *httptest.Server server *httptest.Server
users []TestUser users []TestUser
cleanup func() cleanup func()
jwtManager *middleware.JwtManager
} }
func setupTestDatabase() (*sql.DB, func(), error) { func setupTestDatabase() (*sql.DB, func(), error) {
@ -179,12 +180,14 @@ func setupTestContext(t *testing.T) *TestContext {
t.Fatalf("Failed to setup test database: %v", err) t.Fatalf("Failed to setup test database: %v", err)
} }
router := setupRouter(db) jwtManager := middleware.NewJwtManager([]byte("test-jwt-secret"))
router := setupRouter(db, jwtManager)
server := httptest.NewServer(router) server := httptest.NewServer(router)
tc.db = db tc.db = db
tc.router = router tc.router = router
tc.server = server tc.server = server
tc.jwtManager = jwtManager
tc.cleanup = func() { tc.cleanup = func() {
server.Close() server.Close()
cleanup() cleanup()
@ -202,7 +205,7 @@ func (tc *TestContext) createTestUser(email string) TestUser {
} }
// Create access token for the user // Create access token for the user
accessToken := middleware.CreateAccessToken(userID) accessToken := tc.jwtManager.CreateAccessToken(userID)
user := TestUser{ user := TestUser{
ID: userID, ID: userID,

View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"os" "os"
"screenmark/screenmark/middleware"
"screenmark/screenmark/models" "screenmark/screenmark/models"
"github.com/joho/godotenv" "github.com/joho/godotenv"
@ -15,12 +16,19 @@ func main() {
panic(err) panic(err)
} }
jwtSecret := os.Getenv("JWT_SECRET")
if jwtSecret == "" {
panic("JWT_SECRET environment variable not set")
}
jwtManager := middleware.NewJwtManager([]byte(jwtSecret))
db, err := models.InitDatabase() db, err := models.InitDatabase()
if err != nil { if err != nil {
panic(err) panic(err)
} }
router := setupRouter(db) router := setupRouter(db, jwtManager)
port, exists := os.LookupEnv("PORT") port, exists := os.LookupEnv("PORT")
if !exists { if !exists {

View File

@ -1,6 +1,7 @@
package middleware package middleware
import ( import (
"encoding/json"
"errors" "errors"
"time" "time"
@ -21,10 +22,15 @@ type JwtClaims struct {
Expire time.Time Expire time.Time
} }
// obviously this is very not secure. TODO: extract to env type JwtManager struct {
var JWT_SECRET = []byte("very secret") secret []byte
}
func createToken(claims JwtClaims) *jwt.Token { func NewJwtManager(secret []byte) *JwtManager {
return &JwtManager{secret: secret}
}
func (jm *JwtManager) createToken(claims JwtClaims) *jwt.Token {
return jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ return jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"UserID": claims.UserID, "UserID": claims.UserID,
"Type": claims.Type, "Type": claims.Type,
@ -32,15 +38,14 @@ func createToken(claims JwtClaims) *jwt.Token {
}) })
} }
func CreateRefreshToken(userId uuid.UUID) string { func (jm *JwtManager) CreateRefreshToken(userId uuid.UUID) string {
token := createToken(JwtClaims{ token := jm.createToken(JwtClaims{
UserID: userId.String(), UserID: userId.String(),
Type: Refresh, Type: Refresh,
Expire: time.Now().Add(time.Hour * 24 * 7), Expire: time.Now().Add(time.Hour * 24 * 30),
}) })
// TODO: bruh what is this tokenString, err := token.SignedString(jm.secret)
tokenString, err := token.SignedString(JWT_SECRET)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -48,15 +53,14 @@ func CreateRefreshToken(userId uuid.UUID) string {
return tokenString return tokenString
} }
func CreateAccessToken(userId uuid.UUID) string { func (jm *JwtManager) CreateAccessToken(userId uuid.UUID) string {
token := createToken(JwtClaims{ token := jm.createToken(JwtClaims{
UserID: userId.String(), UserID: userId.String(),
Type: Access, Type: Access,
Expire: time.Now().Add(time.Hour), Expire: time.Now().Add(time.Minute),
}) })
// TODO: bruh what is this tokenString, err := token.SignedString(jm.secret)
tokenString, err := token.SignedString(JWT_SECRET)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -66,24 +70,50 @@ func CreateAccessToken(userId uuid.UUID) string {
var NotValidToken = errors.New("Not a valid token") var NotValidToken = errors.New("Not a valid token")
func GetUserIdFromAccess(accessToken string) (uuid.UUID, error) { func (jm *JwtManager) GetUserIdFromAccess(accessToken string) (uuid.UUID, error) {
token, err := jwt.Parse(accessToken, func(token *jwt.Token) (any, error) { token, err := jwt.Parse(accessToken, func(token *jwt.Token) (any, error) {
return JWT_SECRET, nil return jm.secret, nil
}, jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()})) }, jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}))
if err != nil { if err != nil {
return uuid.Nil, err return uuid.Nil, err
} }
// Blah blah, check expiry and stuff // Check if token is valid (including expiry check)
if !token.Valid {
return uuid.Nil, NotValidToken
}
// this function is stupid
if claims, ok := token.Claims.(jwt.MapClaims); ok { if claims, ok := token.Claims.(jwt.MapClaims); ok {
tokenType, ok := claims["Type"] tokenType, ok := claims["Type"]
if !ok || tokenType.(string) != "access" { if !ok || tokenType.(string) != "access" {
return uuid.Nil, NotValidToken return uuid.Nil, NotValidToken
} }
// Additional explicit expiry check
expireClaim, ok := claims["Expire"]
if !ok {
return uuid.Nil, NotValidToken
}
var expireTime time.Time
switch exp := expireClaim.(type) {
case float64:
expireTime = time.Unix(int64(exp), 0)
case json.Number:
expInt, err := exp.Int64()
if err != nil {
return uuid.Nil, NotValidToken
}
expireTime = time.Unix(expInt, 0)
default:
return uuid.Nil, NotValidToken
}
if time.Now().After(expireTime) {
return uuid.Nil, NotValidToken
}
userId, err := uuid.Parse(claims["UserID"].(string)) userId, err := uuid.Parse(claims["UserID"].(string))
if err != nil { if err != nil {
return uuid.Nil, NotValidToken return uuid.Nil, NotValidToken
@ -94,3 +124,7 @@ func GetUserIdFromAccess(accessToken string) (uuid.UUID, error) {
return uuid.Nil, NotValidToken return uuid.Nil, NotValidToken
} }
} }
func GetUserIdFromAccess(jm *JwtManager, accessToken string) (uuid.UUID, error) {
return jm.GetUserIdFromAccess(accessToken)
}

View File

@ -50,65 +50,71 @@ func GetUserID(ctx context.Context, logger *log.Logger, w http.ResponseWriter) (
return userIdUuid, nil return userIdUuid, nil
} }
func ProtectedRouteURL(next http.Handler) http.Handler { func ProtectedRouteURL(jm *JwtManager) func(http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return func(next http.Handler) http.Handler {
token := r.URL.Query().Get("token") return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := r.URL.Query().Get("token")
userId, err := GetUserIdFromAccess(token) userId, err := GetUserIdFromAccess(jm, token)
if err != nil { if err != nil {
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
return return
} }
contextWithUserId := context.WithValue(r.Context(), USER_ID, userId) contextWithUserId := context.WithValue(r.Context(), USER_ID, userId)
newR := r.WithContext(contextWithUserId) newR := r.WithContext(contextWithUserId)
next.ServeHTTP(w, newR) next.ServeHTTP(w, newR)
}) })
}
} }
func ProtectedRoute(next http.Handler) http.Handler { func ProtectedRoute(jm *JwtManager) func(http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return func(next http.Handler) http.Handler {
token := r.Header.Get("Authorization") return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := r.Header.Get("Authorization")
if len(token) < len("Bearer ") { if len(token) < len("Bearer ") {
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
return return
} }
userId, err := GetUserIdFromAccess(token[len("Bearer "):]) userId, err := GetUserIdFromAccess(jm, token[len("Bearer "):])
if err != nil { if err != nil {
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
return return
} }
contextWithUserId := context.WithValue(r.Context(), USER_ID, userId) contextWithUserId := context.WithValue(r.Context(), USER_ID, userId)
newR := r.WithContext(contextWithUserId) newR := r.WithContext(contextWithUserId)
next.ServeHTTP(w, newR) next.ServeHTTP(w, newR)
}) })
}
} }
func GetUserIdFromUrl(next http.Handler) http.Handler { func GetUserIdFromUrl(jm *JwtManager) func(http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return func(next http.Handler) http.Handler {
token := r.URL.Query().Get("token") return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := r.URL.Query().Get("token")
if len(token) == 0 { if len(token) == 0 {
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
return return
} }
userId, err := GetUserIdFromAccess(token) userId, err := GetUserIdFromAccess(jm, token)
if err != nil { if err != nil {
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
return return
} }
contextWithUserId := context.WithValue(r.Context(), USER_ID, userId) contextWithUserId := context.WithValue(r.Context(), USER_ID, userId)
newR := r.WithContext(contextWithUserId) newR := r.WithContext(contextWithUserId)
next.ServeHTTP(w, newR) next.ServeHTTP(w, newR)
}) })
}
} }
func GetPathParamID(logger *log.Logger, param string, w http.ResponseWriter, r *http.Request) (uuid.UUID, error) { func GetPathParamID(logger *log.Logger, param string, w http.ResponseWriter, r *http.Request) (uuid.UUID, error) {

View File

@ -24,7 +24,7 @@ func (client TestAiClient) GetImageInfo(imageName string, imageData []byte) (cli
return client.ImageInfo, nil return client.ImageInfo, nil
} }
func setupRouter(db *sql.DB) chi.Router { func setupRouter(db *sql.DB, jwtManager *ourmiddleware.JwtManager) chi.Router {
imageModel := models.NewImageModel(db) imageModel := models.NewImageModel(db)
stackModel := models.NewListModel(db) stackModel := models.NewListModel(db)
@ -33,9 +33,9 @@ func setupRouter(db *sql.DB) chi.Router {
processImageLogger := createLogger("Process Image", os.Stdout) processImageLogger := createLogger("Process Image", os.Stdout)
processImage := ProcessImage(processImageLogger, db) processImage := ProcessImage(processImageLogger, db)
stackHandler := stacks.CreateStackHandler(db, limitsManager) stackHandler := stacks.CreateStackHandler(db, limitsManager, jwtManager)
authHandler := auth.CreateAuthHandler(db) authHandler := auth.CreateAuthHandler(db, jwtManager)
imageHandler := images.CreateImageHandler(db, limitsManager, processImage) imageHandler := images.CreateImageHandler(db, limitsManager, processImage, jwtManager)
notifier := NewNotifier[Notification](10) notifier := NewNotifier[Notification](10)
@ -62,7 +62,7 @@ func setupRouter(db *sql.DB) chi.Router {
r.Route("/images", imageHandler.CreateRoutes) r.Route("/images", imageHandler.CreateRoutes)
r.Route("/notifications", func(r chi.Router) { r.Route("/notifications", func(r chi.Router) {
r.Use(ourmiddleware.GetUserIdFromUrl) r.Use(ourmiddleware.GetUserIdFromUrl(jwtManager))
r.Get("/", CreateEventsHandler(&notifier)) r.Get("/", CreateEventsHandler(&notifier))
}) })

View File

@ -23,6 +23,8 @@ type StackHandler struct {
stackModel models.ListModel stackModel models.ListModel
limitsManager limits.LimitsManagerMethods limitsManager limits.LimitsManagerMethods
jwtManager *middleware.JwtManager
} }
func (h *StackHandler) getAllStacks(w http.ResponseWriter, r *http.Request) { func (h *StackHandler) getAllStacks(w http.ResponseWriter, r *http.Request) {
@ -185,7 +187,7 @@ func (h *StackHandler) CreateRoutes(r chi.Router) {
h.logger.Info("Mounting stack router") h.logger.Info("Mounting stack router")
r.Group(func(r chi.Router) { r.Group(func(r chi.Router) {
r.Use(middleware.ProtectedRoute) r.Use(middleware.ProtectedRoute(h.jwtManager))
r.Use(middleware.SetJson) r.Use(middleware.SetJson)
r.Get("/", h.getAllStacks) r.Get("/", h.getAllStacks)
@ -198,15 +200,16 @@ func (h *StackHandler) CreateRoutes(r chi.Router) {
}) })
} }
func CreateStackHandler(db *sql.DB, limitsManager limits.LimitsManagerMethods) StackHandler { func CreateStackHandler(db *sql.DB, limitsManager limits.LimitsManagerMethods, jwtManager *middleware.JwtManager) StackHandler {
stackModel := models.NewListModel(db) stackModel := models.NewListModel(db)
imageModel := models.NewImageModel(db) imageModel := models.NewImageModel(db)
logger := log.New(os.Stdout).WithPrefix("Stacks") logger := log.New(os.Stdout).WithPrefix("Stacks")
return StackHandler{ return StackHandler{
logger, logger: logger,
imageModel, imageModel: imageModel,
stackModel, stackModel: stackModel,
limitsManager, limitsManager: limitsManager,
jwtManager: jwtManager,
} }
} }