AI: refactoring JWT tokens to jwt token manager

This commit is contained in:
2025-09-21 14:42:06 +01:00
parent e28d9e5d16
commit f078ac7d0b
8 changed files with 120 additions and 90 deletions

View File

@@ -21,10 +21,15 @@ type JwtClaims struct {
Expire time.Time
}
// obviously this is very not secure. TODO: extract to env
var JWT_SECRET = []byte("very secret")
type JwtManager struct {
secret []byte
}
func createToken(claims JwtClaims) *jwt.Token {
func NewJwtManager(secret []byte) *JwtManager {
return &JwtManager{secret: secret}
}
func (jm *JwtManager) createToken(claims JwtClaims) *jwt.Token {
return jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"UserID": claims.UserID,
"Type": claims.Type,
@@ -32,15 +37,14 @@ func createToken(claims JwtClaims) *jwt.Token {
})
}
func CreateRefreshToken(userId uuid.UUID) string {
token := createToken(JwtClaims{
func (jm *JwtManager) CreateRefreshToken(userId uuid.UUID) string {
token := jm.createToken(JwtClaims{
UserID: userId.String(),
Type: Refresh,
Expire: time.Now().Add(time.Hour * 24 * 7),
Expire: time.Now().Add(time.Hour * 24 * 30),
})
// TODO: bruh what is this
tokenString, err := token.SignedString(JWT_SECRET)
tokenString, err := token.SignedString(jm.secret)
if err != nil {
panic(err)
}
@@ -48,15 +52,14 @@ func CreateRefreshToken(userId uuid.UUID) string {
return tokenString
}
func CreateAccessToken(userId uuid.UUID) string {
token := createToken(JwtClaims{
func (jm *JwtManager) CreateAccessToken(userId uuid.UUID) string {
token := jm.createToken(JwtClaims{
UserID: userId.String(),
Type: Access,
Expire: time.Now().Add(time.Hour),
Expire: time.Now().Add(time.Minute),
})
// TODO: bruh what is this
tokenString, err := token.SignedString(JWT_SECRET)
tokenString, err := token.SignedString(jm.secret)
if err != nil {
panic(err)
}
@@ -66,18 +69,15 @@ func CreateAccessToken(userId uuid.UUID) string {
var NotValidToken = errors.New("Not a valid token")
func GetUserIdFromAccess(accessToken string) (uuid.UUID, error) {
func (jm *JwtManager) GetUserIdFromAccess(accessToken string) (uuid.UUID, error) {
token, err := jwt.Parse(accessToken, func(token *jwt.Token) (any, error) {
return JWT_SECRET, nil
return jm.secret, nil
}, jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}))
if err != nil {
return uuid.Nil, err
}
// Blah blah, check expiry and stuff
// this function is stupid
if claims, ok := token.Claims.(jwt.MapClaims); ok {
tokenType, ok := claims["Type"]
if !ok || tokenType.(string) != "access" {
@@ -94,3 +94,7 @@ func GetUserIdFromAccess(accessToken string) (uuid.UUID, error) {
return uuid.Nil, NotValidToken
}
}
func GetUserIdFromAccess(jm *JwtManager, accessToken string) (uuid.UUID, error) {
return jm.GetUserIdFromAccess(accessToken)
}

View File

@@ -50,65 +50,71 @@ func GetUserID(ctx context.Context, logger *log.Logger, w http.ResponseWriter) (
return userIdUuid, nil
}
func ProtectedRouteURL(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := r.URL.Query().Get("token")
func ProtectedRouteURL(jm *JwtManager) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := r.URL.Query().Get("token")
userId, err := GetUserIdFromAccess(token)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
return
}
userId, err := GetUserIdFromAccess(jm, token)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
return
}
contextWithUserId := context.WithValue(r.Context(), USER_ID, userId)
contextWithUserId := context.WithValue(r.Context(), USER_ID, userId)
newR := r.WithContext(contextWithUserId)
next.ServeHTTP(w, newR)
})
newR := r.WithContext(contextWithUserId)
next.ServeHTTP(w, newR)
})
}
}
func ProtectedRoute(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := r.Header.Get("Authorization")
func ProtectedRoute(jm *JwtManager) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := r.Header.Get("Authorization")
if len(token) < len("Bearer ") {
w.WriteHeader(http.StatusUnauthorized)
return
}
if len(token) < len("Bearer ") {
w.WriteHeader(http.StatusUnauthorized)
return
}
userId, err := GetUserIdFromAccess(token[len("Bearer "):])
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
return
}
userId, err := GetUserIdFromAccess(jm, token[len("Bearer "):])
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
return
}
contextWithUserId := context.WithValue(r.Context(), USER_ID, userId)
contextWithUserId := context.WithValue(r.Context(), USER_ID, userId)
newR := r.WithContext(contextWithUserId)
next.ServeHTTP(w, newR)
})
newR := r.WithContext(contextWithUserId)
next.ServeHTTP(w, newR)
})
}
}
func GetUserIdFromUrl(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := r.URL.Query().Get("token")
func GetUserIdFromUrl(jm *JwtManager) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := r.URL.Query().Get("token")
if len(token) == 0 {
w.WriteHeader(http.StatusUnauthorized)
return
}
if len(token) == 0 {
w.WriteHeader(http.StatusUnauthorized)
return
}
userId, err := GetUserIdFromAccess(token)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
return
}
userId, err := GetUserIdFromAccess(jm, token)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
return
}
contextWithUserId := context.WithValue(r.Context(), USER_ID, userId)
contextWithUserId := context.WithValue(r.Context(), USER_ID, userId)
newR := r.WithContext(contextWithUserId)
next.ServeHTTP(w, newR)
})
newR := r.WithContext(contextWithUserId)
next.ServeHTTP(w, newR)
})
}
}
func GetPathParamID(logger *log.Logger, param string, w http.ResponseWriter, r *http.Request) (uuid.UUID, error) {