Haystack/backend/middleware.go

62 lines
1.4 KiB
Go

package main
import (
"context"
"net/http"
)
func CorsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Access-Control-Allow-Origin", "*")
w.Header().Add("Access-Control-Allow-Credentials", "*")
w.Header().Add("Access-Control-Allow-Headers", "*")
next.ServeHTTP(w, r)
})
}
const USER_ID = "UserID"
func ProtectedRoute(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := r.Header.Get("Authorization")
if len(token) < len("Bearer ") {
w.WriteHeader(http.StatusUnauthorized)
return
}
userId, err := GetUserIdFromAccess(token[len("Bearer "):])
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
return
}
contextWithUserId := context.WithValue(r.Context(), USER_ID, userId)
newR := r.WithContext(contextWithUserId)
next.ServeHTTP(w, newR)
})
}
func GetUserIdFromUrl(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := r.URL.Query().Get("token")
if len(token) == 0 {
w.WriteHeader(http.StatusUnauthorized)
return
}
userId, err := GetUserIdFromAccess(token)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
return
}
contextWithUserId := context.WithValue(r.Context(), USER_ID, userId)
newR := r.WithContext(contextWithUserId)
next.ServeHTTP(w, newR)
})
}