AI: refactoring JWT tokens to jwt token manager

This commit is contained in:
2025-09-21 14:42:06 +01:00
parent e28d9e5d16
commit f078ac7d0b
8 changed files with 120 additions and 90 deletions

View File

@ -18,6 +18,8 @@ type AuthHandler struct {
user models.UserModel
auth Auth
jwtManager *middleware.JwtManager
}
type loginBody struct {
@ -65,8 +67,8 @@ func (h *AuthHandler) code(body codeBody, w http.ResponseWriter, r *http.Request
return
}
refresh := middleware.CreateRefreshToken(uuid)
access := middleware.CreateAccessToken(uuid)
refresh := h.jwtManager.CreateRefreshToken(uuid)
access := h.jwtManager.CreateAccessToken(uuid)
codeReturn := codeReturn{
Access: access,
@ -87,7 +89,7 @@ func (h *AuthHandler) CreateRoutes(r chi.Router) {
})
}
func CreateAuthHandler(db *sql.DB) AuthHandler {
func CreateAuthHandler(db *sql.DB, jwtManager *middleware.JwtManager) AuthHandler {
userModel := models.NewUserModel(db)
logger := log.New(os.Stdout).WithPrefix("Auth")
@ -99,8 +101,9 @@ func CreateAuthHandler(db *sql.DB) AuthHandler {
auth := CreateAuth(mailer)
return AuthHandler{
logger,
userModel,
auth,
logger: logger,
user: userModel,
auth: auth,
jwtManager: jwtManager,
}
}

View File

@ -28,6 +28,8 @@ type ImageHandler struct {
limitsManager limits.LimitsManagerMethods
processImage func(imageID uuid.UUID)
jwtManager *middleware.JwtManager
}
type ImagesReturn struct {
@ -250,12 +252,12 @@ func (h *ImageHandler) CreateRoutes(r chi.Router) {
h.logger.Info("Mounting image router")
r.Group(func(r chi.Router) {
r.Use(middleware.ProtectedRouteURL)
r.Use(middleware.ProtectedRouteURL(h.jwtManager))
r.Get("/{id}", h.serveImage)
})
r.Group(func(r chi.Router) {
r.Use(middleware.ProtectedRoute)
r.Use(middleware.ProtectedRoute(h.jwtManager))
r.Use(middleware.SetJson)
r.Get("/", h.listImages)
@ -264,7 +266,7 @@ func (h *ImageHandler) CreateRoutes(r chi.Router) {
})
}
func CreateImageHandler(db *sql.DB, limitsManager limits.LimitsManagerMethods, processImage func(imageID uuid.UUID)) ImageHandler {
func CreateImageHandler(db *sql.DB, limitsManager limits.LimitsManagerMethods, processImage func(imageID uuid.UUID), jwtManager *middleware.JwtManager) ImageHandler {
imageModel := models.NewImageModel(db)
userModel := models.NewUserModel(db)
logger := log.New(os.Stdout).WithPrefix("Images")
@ -275,5 +277,6 @@ func CreateImageHandler(db *sql.DB, limitsManager limits.LimitsManagerMethods, p
userModel: userModel,
limitsManager: limitsManager,
processImage: processImage,
jwtManager: jwtManager,
}
}

View File

@ -62,6 +62,7 @@ type TestContext struct {
server *httptest.Server
users []TestUser
cleanup func()
jwtManager *middleware.JwtManager
}
func setupTestDatabase() (*sql.DB, func(), error) {
@ -179,12 +180,14 @@ func setupTestContext(t *testing.T) *TestContext {
t.Fatalf("Failed to setup test database: %v", err)
}
router := setupRouter(db)
jwtManager := middleware.NewJwtManager([]byte("test-jwt-secret"))
router := setupRouter(db, jwtManager)
server := httptest.NewServer(router)
tc.db = db
tc.router = router
tc.server = server
tc.jwtManager = jwtManager
tc.cleanup = func() {
server.Close()
cleanup()
@ -202,7 +205,7 @@ func (tc *TestContext) createTestUser(email string) TestUser {
}
// Create access token for the user
accessToken := middleware.CreateAccessToken(userID)
accessToken := tc.jwtManager.CreateAccessToken(userID)
user := TestUser{
ID: userID,

View File

@ -4,6 +4,7 @@ import (
"fmt"
"net/http"
"os"
"screenmark/screenmark/middleware"
"screenmark/screenmark/models"
"github.com/joho/godotenv"
@ -15,12 +16,19 @@ func main() {
panic(err)
}
jwtSecret := os.Getenv("JWT_SECRET")
if jwtSecret == "" {
panic("JWT_SECRET environment variable not set")
}
jwtManager := middleware.NewJwtManager([]byte(jwtSecret))
db, err := models.InitDatabase()
if err != nil {
panic(err)
}
router := setupRouter(db)
router := setupRouter(db, jwtManager)
port, exists := os.LookupEnv("PORT")
if !exists {

View File

@ -21,10 +21,15 @@ type JwtClaims struct {
Expire time.Time
}
// obviously this is very not secure. TODO: extract to env
var JWT_SECRET = []byte("very secret")
type JwtManager struct {
secret []byte
}
func createToken(claims JwtClaims) *jwt.Token {
func NewJwtManager(secret []byte) *JwtManager {
return &JwtManager{secret: secret}
}
func (jm *JwtManager) createToken(claims JwtClaims) *jwt.Token {
return jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"UserID": claims.UserID,
"Type": claims.Type,
@ -32,15 +37,14 @@ func createToken(claims JwtClaims) *jwt.Token {
})
}
func CreateRefreshToken(userId uuid.UUID) string {
token := createToken(JwtClaims{
func (jm *JwtManager) CreateRefreshToken(userId uuid.UUID) string {
token := jm.createToken(JwtClaims{
UserID: userId.String(),
Type: Refresh,
Expire: time.Now().Add(time.Hour * 24 * 7),
Expire: time.Now().Add(time.Hour * 24 * 30),
})
// TODO: bruh what is this
tokenString, err := token.SignedString(JWT_SECRET)
tokenString, err := token.SignedString(jm.secret)
if err != nil {
panic(err)
}
@ -48,15 +52,14 @@ func CreateRefreshToken(userId uuid.UUID) string {
return tokenString
}
func CreateAccessToken(userId uuid.UUID) string {
token := createToken(JwtClaims{
func (jm *JwtManager) CreateAccessToken(userId uuid.UUID) string {
token := jm.createToken(JwtClaims{
UserID: userId.String(),
Type: Access,
Expire: time.Now().Add(time.Hour),
Expire: time.Now().Add(time.Minute),
})
// TODO: bruh what is this
tokenString, err := token.SignedString(JWT_SECRET)
tokenString, err := token.SignedString(jm.secret)
if err != nil {
panic(err)
}
@ -66,18 +69,15 @@ func CreateAccessToken(userId uuid.UUID) string {
var NotValidToken = errors.New("Not a valid token")
func GetUserIdFromAccess(accessToken string) (uuid.UUID, error) {
func (jm *JwtManager) GetUserIdFromAccess(accessToken string) (uuid.UUID, error) {
token, err := jwt.Parse(accessToken, func(token *jwt.Token) (any, error) {
return JWT_SECRET, nil
return jm.secret, nil
}, jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}))
if err != nil {
return uuid.Nil, err
}
// Blah blah, check expiry and stuff
// this function is stupid
if claims, ok := token.Claims.(jwt.MapClaims); ok {
tokenType, ok := claims["Type"]
if !ok || tokenType.(string) != "access" {
@ -94,3 +94,7 @@ func GetUserIdFromAccess(accessToken string) (uuid.UUID, error) {
return uuid.Nil, NotValidToken
}
}
func GetUserIdFromAccess(jm *JwtManager, accessToken string) (uuid.UUID, error) {
return jm.GetUserIdFromAccess(accessToken)
}

View File

@ -50,11 +50,12 @@ func GetUserID(ctx context.Context, logger *log.Logger, w http.ResponseWriter) (
return userIdUuid, nil
}
func ProtectedRouteURL(next http.Handler) http.Handler {
func ProtectedRouteURL(jm *JwtManager) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := r.URL.Query().Get("token")
userId, err := GetUserIdFromAccess(token)
userId, err := GetUserIdFromAccess(jm, token)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
return
@ -66,8 +67,10 @@ func ProtectedRouteURL(next http.Handler) http.Handler {
next.ServeHTTP(w, newR)
})
}
}
func ProtectedRoute(next http.Handler) http.Handler {
func ProtectedRoute(jm *JwtManager) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := r.Header.Get("Authorization")
@ -76,7 +79,7 @@ func ProtectedRoute(next http.Handler) http.Handler {
return
}
userId, err := GetUserIdFromAccess(token[len("Bearer "):])
userId, err := GetUserIdFromAccess(jm, token[len("Bearer "):])
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
return
@ -88,8 +91,10 @@ func ProtectedRoute(next http.Handler) http.Handler {
next.ServeHTTP(w, newR)
})
}
}
func GetUserIdFromUrl(next http.Handler) http.Handler {
func GetUserIdFromUrl(jm *JwtManager) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := r.URL.Query().Get("token")
@ -98,7 +103,7 @@ func GetUserIdFromUrl(next http.Handler) http.Handler {
return
}
userId, err := GetUserIdFromAccess(token)
userId, err := GetUserIdFromAccess(jm, token)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
return
@ -110,6 +115,7 @@ func GetUserIdFromUrl(next http.Handler) http.Handler {
next.ServeHTTP(w, newR)
})
}
}
func GetPathParamID(logger *log.Logger, param string, w http.ResponseWriter, r *http.Request) (uuid.UUID, error) {
pathParam := r.PathValue(param)

View File

@ -24,7 +24,7 @@ func (client TestAiClient) GetImageInfo(imageName string, imageData []byte) (cli
return client.ImageInfo, nil
}
func setupRouter(db *sql.DB) chi.Router {
func setupRouter(db *sql.DB, jwtManager *ourmiddleware.JwtManager) chi.Router {
imageModel := models.NewImageModel(db)
stackModel := models.NewListModel(db)
@ -33,9 +33,9 @@ func setupRouter(db *sql.DB) chi.Router {
processImageLogger := createLogger("Process Image", os.Stdout)
processImage := ProcessImage(processImageLogger, db)
stackHandler := stacks.CreateStackHandler(db, limitsManager)
authHandler := auth.CreateAuthHandler(db)
imageHandler := images.CreateImageHandler(db, limitsManager, processImage)
stackHandler := stacks.CreateStackHandler(db, limitsManager, jwtManager)
authHandler := auth.CreateAuthHandler(db, jwtManager)
imageHandler := images.CreateImageHandler(db, limitsManager, processImage, jwtManager)
notifier := NewNotifier[Notification](10)
@ -62,7 +62,7 @@ func setupRouter(db *sql.DB) chi.Router {
r.Route("/images", imageHandler.CreateRoutes)
r.Route("/notifications", func(r chi.Router) {
r.Use(ourmiddleware.GetUserIdFromUrl)
r.Use(ourmiddleware.GetUserIdFromUrl(jwtManager))
r.Get("/", CreateEventsHandler(&notifier))
})

View File

@ -23,6 +23,8 @@ type StackHandler struct {
stackModel models.ListModel
limitsManager limits.LimitsManagerMethods
jwtManager *middleware.JwtManager
}
func (h *StackHandler) getAllStacks(w http.ResponseWriter, r *http.Request) {
@ -185,7 +187,7 @@ func (h *StackHandler) CreateRoutes(r chi.Router) {
h.logger.Info("Mounting stack router")
r.Group(func(r chi.Router) {
r.Use(middleware.ProtectedRoute)
r.Use(middleware.ProtectedRoute(h.jwtManager))
r.Use(middleware.SetJson)
r.Get("/", h.getAllStacks)
@ -198,15 +200,16 @@ func (h *StackHandler) CreateRoutes(r chi.Router) {
})
}
func CreateStackHandler(db *sql.DB, limitsManager limits.LimitsManagerMethods) StackHandler {
func CreateStackHandler(db *sql.DB, limitsManager limits.LimitsManagerMethods, jwtManager *middleware.JwtManager) StackHandler {
stackModel := models.NewListModel(db)
imageModel := models.NewImageModel(db)
logger := log.New(os.Stdout).WithPrefix("Stacks")
return StackHandler{
logger,
imageModel,
stackModel,
limitsManager,
logger: logger,
imageModel: imageModel,
stackModel: stackModel,
limitsManager: limitsManager,
jwtManager: jwtManager,
}
}