137 lines
2.9 KiB
Go
137 lines
2.9 KiB
Go
package middleware
|
|
|
|
import (
|
|
"errors"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
type JwtType string
|
|
|
|
const (
|
|
Access JwtType = "access"
|
|
Refresh JwtType = "refresh"
|
|
)
|
|
|
|
type JwtClaims struct {
|
|
UserID string
|
|
Type JwtType
|
|
Expiry time.Time
|
|
}
|
|
|
|
type JwtManager struct {
|
|
secret []byte
|
|
}
|
|
|
|
func NewJwtManager(secret []byte) *JwtManager {
|
|
return &JwtManager{secret: secret}
|
|
}
|
|
|
|
func (jm *JwtManager) createToken(claims JwtClaims) *jwt.Token {
|
|
return jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
|
"UserID": claims.UserID,
|
|
"Type": claims.Type,
|
|
"exp": claims.Expiry.Unix(),
|
|
})
|
|
}
|
|
|
|
func (jm *JwtManager) CreateRefreshToken(userId uuid.UUID) string {
|
|
token := jm.createToken(JwtClaims{
|
|
UserID: userId.String(),
|
|
Type: Refresh,
|
|
Expiry: time.Now().Add(time.Hour * 24 * 30),
|
|
})
|
|
|
|
tokenString, err := token.SignedString(jm.secret)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
return tokenString
|
|
}
|
|
|
|
func (jm *JwtManager) CreateAccessToken(userId uuid.UUID) string {
|
|
token := jm.createToken(JwtClaims{
|
|
UserID: userId.String(),
|
|
Type: Access,
|
|
Expiry: time.Now().Add(time.Minute),
|
|
})
|
|
|
|
tokenString, err := token.SignedString(jm.secret)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
return tokenString
|
|
}
|
|
|
|
var NotValidToken = errors.New("Not a valid token")
|
|
|
|
func (jm *JwtManager) GetUserIdFromAccess(accessToken string) (uuid.UUID, error) {
|
|
token, err := jwt.Parse(accessToken, func(token *jwt.Token) (any, error) {
|
|
return jm.secret, nil
|
|
}, jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}))
|
|
|
|
if err != nil {
|
|
return uuid.Nil, err
|
|
}
|
|
|
|
// Check if token is valid (JWT library validates exp claim automatically)
|
|
if !token.Valid {
|
|
return uuid.Nil, NotValidToken
|
|
}
|
|
|
|
if claims, ok := token.Claims.(jwt.MapClaims); ok {
|
|
tokenType, ok := claims["Type"]
|
|
if !ok || tokenType.(string) != "access" {
|
|
return uuid.Nil, NotValidToken
|
|
}
|
|
|
|
userId, err := uuid.Parse(claims["UserID"].(string))
|
|
if err != nil {
|
|
return uuid.Nil, NotValidToken
|
|
}
|
|
|
|
return userId, nil
|
|
} else {
|
|
return uuid.Nil, NotValidToken
|
|
}
|
|
}
|
|
|
|
func (jm *JwtManager) GetUserIdFromRefresh(refreshToken string) (uuid.UUID, error) {
|
|
token, err := jwt.Parse(refreshToken, func(token *jwt.Token) (any, error) {
|
|
return jm.secret, nil
|
|
}, jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}))
|
|
|
|
if err != nil {
|
|
return uuid.Nil, err
|
|
}
|
|
|
|
// Check if token is valid (JWT library validates exp claim automatically)
|
|
if !token.Valid {
|
|
return uuid.Nil, NotValidToken
|
|
}
|
|
|
|
if claims, ok := token.Claims.(jwt.MapClaims); ok {
|
|
tokenType, ok := claims["Type"]
|
|
if !ok || tokenType.(string) != "refresh" {
|
|
return uuid.Nil, NotValidToken
|
|
}
|
|
|
|
userId, err := uuid.Parse(claims["UserID"].(string))
|
|
if err != nil {
|
|
return uuid.Nil, NotValidToken
|
|
}
|
|
|
|
return userId, nil
|
|
} else {
|
|
return uuid.Nil, NotValidToken
|
|
}
|
|
}
|
|
|
|
func GetUserIdFromAccess(jm *JwtManager, accessToken string) (uuid.UUID, error) {
|
|
return jm.GetUserIdFromAccess(accessToken)
|
|
}
|