diff --git a/backend/jwt.go b/backend/jwt.go index 06b2d5c..9b76eff 100644 --- a/backend/jwt.go +++ b/backend/jwt.go @@ -1,6 +1,7 @@ package main import ( + "errors" "time" "github.com/golang-jwt/jwt/v5" @@ -20,6 +21,9 @@ type JwtClaims struct { Expire time.Time } +// obviously this is very not secure. TODO: extract to env +var JWT_SECRET = []byte("very secret") + func createToken(claims JwtClaims) *jwt.Token { return jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ "UserID": claims.UserID, @@ -36,7 +40,7 @@ func CreateRefreshToken(userId uuid.UUID) string { }) // TODO: bruh what is this - tokenString, err := token.SignedString([]byte("very secret")) + tokenString, err := token.SignedString(JWT_SECRET) if err != nil { panic(err) } @@ -52,10 +56,41 @@ func CreateAccessToken(userId uuid.UUID) string { }) // TODO: bruh what is this - tokenString, err := token.SignedString([]byte("very secret")) + tokenString, err := token.SignedString(JWT_SECRET) if err != nil { panic(err) } return tokenString } + +var NotValidToken = errors.New("Not a valid token") + +func GetUserIdFromAccess(accessToken string) (uuid.UUID, error) { + token, err := jwt.Parse(accessToken, func(token *jwt.Token) (any, error) { + return JWT_SECRET, nil + }, jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()})) + + if err != nil { + return uuid.Nil, err + } + + // Blah blah, check expiry and stuff + + // this function is stupid + if claims, ok := token.Claims.(jwt.MapClaims); ok { + tokenType, ok := claims["Type"] + if !ok || tokenType.(string) != "access" { + return uuid.Nil, NotValidToken + } + + userId, err := uuid.Parse(claims["UserID"].(string)) + if err != nil { + return uuid.Nil, NotValidToken + } + + return userId, nil + } else { + return uuid.Nil, NotValidToken + } +} diff --git a/backend/main.go b/backend/main.go index f8ab597..cedf257 100644 --- a/backend/main.go +++ b/backend/main.go @@ -63,9 +63,19 @@ func main() { }) r.Get("/image", func(w http.ResponseWriter, r *http.Request) { - userId := r.Header.Get("userId") + token := r.Header.Get("Authorization")[7:] - images, err := userModel.ListWithProperties(r.Context(), uuid.MustParse(userId)) + fmt.Println(token) + + userId, err := GetUserIdFromAccess(token) + if err != nil { + log.Println(err) + w.WriteHeader(http.StatusForbidden) + fmt.Fprintf(w, "Get out of here!") + return + } + + images, err := userModel.ListWithProperties(r.Context(), userId) if err != nil { log.Println(err) w.WriteHeader(http.StatusNotFound)