Compare commits
2 Commits
e28d9e5d16
...
a3345afbfa
Author | SHA1 | Date | |
---|---|---|---|
a3345afbfa | |||
f078ac7d0b |
@ -18,6 +18,8 @@ type AuthHandler struct {
|
||||
user models.UserModel
|
||||
|
||||
auth Auth
|
||||
|
||||
jwtManager *middleware.JwtManager
|
||||
}
|
||||
|
||||
type loginBody struct {
|
||||
@ -65,8 +67,8 @@ func (h *AuthHandler) code(body codeBody, w http.ResponseWriter, r *http.Request
|
||||
return
|
||||
}
|
||||
|
||||
refresh := middleware.CreateRefreshToken(uuid)
|
||||
access := middleware.CreateAccessToken(uuid)
|
||||
refresh := h.jwtManager.CreateRefreshToken(uuid)
|
||||
access := h.jwtManager.CreateAccessToken(uuid)
|
||||
|
||||
codeReturn := codeReturn{
|
||||
Access: access,
|
||||
@ -87,7 +89,7 @@ func (h *AuthHandler) CreateRoutes(r chi.Router) {
|
||||
})
|
||||
}
|
||||
|
||||
func CreateAuthHandler(db *sql.DB) AuthHandler {
|
||||
func CreateAuthHandler(db *sql.DB, jwtManager *middleware.JwtManager) AuthHandler {
|
||||
userModel := models.NewUserModel(db)
|
||||
logger := log.New(os.Stdout).WithPrefix("Auth")
|
||||
|
||||
@ -99,8 +101,9 @@ func CreateAuthHandler(db *sql.DB) AuthHandler {
|
||||
auth := CreateAuth(mailer)
|
||||
|
||||
return AuthHandler{
|
||||
logger,
|
||||
userModel,
|
||||
auth,
|
||||
logger: logger,
|
||||
user: userModel,
|
||||
auth: auth,
|
||||
jwtManager: jwtManager,
|
||||
}
|
||||
}
|
||||
|
@ -28,6 +28,8 @@ type ImageHandler struct {
|
||||
limitsManager limits.LimitsManagerMethods
|
||||
|
||||
processImage func(imageID uuid.UUID)
|
||||
|
||||
jwtManager *middleware.JwtManager
|
||||
}
|
||||
|
||||
type ImagesReturn struct {
|
||||
@ -250,12 +252,12 @@ func (h *ImageHandler) CreateRoutes(r chi.Router) {
|
||||
h.logger.Info("Mounting image router")
|
||||
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(middleware.ProtectedRouteURL)
|
||||
r.Use(middleware.ProtectedRouteURL(h.jwtManager))
|
||||
r.Get("/{id}", h.serveImage)
|
||||
})
|
||||
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(middleware.ProtectedRoute)
|
||||
r.Use(middleware.ProtectedRoute(h.jwtManager))
|
||||
r.Use(middleware.SetJson)
|
||||
|
||||
r.Get("/", h.listImages)
|
||||
@ -264,7 +266,7 @@ func (h *ImageHandler) CreateRoutes(r chi.Router) {
|
||||
})
|
||||
}
|
||||
|
||||
func CreateImageHandler(db *sql.DB, limitsManager limits.LimitsManagerMethods, processImage func(imageID uuid.UUID)) ImageHandler {
|
||||
func CreateImageHandler(db *sql.DB, limitsManager limits.LimitsManagerMethods, processImage func(imageID uuid.UUID), jwtManager *middleware.JwtManager) ImageHandler {
|
||||
imageModel := models.NewImageModel(db)
|
||||
userModel := models.NewUserModel(db)
|
||||
logger := log.New(os.Stdout).WithPrefix("Images")
|
||||
@ -275,5 +277,6 @@ func CreateImageHandler(db *sql.DB, limitsManager limits.LimitsManagerMethods, p
|
||||
userModel: userModel,
|
||||
limitsManager: limitsManager,
|
||||
processImage: processImage,
|
||||
jwtManager: jwtManager,
|
||||
}
|
||||
}
|
||||
|
@ -62,6 +62,7 @@ type TestContext struct {
|
||||
server *httptest.Server
|
||||
users []TestUser
|
||||
cleanup func()
|
||||
jwtManager *middleware.JwtManager
|
||||
}
|
||||
|
||||
func setupTestDatabase() (*sql.DB, func(), error) {
|
||||
@ -179,12 +180,14 @@ func setupTestContext(t *testing.T) *TestContext {
|
||||
t.Fatalf("Failed to setup test database: %v", err)
|
||||
}
|
||||
|
||||
router := setupRouter(db)
|
||||
jwtManager := middleware.NewJwtManager([]byte("test-jwt-secret"))
|
||||
router := setupRouter(db, jwtManager)
|
||||
server := httptest.NewServer(router)
|
||||
|
||||
tc.db = db
|
||||
tc.router = router
|
||||
tc.server = server
|
||||
tc.jwtManager = jwtManager
|
||||
tc.cleanup = func() {
|
||||
server.Close()
|
||||
cleanup()
|
||||
@ -202,7 +205,7 @@ func (tc *TestContext) createTestUser(email string) TestUser {
|
||||
}
|
||||
|
||||
// Create access token for the user
|
||||
accessToken := middleware.CreateAccessToken(userID)
|
||||
accessToken := tc.jwtManager.CreateAccessToken(userID)
|
||||
|
||||
user := TestUser{
|
||||
ID: userID,
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"screenmark/screenmark/middleware"
|
||||
"screenmark/screenmark/models"
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
@ -15,12 +16,19 @@ func main() {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
jwtSecret := os.Getenv("JWT_SECRET")
|
||||
if jwtSecret == "" {
|
||||
panic("JWT_SECRET environment variable not set")
|
||||
}
|
||||
|
||||
jwtManager := middleware.NewJwtManager([]byte(jwtSecret))
|
||||
|
||||
db, err := models.InitDatabase()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
router := setupRouter(db)
|
||||
router := setupRouter(db, jwtManager)
|
||||
|
||||
port, exists := os.LookupEnv("PORT")
|
||||
if !exists {
|
||||
|
@ -1,6 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
@ -21,10 +22,15 @@ type JwtClaims struct {
|
||||
Expire time.Time
|
||||
}
|
||||
|
||||
// obviously this is very not secure. TODO: extract to env
|
||||
var JWT_SECRET = []byte("very secret")
|
||||
type JwtManager struct {
|
||||
secret []byte
|
||||
}
|
||||
|
||||
func createToken(claims JwtClaims) *jwt.Token {
|
||||
func NewJwtManager(secret []byte) *JwtManager {
|
||||
return &JwtManager{secret: secret}
|
||||
}
|
||||
|
||||
func (jm *JwtManager) createToken(claims JwtClaims) *jwt.Token {
|
||||
return jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"UserID": claims.UserID,
|
||||
"Type": claims.Type,
|
||||
@ -32,15 +38,14 @@ func createToken(claims JwtClaims) *jwt.Token {
|
||||
})
|
||||
}
|
||||
|
||||
func CreateRefreshToken(userId uuid.UUID) string {
|
||||
token := createToken(JwtClaims{
|
||||
func (jm *JwtManager) CreateRefreshToken(userId uuid.UUID) string {
|
||||
token := jm.createToken(JwtClaims{
|
||||
UserID: userId.String(),
|
||||
Type: Refresh,
|
||||
Expire: time.Now().Add(time.Hour * 24 * 7),
|
||||
Expire: time.Now().Add(time.Hour * 24 * 30),
|
||||
})
|
||||
|
||||
// TODO: bruh what is this
|
||||
tokenString, err := token.SignedString(JWT_SECRET)
|
||||
tokenString, err := token.SignedString(jm.secret)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@ -48,15 +53,14 @@ func CreateRefreshToken(userId uuid.UUID) string {
|
||||
return tokenString
|
||||
}
|
||||
|
||||
func CreateAccessToken(userId uuid.UUID) string {
|
||||
token := createToken(JwtClaims{
|
||||
func (jm *JwtManager) CreateAccessToken(userId uuid.UUID) string {
|
||||
token := jm.createToken(JwtClaims{
|
||||
UserID: userId.String(),
|
||||
Type: Access,
|
||||
Expire: time.Now().Add(time.Hour),
|
||||
Expire: time.Now().Add(time.Minute),
|
||||
})
|
||||
|
||||
// TODO: bruh what is this
|
||||
tokenString, err := token.SignedString(JWT_SECRET)
|
||||
tokenString, err := token.SignedString(jm.secret)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@ -66,24 +70,50 @@ func CreateAccessToken(userId uuid.UUID) string {
|
||||
|
||||
var NotValidToken = errors.New("Not a valid token")
|
||||
|
||||
func GetUserIdFromAccess(accessToken string) (uuid.UUID, error) {
|
||||
func (jm *JwtManager) GetUserIdFromAccess(accessToken string) (uuid.UUID, error) {
|
||||
token, err := jwt.Parse(accessToken, func(token *jwt.Token) (any, error) {
|
||||
return JWT_SECRET, nil
|
||||
return jm.secret, nil
|
||||
}, jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}))
|
||||
|
||||
if err != nil {
|
||||
return uuid.Nil, err
|
||||
}
|
||||
|
||||
// Blah blah, check expiry and stuff
|
||||
// Check if token is valid (including expiry check)
|
||||
if !token.Valid {
|
||||
return uuid.Nil, NotValidToken
|
||||
}
|
||||
|
||||
// this function is stupid
|
||||
if claims, ok := token.Claims.(jwt.MapClaims); ok {
|
||||
tokenType, ok := claims["Type"]
|
||||
if !ok || tokenType.(string) != "access" {
|
||||
return uuid.Nil, NotValidToken
|
||||
}
|
||||
|
||||
// Additional explicit expiry check
|
||||
expireClaim, ok := claims["Expire"]
|
||||
if !ok {
|
||||
return uuid.Nil, NotValidToken
|
||||
}
|
||||
|
||||
var expireTime time.Time
|
||||
switch exp := expireClaim.(type) {
|
||||
case float64:
|
||||
expireTime = time.Unix(int64(exp), 0)
|
||||
case json.Number:
|
||||
expInt, err := exp.Int64()
|
||||
if err != nil {
|
||||
return uuid.Nil, NotValidToken
|
||||
}
|
||||
expireTime = time.Unix(expInt, 0)
|
||||
default:
|
||||
return uuid.Nil, NotValidToken
|
||||
}
|
||||
|
||||
if time.Now().After(expireTime) {
|
||||
return uuid.Nil, NotValidToken
|
||||
}
|
||||
|
||||
userId, err := uuid.Parse(claims["UserID"].(string))
|
||||
if err != nil {
|
||||
return uuid.Nil, NotValidToken
|
||||
@ -94,3 +124,7 @@ func GetUserIdFromAccess(accessToken string) (uuid.UUID, error) {
|
||||
return uuid.Nil, NotValidToken
|
||||
}
|
||||
}
|
||||
|
||||
func GetUserIdFromAccess(jm *JwtManager, accessToken string) (uuid.UUID, error) {
|
||||
return jm.GetUserIdFromAccess(accessToken)
|
||||
}
|
||||
|
@ -50,11 +50,12 @@ func GetUserID(ctx context.Context, logger *log.Logger, w http.ResponseWriter) (
|
||||
return userIdUuid, nil
|
||||
}
|
||||
|
||||
func ProtectedRouteURL(next http.Handler) http.Handler {
|
||||
func ProtectedRouteURL(jm *JwtManager) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
token := r.URL.Query().Get("token")
|
||||
|
||||
userId, err := GetUserIdFromAccess(token)
|
||||
userId, err := GetUserIdFromAccess(jm, token)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
@ -66,8 +67,10 @@ func ProtectedRouteURL(next http.Handler) http.Handler {
|
||||
next.ServeHTTP(w, newR)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func ProtectedRoute(next http.Handler) http.Handler {
|
||||
func ProtectedRoute(jm *JwtManager) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
token := r.Header.Get("Authorization")
|
||||
|
||||
@ -76,7 +79,7 @@ func ProtectedRoute(next http.Handler) http.Handler {
|
||||
return
|
||||
}
|
||||
|
||||
userId, err := GetUserIdFromAccess(token[len("Bearer "):])
|
||||
userId, err := GetUserIdFromAccess(jm, token[len("Bearer "):])
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
@ -88,8 +91,10 @@ func ProtectedRoute(next http.Handler) http.Handler {
|
||||
next.ServeHTTP(w, newR)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func GetUserIdFromUrl(next http.Handler) http.Handler {
|
||||
func GetUserIdFromUrl(jm *JwtManager) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
token := r.URL.Query().Get("token")
|
||||
|
||||
@ -98,7 +103,7 @@ func GetUserIdFromUrl(next http.Handler) http.Handler {
|
||||
return
|
||||
}
|
||||
|
||||
userId, err := GetUserIdFromAccess(token)
|
||||
userId, err := GetUserIdFromAccess(jm, token)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
@ -110,6 +115,7 @@ func GetUserIdFromUrl(next http.Handler) http.Handler {
|
||||
next.ServeHTTP(w, newR)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func GetPathParamID(logger *log.Logger, param string, w http.ResponseWriter, r *http.Request) (uuid.UUID, error) {
|
||||
pathParam := r.PathValue(param)
|
||||
|
@ -24,7 +24,7 @@ func (client TestAiClient) GetImageInfo(imageName string, imageData []byte) (cli
|
||||
return client.ImageInfo, nil
|
||||
}
|
||||
|
||||
func setupRouter(db *sql.DB) chi.Router {
|
||||
func setupRouter(db *sql.DB, jwtManager *ourmiddleware.JwtManager) chi.Router {
|
||||
imageModel := models.NewImageModel(db)
|
||||
stackModel := models.NewListModel(db)
|
||||
|
||||
@ -33,9 +33,9 @@ func setupRouter(db *sql.DB) chi.Router {
|
||||
processImageLogger := createLogger("Process Image", os.Stdout)
|
||||
processImage := ProcessImage(processImageLogger, db)
|
||||
|
||||
stackHandler := stacks.CreateStackHandler(db, limitsManager)
|
||||
authHandler := auth.CreateAuthHandler(db)
|
||||
imageHandler := images.CreateImageHandler(db, limitsManager, processImage)
|
||||
stackHandler := stacks.CreateStackHandler(db, limitsManager, jwtManager)
|
||||
authHandler := auth.CreateAuthHandler(db, jwtManager)
|
||||
imageHandler := images.CreateImageHandler(db, limitsManager, processImage, jwtManager)
|
||||
|
||||
notifier := NewNotifier[Notification](10)
|
||||
|
||||
@ -62,7 +62,7 @@ func setupRouter(db *sql.DB) chi.Router {
|
||||
r.Route("/images", imageHandler.CreateRoutes)
|
||||
|
||||
r.Route("/notifications", func(r chi.Router) {
|
||||
r.Use(ourmiddleware.GetUserIdFromUrl)
|
||||
r.Use(ourmiddleware.GetUserIdFromUrl(jwtManager))
|
||||
|
||||
r.Get("/", CreateEventsHandler(¬ifier))
|
||||
})
|
||||
|
@ -23,6 +23,8 @@ type StackHandler struct {
|
||||
stackModel models.ListModel
|
||||
|
||||
limitsManager limits.LimitsManagerMethods
|
||||
|
||||
jwtManager *middleware.JwtManager
|
||||
}
|
||||
|
||||
func (h *StackHandler) getAllStacks(w http.ResponseWriter, r *http.Request) {
|
||||
@ -185,7 +187,7 @@ func (h *StackHandler) CreateRoutes(r chi.Router) {
|
||||
h.logger.Info("Mounting stack router")
|
||||
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(middleware.ProtectedRoute)
|
||||
r.Use(middleware.ProtectedRoute(h.jwtManager))
|
||||
r.Use(middleware.SetJson)
|
||||
|
||||
r.Get("/", h.getAllStacks)
|
||||
@ -198,15 +200,16 @@ func (h *StackHandler) CreateRoutes(r chi.Router) {
|
||||
})
|
||||
}
|
||||
|
||||
func CreateStackHandler(db *sql.DB, limitsManager limits.LimitsManagerMethods) StackHandler {
|
||||
func CreateStackHandler(db *sql.DB, limitsManager limits.LimitsManagerMethods, jwtManager *middleware.JwtManager) StackHandler {
|
||||
stackModel := models.NewListModel(db)
|
||||
imageModel := models.NewImageModel(db)
|
||||
logger := log.New(os.Stdout).WithPrefix("Stacks")
|
||||
|
||||
return StackHandler{
|
||||
logger,
|
||||
imageModel,
|
||||
stackModel,
|
||||
limitsManager,
|
||||
logger: logger,
|
||||
imageModel: imageModel,
|
||||
stackModel: stackModel,
|
||||
limitsManager: limitsManager,
|
||||
jwtManager: jwtManager,
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user