131 lines
2.6 KiB
Go
131 lines
2.6 KiB
Go
package middleware
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
type JwtType string
|
|
|
|
const (
|
|
Access JwtType = "access"
|
|
Refresh JwtType = "refresh"
|
|
)
|
|
|
|
type JwtClaims struct {
|
|
UserID string
|
|
Type JwtType
|
|
Expire time.Time
|
|
}
|
|
|
|
type JwtManager struct {
|
|
secret []byte
|
|
}
|
|
|
|
func NewJwtManager(secret []byte) *JwtManager {
|
|
return &JwtManager{secret: secret}
|
|
}
|
|
|
|
func (jm *JwtManager) createToken(claims JwtClaims) *jwt.Token {
|
|
return jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
|
"UserID": claims.UserID,
|
|
"Type": claims.Type,
|
|
"Expire": claims.Expire,
|
|
})
|
|
}
|
|
|
|
func (jm *JwtManager) CreateRefreshToken(userId uuid.UUID) string {
|
|
token := jm.createToken(JwtClaims{
|
|
UserID: userId.String(),
|
|
Type: Refresh,
|
|
Expire: time.Now().Add(time.Hour * 24 * 30),
|
|
})
|
|
|
|
tokenString, err := token.SignedString(jm.secret)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
return tokenString
|
|
}
|
|
|
|
func (jm *JwtManager) CreateAccessToken(userId uuid.UUID) string {
|
|
token := jm.createToken(JwtClaims{
|
|
UserID: userId.String(),
|
|
Type: Access,
|
|
Expire: time.Now().Add(time.Minute),
|
|
})
|
|
|
|
tokenString, err := token.SignedString(jm.secret)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
return tokenString
|
|
}
|
|
|
|
var NotValidToken = errors.New("Not a valid token")
|
|
|
|
func (jm *JwtManager) GetUserIdFromAccess(accessToken string) (uuid.UUID, error) {
|
|
token, err := jwt.Parse(accessToken, func(token *jwt.Token) (any, error) {
|
|
return jm.secret, nil
|
|
}, jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}))
|
|
|
|
if err != nil {
|
|
return uuid.Nil, err
|
|
}
|
|
|
|
// Check if token is valid (including expiry check)
|
|
if !token.Valid {
|
|
return uuid.Nil, NotValidToken
|
|
}
|
|
|
|
if claims, ok := token.Claims.(jwt.MapClaims); ok {
|
|
tokenType, ok := claims["Type"]
|
|
if !ok || tokenType.(string) != "access" {
|
|
return uuid.Nil, NotValidToken
|
|
}
|
|
|
|
// Additional explicit expiry check
|
|
expireClaim, ok := claims["Expire"]
|
|
if !ok {
|
|
return uuid.Nil, NotValidToken
|
|
}
|
|
|
|
var expireTime time.Time
|
|
switch exp := expireClaim.(type) {
|
|
case float64:
|
|
expireTime = time.Unix(int64(exp), 0)
|
|
case json.Number:
|
|
expInt, err := exp.Int64()
|
|
if err != nil {
|
|
return uuid.Nil, NotValidToken
|
|
}
|
|
expireTime = time.Unix(expInt, 0)
|
|
default:
|
|
return uuid.Nil, NotValidToken
|
|
}
|
|
|
|
if time.Now().After(expireTime) {
|
|
return uuid.Nil, NotValidToken
|
|
}
|
|
|
|
userId, err := uuid.Parse(claims["UserID"].(string))
|
|
if err != nil {
|
|
return uuid.Nil, NotValidToken
|
|
}
|
|
|
|
return userId, nil
|
|
} else {
|
|
return uuid.Nil, NotValidToken
|
|
}
|
|
}
|
|
|
|
func GetUserIdFromAccess(jm *JwtManager, accessToken string) (uuid.UUID, error) {
|
|
return jm.GetUserIdFromAccess(accessToken)
|
|
}
|