From 3e57d103606819931e1439726dcebff4e8d775ca Mon Sep 17 00:00:00 2001 From: John Costa Date: Tue, 19 Aug 2025 21:27:37 +0100 Subject: [PATCH] a good start --- backend/events.go | 9 ++-- backend/main.go | 28 ++++++----- backend/middleware/json.go | 11 +++++ backend/{ => middleware}/jwt.go | 2 +- backend/{ => middleware}/middleware.go | 32 +++++++++++- backend/stacks/handler.go | 67 ++++++++++++++++++++++++++ 6 files changed, 131 insertions(+), 18 deletions(-) create mode 100644 backend/middleware/json.go rename backend/{ => middleware}/jwt.go (99%) rename backend/{ => middleware}/middleware.go (63%) create mode 100644 backend/stacks/handler.go diff --git a/backend/events.go b/backend/events.go index bda6c0f..5f983d9 100644 --- a/backend/events.go +++ b/backend/events.go @@ -8,6 +8,7 @@ import ( "net/http" "os" "screenmark/screenmark/agents" + "screenmark/screenmark/middleware" "screenmark/screenmark/models" "strconv" "sync" @@ -71,13 +72,15 @@ func ListenNewImageEvents(db *sql.DB, notifier *Notifier[Notification]) { wg.Add(2) go func() { + defer wg.Done() + descriptionAgent.Describe(createLogger("Description 📓", splitWriter), image.Image.ID, image.Image.ImageName, image.Image.Image) - wg.Done() }() go func() { + defer wg.Done() + listAgent.RunAgent(image.UserID, image.ImageID, image.Image.ImageName, image.Image.Image) - wg.Done() }() wg.Wait() @@ -149,7 +152,7 @@ func CreateEventsHandler(notifier *Notifier[Notification]) http.HandlerFunc { userSplitters := make(map[string]*ChannelSplitter[Notification]) return func(w http.ResponseWriter, r *http.Request) { - _userId := r.Context().Value(USER_ID).(uuid.UUID) + _userId := r.Context().Value(middleware.USER_ID).(uuid.UUID) if _userId == uuid.Nil { w.WriteHeader(http.StatusUnauthorized) return diff --git a/backend/main.go b/backend/main.go index 9720ac5..9e1bd4c 100644 --- a/backend/main.go +++ b/backend/main.go @@ -12,6 +12,9 @@ import ( "screenmark/screenmark/.gen/haystack/haystack/model" "screenmark/screenmark/agents/client" "screenmark/screenmark/models" + "screenmark/screenmark/stacks" + + ourmiddleware "screenmark/screenmark/middleware" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" @@ -41,6 +44,8 @@ func main() { imageModel := models.NewImageModel(db) userModel := models.NewUserModel(db) + stackHandler := stacks.CreateStackHandler(db) + mail, err := CreateMailClient() if err != nil { panic(err) @@ -56,10 +61,9 @@ func main() { r := chi.NewRouter() r.Use(middleware.Logger) - r.Use(CorsMiddleware) - r.Options("/*", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - }) + r.Use(ourmiddleware.CorsMiddleware) + + r.Route("/stacks", stackHandler.CreateRoutes) // Temporarily not in protect route because we aren't using cookies. // Therefore they don't get automatically attached to the request. @@ -102,7 +106,7 @@ func main() { }) r.Group(func(r chi.Router) { - r.Use(ProtectedRoute) + r.Use(ourmiddleware.ProtectedRoute) r.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Add("Content-Type", "application/json") @@ -112,7 +116,7 @@ func main() { }) r.Get("/image", func(w http.ResponseWriter, r *http.Request) { - userId := r.Context().Value(USER_ID).(uuid.UUID) + userId := r.Context().Value(ourmiddleware.USER_ID).(uuid.UUID) if err != nil { w.WriteHeader(http.StatusForbidden) fmt.Fprintf(w, "You cannot read this") @@ -168,7 +172,7 @@ func main() { r.Post("/image/{name}", func(w http.ResponseWriter, r *http.Request) { imageName := r.PathValue("name") - userId := r.Context().Value(USER_ID).(uuid.UUID) + userId := r.Context().Value(ourmiddleware.USER_ID).(uuid.UUID) if len(imageName) == 0 { w.WriteHeader(http.StatusBadRequest) @@ -256,7 +260,7 @@ func main() { }) r.Route("/notifications", func(r chi.Router) { - r.Use(GetUserIdFromUrl) + r.Use(ourmiddleware.GetUserIdFromUrl) r.Get("/", CreateEventsHandler(¬ifier)) }) @@ -322,8 +326,8 @@ func main() { return } - refresh := CreateRefreshToken(uuid) - access := CreateAccessToken(uuid) + refresh := ourmiddleware.CreateRefreshToken(uuid) + access := ourmiddleware.CreateAccessToken(uuid) codeReturn := CodeReturn{ Access: access, @@ -355,8 +359,8 @@ func main() { return } - refresh := CreateRefreshToken(uuid) - access := CreateAccessToken(uuid) + refresh := ourmiddleware.CreateRefreshToken(uuid) + access := ourmiddleware.CreateAccessToken(uuid) codeReturn := CodeReturn{ Access: access, diff --git a/backend/middleware/json.go b/backend/middleware/json.go new file mode 100644 index 0000000..96ad635 --- /dev/null +++ b/backend/middleware/json.go @@ -0,0 +1,11 @@ +package middleware + +import "net/http" + +func SetJson(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("Content-Type", "application/json") + + next.ServeHTTP(w, r) + }) +} diff --git a/backend/jwt.go b/backend/middleware/jwt.go similarity index 99% rename from backend/jwt.go rename to backend/middleware/jwt.go index 9b76eff..f8865a3 100644 --- a/backend/jwt.go +++ b/backend/middleware/jwt.go @@ -1,4 +1,4 @@ -package main +package middleware import ( "errors" diff --git a/backend/middleware.go b/backend/middleware/middleware.go similarity index 63% rename from backend/middleware.go rename to backend/middleware/middleware.go index 9f9e00c..bb8754d 100644 --- a/backend/middleware.go +++ b/backend/middleware/middleware.go @@ -1,22 +1,50 @@ -package main +package middleware import ( "context" + "errors" + "fmt" "net/http" + + "github.com/google/uuid" ) func CorsMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Add("Access-Control-Allow-Origin", "*") - w.Header().Add("Access-Control-Allow-Credentials", "*") w.Header().Add("Access-Control-Allow-Headers", "*") + // Access-Control-Allow-Methods is often needed for preflight OPTIONS requests + w.Header().Add("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") + + // The client makes an OPTIONS preflight request before a complex request. + // We must handle this and respond with the appropriate headers. + if r.Method == "OPTIONS" { + w.WriteHeader(http.StatusOK) + return + } + next.ServeHTTP(w, r) }) } const USER_ID = "UserID" +func GetUserID(ctx context.Context) (uuid.UUID, error) { + userId := ctx.Value(USER_ID) + + if userId == nil { + return uuid.Nil, errors.New("context does not contain a user id") + } + + userIdUuid, ok := userId.(uuid.UUID) + if !ok { + return uuid.Nil, fmt.Errorf("context user id is not of type uuid, got: %t", userId) + } + + return userIdUuid, nil +} + func ProtectedRoute(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { token := r.Header.Get("Authorization") diff --git a/backend/stacks/handler.go b/backend/stacks/handler.go new file mode 100644 index 0000000..315d339 --- /dev/null +++ b/backend/stacks/handler.go @@ -0,0 +1,67 @@ +package stacks + +import ( + "database/sql" + "encoding/json" + "net/http" + "os" + "screenmark/screenmark/middleware" + "screenmark/screenmark/models" + + "github.com/charmbracelet/log" + "github.com/go-chi/chi/v5" +) + +type StackHandler struct { + logger *log.Logger + stackModel models.ListModel +} + +func (h *StackHandler) getAllStacks(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + userId, err := middleware.GetUserID(ctx) + if err != nil { + h.logger.Warn("could not get users in get all stacks", "err", err) + w.WriteHeader(http.StatusUnauthorized) + return + } + + lists, err := h.stackModel.List(ctx, userId) + if err != nil { + h.logger.Warn("could not get stacks", "err", err) + w.WriteHeader(http.StatusBadRequest) + return + } + + jsonLists, err := json.Marshal(lists) + if err != nil { + h.logger.Warn("could not marshal json lists", "err", err) + w.WriteHeader(http.StatusBadRequest) + return + } + + w.Write(jsonLists) + w.WriteHeader(http.StatusOK) +} + +func (h *StackHandler) CreateRoutes(r chi.Router) { + h.logger.Info("Mounting stack router") + + r.Group(func(r chi.Router) { + r.Use(middleware.ProtectedRoute) + r.Use(middleware.SetJson) + + r.Get("/", h.getAllStacks) + }) +} + +func CreateStackHandler(db *sql.DB) StackHandler { + stackModel := models.NewListModel(db) + logger := log.New(os.Stdout).WithPrefix("Stacks") + + return StackHandler{ + logger, + stackModel, + } +}