From a86addc8b209b9615b4c22bbe3e34a631867eb65 Mon Sep 17 00:00:00 2001 From: John Costa Date: Thu, 10 Apr 2025 15:35:35 +0100 Subject: [PATCH] feat(jwt): adding access and refresh token generation --- backend/.gen/haystack/haystack/model/users.go | 3 +- backend/.gen/haystack/haystack/table/users.go | 11 ++-- backend/auth.go | 13 ++++ backend/auth_test.go | 2 +- backend/email.go | 16 ++++- backend/go.mod | 1 + backend/go.sum | 2 + backend/jwt.go | 61 +++++++++++++++++++ backend/main.go | 56 +++++++++++++++++ backend/models/user.go | 9 +++ backend/schema.sql | 5 +- 11 files changed, 170 insertions(+), 9 deletions(-) create mode 100644 backend/jwt.go diff --git a/backend/.gen/haystack/haystack/model/users.go b/backend/.gen/haystack/haystack/model/users.go index a35be7f..2f417d8 100644 --- a/backend/.gen/haystack/haystack/model/users.go +++ b/backend/.gen/haystack/haystack/model/users.go @@ -12,5 +12,6 @@ import ( ) type Users struct { - ID uuid.UUID `sql:"primary_key"` + ID uuid.UUID `sql:"primary_key"` + Email string } diff --git a/backend/.gen/haystack/haystack/table/users.go b/backend/.gen/haystack/haystack/table/users.go index 73d5fd7..412b46f 100644 --- a/backend/.gen/haystack/haystack/table/users.go +++ b/backend/.gen/haystack/haystack/table/users.go @@ -17,7 +17,8 @@ type usersTable struct { postgres.Table // Columns - ID postgres.ColumnString + ID postgres.ColumnString + Email postgres.ColumnString AllColumns postgres.ColumnList MutableColumns postgres.ColumnList @@ -59,15 +60,17 @@ func newUsersTable(schemaName, tableName, alias string) *UsersTable { func newUsersTableImpl(schemaName, tableName, alias string) usersTable { var ( IDColumn = postgres.StringColumn("id") - allColumns = postgres.ColumnList{IDColumn} - mutableColumns = postgres.ColumnList{} + EmailColumn = postgres.StringColumn("email") + allColumns = postgres.ColumnList{IDColumn, EmailColumn} + mutableColumns = postgres.ColumnList{EmailColumn} ) return usersTable{ Table: postgres.NewTable(schemaName, tableName, alias, allColumns...), //Columns - ID: IDColumn, + ID: IDColumn, + Email: EmailColumn, AllColumns: allColumns, MutableColumns: mutableColumns, diff --git a/backend/auth.go b/backend/auth.go index dbaa14e..b3c372d 100644 --- a/backend/auth.go +++ b/backend/auth.go @@ -1,6 +1,8 @@ package main import ( + "errors" + "fmt" "math/rand" "time" ) @@ -42,6 +44,7 @@ func (a *Auth) CreateCode(email string) error { } func (a *Auth) IsCodeValid(email string, code string) bool { + fmt.Println(a.codes) existingCode, exists := a.codes[email] if !exists { return false @@ -50,6 +53,16 @@ func (a *Auth) IsCodeValid(email string, code string) bool { return existingCode.Valid.After(time.Now()) && existingCode.Code == code } +func (a *Auth) UseCode(email string, code string) error { + if valid := a.IsCodeValid(email, code); !valid { + fmt.Println("returning error?") + return errors.New("This code is invalid.") + } + + delete(a.codes, email) + return nil +} + func CreateAuth(mailer Mailer) Auth { return Auth{ codes: make(map[string]Code), diff --git a/backend/auth_test.go b/backend/auth_test.go index bf5e436..04519fb 100644 --- a/backend/auth_test.go +++ b/backend/auth_test.go @@ -18,7 +18,7 @@ var testMailer = TestMail{} func TestCreateCode(t *testing.T) { require := require.New(t) - auth := NewAuth(testMailer) + auth := CreateAuth(testMailer) err := auth.CreateCode("test") require.NoError(err) diff --git a/backend/email.go b/backend/email.go index 45c67c1..03c6fb9 100644 --- a/backend/email.go +++ b/backend/email.go @@ -1,6 +1,7 @@ package main import ( + "fmt" "os" "github.com/wneessen/go-mail" @@ -10,6 +11,8 @@ type MailClient struct { client *mail.Client } +type TestMailClient struct{} + type Mailer interface { SendCode(to string, code string) error } @@ -39,7 +42,18 @@ func (m MailClient) SendCode(to string, code string) error { return m.client.DialAndSend(msg) } -func CreateMailClient() (MailClient, error) { +func (m TestMailClient) SendCode(to string, code string) error { + fmt.Printf("Email: %s | Code %s\n", to, code) + + return nil +} + +func CreateMailClient() (Mailer, error) { + mode := os.Getenv("MODE") + if mode == "DEV" { + return TestMailClient{}, nil + } + client, err := mail.NewClient( "smtp.mailbox.org", mail.WithSMTPAuth(mail.SMTPAuthPlain), diff --git a/backend/go.mod b/backend/go.mod index 1a88d32..97f50b8 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -6,6 +6,7 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/go-chi/chi/v5 v5.2.1 // indirect github.com/go-jet/jet/v2 v2.12.0 // indirect + github.com/golang-jwt/jwt/v5 v5.2.2 // indirect github.com/google/uuid v1.6.0 // indirect github.com/joho/godotenv v1.5.1 // indirect github.com/lib/pq v1.10.9 // indirect diff --git a/backend/go.sum b/backend/go.sum index e2ec7d2..59f6368 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -4,6 +4,8 @@ github.com/go-chi/chi/v5 v5.2.1 h1:KOIHODQj58PmL80G2Eak4WdvUzjSJSm0vG72crDCqb8= github.com/go-chi/chi/v5 v5.2.1/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops= github.com/go-jet/jet/v2 v2.12.0 h1:z2JfvBAZgsfxlQz6NXBYdZTXc7ep3jhbszTLtETv1JE= github.com/go-jet/jet/v2 v2.12.0/go.mod h1:ufQVRQeI1mbcO5R8uCEVcVf3Foej9kReBdwDx7YMWUM= +github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= diff --git a/backend/jwt.go b/backend/jwt.go new file mode 100644 index 0000000..06b2d5c --- /dev/null +++ b/backend/jwt.go @@ -0,0 +1,61 @@ +package main + +import ( + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" +) + +type JwtType string + +const ( + Access JwtType = "access" + Refresh JwtType = "refresh" +) + +type JwtClaims struct { + UserID string + Type JwtType + Expire time.Time +} + +func createToken(claims JwtClaims) *jwt.Token { + return jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "UserID": claims.UserID, + "Type": claims.Type, + "Expire": claims.Expire, + }) +} + +func CreateRefreshToken(userId uuid.UUID) string { + token := createToken(JwtClaims{ + UserID: userId.String(), + Type: Refresh, + Expire: time.Now().Add(time.Hour * 24 * 7), + }) + + // TODO: bruh what is this + tokenString, err := token.SignedString([]byte("very secret")) + if err != nil { + panic(err) + } + + return tokenString +} + +func CreateAccessToken(userId uuid.UUID) string { + token := createToken(JwtClaims{ + UserID: userId.String(), + Type: Access, + Expire: time.Now().Add(time.Hour), + }) + + // TODO: bruh what is this + tokenString, err := token.SignedString([]byte("very secret")) + if err != nil { + panic(err) + } + + return tokenString +} diff --git a/backend/main.go b/backend/main.go index 9bf50bf..f8ab597 100644 --- a/backend/main.go +++ b/backend/main.go @@ -246,6 +246,62 @@ func main() { w.WriteHeader(http.StatusOK) }) + r.Post("/code", func(w http.ResponseWriter, r *http.Request) { + type CodeBody struct { + Email string `json:"email"` + Code string `json:"code"` + } + + type CodeReturn struct { + Access string `json:"access"` + Refresh string `json:"refresh"` + } + + codeBody := CodeBody{} + if err := json.NewDecoder(r.Body).Decode(&codeBody); err != nil { + log.Println(err) + w.WriteHeader(http.StatusBadRequest) + fmt.Fprintf(w, "Request body was not correct") + return + } + + if err := auth.UseCode(codeBody.Email, codeBody.Code); err != nil { + log.Println(err) + w.WriteHeader(http.StatusBadRequest) + fmt.Fprintf(w, "email or code are incorrect") + return + } + + uuid, err := userModel.GetUserIdFromEmail(r.Context(), codeBody.Email) + if err != nil { + log.Println(err) + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprintf(w, "Something went wrong.") + return + } + + refresh := CreateRefreshToken(uuid) + access := CreateAccessToken(uuid) + + codeReturn := CodeReturn{ + Access: access, + Refresh: refresh, + } + + json, err := json.Marshal(codeReturn) + if err != nil { + log.Println(err) + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprintf(w, "Something went wrong.") + return + } + + w.WriteHeader(http.StatusOK) + w.Header().Add("Content-Type", "application/json") + + fmt.Fprint(w, string(json)) + }) + log.Println("Listening and serving on port 3040.") if err := http.ListenAndServe(":3040", r); err != nil { log.Println(err) diff --git a/backend/models/user.go b/backend/models/user.go index 9c4a72a..ff72d24 100644 --- a/backend/models/user.go +++ b/backend/models/user.go @@ -104,6 +104,15 @@ func (m UserModel) ListWithProperties(ctx context.Context, userId uuid.UUID) ([] return images, err } +func (m UserModel) GetUserIdFromEmail(ctx context.Context, email string) (uuid.UUID, error) { + getUserIdStmt := Users.SELECT(Users.ID).WHERE(Users.Email.EQ(String(email))) + + user := model.Users{} + err := getUserIdStmt.QueryContext(ctx, m.dbPool, &user) + + return user.ID, err +} + func NewUserModel(db *sql.DB) UserModel { return UserModel{dbPool: db} } diff --git a/backend/schema.sql b/backend/schema.sql index d711395..1f21291 100644 --- a/backend/schema.sql +++ b/backend/schema.sql @@ -5,7 +5,8 @@ CREATE SCHEMA haystack; /* -----| Schema tables |----- */ CREATE TABLE haystack.users ( - id uuid PRIMARY KEY DEFAULT gen_random_uuid() + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + email TEXT NOT NULL ); CREATE TABLE haystack.image ( @@ -164,7 +165,7 @@ EXECUTE PROCEDURE notify_new_image(); /* -----| Test Data |----- */ -- Insert a user -INSERT INTO haystack.users (id) VALUES ('1db09f34-b155-4bf2-b606-dda25365fc89'); +INSERT INTO haystack.users (id, email) VALUES ('1db09f34-b155-4bf2-b606-dda25365fc89', 'me@email.com'); -- Insert images INSERT INTO haystack.image (id, image_name, image) VALUES