Haystack/backend/middleware/middleware.go
2025-08-19 21:27:37 +01:00

90 lines
2.1 KiB
Go

package middleware
import (
"context"
"errors"
"fmt"
"net/http"
"github.com/google/uuid"
)
func CorsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Access-Control-Allow-Origin", "*")
w.Header().Add("Access-Control-Allow-Headers", "*")
// Access-Control-Allow-Methods is often needed for preflight OPTIONS requests
w.Header().Add("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
// The client makes an OPTIONS preflight request before a complex request.
// We must handle this and respond with the appropriate headers.
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
next.ServeHTTP(w, r)
})
}
const USER_ID = "UserID"
func GetUserID(ctx context.Context) (uuid.UUID, error) {
userId := ctx.Value(USER_ID)
if userId == nil {
return uuid.Nil, errors.New("context does not contain a user id")
}
userIdUuid, ok := userId.(uuid.UUID)
if !ok {
return uuid.Nil, fmt.Errorf("context user id is not of type uuid, got: %t", userId)
}
return userIdUuid, nil
}
func ProtectedRoute(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := r.Header.Get("Authorization")
if len(token) < len("Bearer ") {
w.WriteHeader(http.StatusUnauthorized)
return
}
userId, err := GetUserIdFromAccess(token[len("Bearer "):])
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
return
}
contextWithUserId := context.WithValue(r.Context(), USER_ID, userId)
newR := r.WithContext(contextWithUserId)
next.ServeHTTP(w, newR)
})
}
func GetUserIdFromUrl(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := r.URL.Query().Get("token")
if len(token) == 0 {
w.WriteHeader(http.StatusUnauthorized)
return
}
userId, err := GetUserIdFromAccess(token)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
return
}
contextWithUserId := context.WithValue(r.Context(), USER_ID, userId)
newR := r.WithContext(contextWithUserId)
next.ServeHTTP(w, newR)
})
}