a good start
This commit is contained in:
11
backend/middleware/json.go
Normal file
11
backend/middleware/json.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package middleware
|
||||
|
||||
import "net/http"
|
||||
|
||||
func SetJson(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
96
backend/middleware/jwt.go
Normal file
96
backend/middleware/jwt.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type JwtType string
|
||||
|
||||
const (
|
||||
Access JwtType = "access"
|
||||
Refresh JwtType = "refresh"
|
||||
)
|
||||
|
||||
type JwtClaims struct {
|
||||
UserID string
|
||||
Type JwtType
|
||||
Expire time.Time
|
||||
}
|
||||
|
||||
// obviously this is very not secure. TODO: extract to env
|
||||
var JWT_SECRET = []byte("very secret")
|
||||
|
||||
func createToken(claims JwtClaims) *jwt.Token {
|
||||
return jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"UserID": claims.UserID,
|
||||
"Type": claims.Type,
|
||||
"Expire": claims.Expire,
|
||||
})
|
||||
}
|
||||
|
||||
func CreateRefreshToken(userId uuid.UUID) string {
|
||||
token := createToken(JwtClaims{
|
||||
UserID: userId.String(),
|
||||
Type: Refresh,
|
||||
Expire: time.Now().Add(time.Hour * 24 * 7),
|
||||
})
|
||||
|
||||
// TODO: bruh what is this
|
||||
tokenString, err := token.SignedString(JWT_SECRET)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return tokenString
|
||||
}
|
||||
|
||||
func CreateAccessToken(userId uuid.UUID) string {
|
||||
token := createToken(JwtClaims{
|
||||
UserID: userId.String(),
|
||||
Type: Access,
|
||||
Expire: time.Now().Add(time.Hour),
|
||||
})
|
||||
|
||||
// TODO: bruh what is this
|
||||
tokenString, err := token.SignedString(JWT_SECRET)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return tokenString
|
||||
}
|
||||
|
||||
var NotValidToken = errors.New("Not a valid token")
|
||||
|
||||
func GetUserIdFromAccess(accessToken string) (uuid.UUID, error) {
|
||||
token, err := jwt.Parse(accessToken, func(token *jwt.Token) (any, error) {
|
||||
return JWT_SECRET, nil
|
||||
}, jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}))
|
||||
|
||||
if err != nil {
|
||||
return uuid.Nil, err
|
||||
}
|
||||
|
||||
// Blah blah, check expiry and stuff
|
||||
|
||||
// this function is stupid
|
||||
if claims, ok := token.Claims.(jwt.MapClaims); ok {
|
||||
tokenType, ok := claims["Type"]
|
||||
if !ok || tokenType.(string) != "access" {
|
||||
return uuid.Nil, NotValidToken
|
||||
}
|
||||
|
||||
userId, err := uuid.Parse(claims["UserID"].(string))
|
||||
if err != nil {
|
||||
return uuid.Nil, NotValidToken
|
||||
}
|
||||
|
||||
return userId, nil
|
||||
} else {
|
||||
return uuid.Nil, NotValidToken
|
||||
}
|
||||
}
|
||||
89
backend/middleware/middleware.go
Normal file
89
backend/middleware/middleware.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func CorsMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Add("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Add("Access-Control-Allow-Headers", "*")
|
||||
|
||||
// Access-Control-Allow-Methods is often needed for preflight OPTIONS requests
|
||||
w.Header().Add("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||
|
||||
// The client makes an OPTIONS preflight request before a complex request.
|
||||
// We must handle this and respond with the appropriate headers.
|
||||
if r.Method == "OPTIONS" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
const USER_ID = "UserID"
|
||||
|
||||
func GetUserID(ctx context.Context) (uuid.UUID, error) {
|
||||
userId := ctx.Value(USER_ID)
|
||||
|
||||
if userId == nil {
|
||||
return uuid.Nil, errors.New("context does not contain a user id")
|
||||
}
|
||||
|
||||
userIdUuid, ok := userId.(uuid.UUID)
|
||||
if !ok {
|
||||
return uuid.Nil, fmt.Errorf("context user id is not of type uuid, got: %t", userId)
|
||||
}
|
||||
|
||||
return userIdUuid, nil
|
||||
}
|
||||
|
||||
func ProtectedRoute(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
token := r.Header.Get("Authorization")
|
||||
if len(token) < len("Bearer ") {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
userId, err := GetUserIdFromAccess(token[len("Bearer "):])
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
contextWithUserId := context.WithValue(r.Context(), USER_ID, userId)
|
||||
|
||||
newR := r.WithContext(contextWithUserId)
|
||||
next.ServeHTTP(w, newR)
|
||||
})
|
||||
}
|
||||
|
||||
func GetUserIdFromUrl(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
token := r.URL.Query().Get("token")
|
||||
|
||||
if len(token) == 0 {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
userId, err := GetUserIdFromAccess(token)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
contextWithUserId := context.WithValue(r.Context(), USER_ID, userId)
|
||||
|
||||
newR := r.WithContext(contextWithUserId)
|
||||
next.ServeHTTP(w, newR)
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user