diff --git a/chirp.go b/chirp.go index cb8f11f..fb77cff 100644 --- a/chirp.go +++ b/chirp.go @@ -9,6 +9,7 @@ import ( "strings" "time" + "github.com/finchrelia/chirpy-server/internal/auth" "github.com/finchrelia/chirpy-server/internal/database" "github.com/google/uuid" ) @@ -23,16 +24,23 @@ type Chirp struct { func (cfg *apiConfig) chirpsCreate(w http.ResponseWriter, r *http.Request) { type parameters struct { - Content string `json:"body"` - UserID uuid.UUID `json:"user_id"` + Content string `json:"body"` } - type returnVals struct { - Data string `json:"cleaned_body"` + token, err := auth.GetBearerToken(r.Header) + if err != nil { + log.Printf("Error extracting token: %s", err) + w.WriteHeader(401) + return + } + userId, err := auth.ValidateJWT(token, cfg.JWT) + if err != nil { + log.Printf("Invalid JWT: %s", err) + w.WriteHeader(401) + return } - decoder := json.NewDecoder(r.Body) params := parameters{} - err := decoder.Decode(¶ms) + err = decoder.Decode(¶ms) if err != nil { log.Printf("Error decoding parameters: %s", err) w.WriteHeader(500) @@ -56,12 +64,9 @@ func (cfg *apiConfig) chirpsCreate(w http.ResponseWriter, r *http.Request) { w.Write(dat) } else { cleanedData := cleanText(params.Content) - // respBody := returnVals{ - // Data: cleanedData, - // } chirp, err := cfg.DB.CreateChirp(r.Context(), database.CreateChirpParams{ Body: cleanedData, - UserID: params.UserID, + UserID: userId, }) if err != nil { log.Printf("Error creating chirp: %s", err) diff --git a/go.mod b/go.mod index b819bfc..d030d9b 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/finchrelia/chirpy-server go 1.22.5 require ( + github.com/golang-jwt/jwt/v5 v5.2.1 // indirect github.com/google/uuid v1.6.0 // indirect github.com/joho/godotenv v1.5.1 // indirect github.com/lib/pq v1.10.9 // indirect diff --git a/go.sum b/go.sum index 14aad1d..16482d3 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 44e8eb2..d8b16ea 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -1,6 +1,17 @@ package auth -import "golang.org/x/crypto/bcrypt" +import ( + "crypto/rand" + "encoding/hex" + "errors" + "net/http" + "strings" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" + "golang.org/x/crypto/bcrypt" +) func HashPassword(password string) (string, error) { hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), 10) @@ -10,3 +21,63 @@ func HashPassword(password string) (string, error) { func CheckPasswordHash(password, hash string) error { return bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) } + +func MakeJWT(userID uuid.UUID, tokenSecret string) (string, error) { + newToken := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ + Issuer: "chirpy", + IssuedAt: jwt.NewNumericDate(time.Now()), + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), + Subject: userID.String(), + }) + token, err := newToken.SignedString([]byte(tokenSecret)) + if err != nil { + return "", err + } + return token, nil +} + +func ValidateJWT(tokenString, tokenSecret string) (uuid.UUID, error) { + + token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(t *jwt.Token) (interface{}, error) { + return []byte(tokenSecret), nil + }) + if err != nil { + return uuid.UUID{}, err + } + if !token.Valid { + return uuid.UUID{}, errors.New("token has expired") + } + + userIDString, err := token.Claims.GetSubject() + if err != nil { + return uuid.UUID{}, err + } + userID, err := uuid.Parse(userIDString) + if err != nil { + return uuid.UUID{}, err + } + + return userID, nil +} + +func GetBearerToken(headers http.Header) (string, error) { + authHeader := headers.Get("Authorization") + if authHeader == "" { + return "", errors.New("authorization header is not set") + } + if !strings.HasPrefix(authHeader, "Bearer ") { + return "", errors.New("incorrect authorization type, must be of type Bearer") + } + bearerToken := strings.TrimPrefix(authHeader, "Bearer ") + return strings.TrimSpace(bearerToken), nil +} + +func MakeRefreshToken() (string, error) { + buffer := make([]byte, 32) + _, err := rand.Read(buffer) + if err != nil { + return "", err + } + hexData := hex.EncodeToString(buffer) + return hexData, nil +} diff --git a/internal/database/models.go b/internal/database/models.go index d77629c..7cc8f3a 100644 --- a/internal/database/models.go +++ b/internal/database/models.go @@ -5,6 +5,7 @@ package database import ( + "database/sql" "time" "github.com/google/uuid" @@ -18,6 +19,15 @@ type Chirp struct { UserID uuid.UUID } +type RefreshToken struct { + Token string + CreatedAt time.Time + UpdatedAt time.Time + UserID uuid.UUID + ExpiresAt sql.NullTime + RevokedAt sql.NullTime +} + type User struct { ID uuid.UUID CreatedAt time.Time diff --git a/internal/database/refresh_token.sql.go b/internal/database/refresh_token.sql.go new file mode 100644 index 0000000..4e38bd5 --- /dev/null +++ b/internal/database/refresh_token.sql.go @@ -0,0 +1,46 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.27.0 +// source: refresh_token.sql + +package database + +import ( + "context" + "database/sql" + + "github.com/google/uuid" +) + +const createRefreshToken = `-- name: CreateRefreshToken :one +INSERT INTO refresh_tokens (token, created_at, updated_at, user_id, expires_at, revoked_at) +VALUES ( + $1, + NOW(), + NOW(), + $2, + $3, + NULL +) +RETURNING token, created_at, updated_at, user_id, expires_at, revoked_at +` + +type CreateRefreshTokenParams struct { + Token string + UserID uuid.UUID + ExpiresAt sql.NullTime +} + +func (q *Queries) CreateRefreshToken(ctx context.Context, arg CreateRefreshTokenParams) (RefreshToken, error) { + row := q.db.QueryRowContext(ctx, createRefreshToken, arg.Token, arg.UserID, arg.ExpiresAt) + var i RefreshToken + err := row.Scan( + &i.Token, + &i.CreatedAt, + &i.UpdatedAt, + &i.UserID, + &i.ExpiresAt, + &i.RevokedAt, + ) + return i, err +} diff --git a/internal/database/update_token.sql.go b/internal/database/update_token.sql.go new file mode 100644 index 0000000..3a4ac72 --- /dev/null +++ b/internal/database/update_token.sql.go @@ -0,0 +1,23 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.27.0 +// source: update_token.sql + +package database + +import ( + "context" +) + +const revokeRefreshToken = `-- name: RevokeRefreshToken :exec +UPDATE refresh_tokens +SET + revoked_at = NOW(), + updated_at = NOW() +WHERE token = $1 +` + +func (q *Queries) RevokeRefreshToken(ctx context.Context, token string) error { + _, err := q.db.ExecContext(ctx, revokeRefreshToken, token) + return err +} diff --git a/internal/database/user_from_token.sql.go b/internal/database/user_from_token.sql.go new file mode 100644 index 0000000..92206ac --- /dev/null +++ b/internal/database/user_from_token.sql.go @@ -0,0 +1,26 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.27.0 +// source: user_from_token.sql + +package database + +import ( + "context" + + "github.com/google/uuid" +) + +const getUserFromRefreshToken = `-- name: GetUserFromRefreshToken :one +SELECT user_id FROM refresh_tokens +WHERE refresh_tokens.token = $1 +AND refresh_tokens.expires_at > NOW() +AND refresh_tokens.revoked_at IS NULL +` + +func (q *Queries) GetUserFromRefreshToken(ctx context.Context, token string) (uuid.UUID, error) { + row := q.db.QueryRowContext(ctx, getUserFromRefreshToken, token) + var user_id uuid.UUID + err := row.Scan(&user_id) + return user_id, err +} diff --git a/internal/database/users.sql.go b/internal/database/users.sql.go index e5b90eb..e2693f1 100644 --- a/internal/database/users.sql.go +++ b/internal/database/users.sql.go @@ -16,13 +16,18 @@ VALUES ( NOW(), NOW(), $1, - $2 + $2 ) RETURNING id, created_at, updated_at, email, hashed_password ` -func (q *Queries) CreateUser(ctx context.Context, email string, hashed_password string) (User, error) { - row := q.db.QueryRowContext(ctx, createUser, email, hashed_password) +type CreateUserParams struct { + Email string + HashedPassword string +} + +func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, error) { + row := q.db.QueryRowContext(ctx, createUser, arg.Email, arg.HashedPassword) var i User err := row.Scan( &i.ID, diff --git a/login.go b/login.go index c459d41..4801327 100644 --- a/login.go +++ b/login.go @@ -1,11 +1,15 @@ package main import ( + "database/sql" "encoding/json" "log" "net/http" + "time" "github.com/finchrelia/chirpy-server/internal/auth" + "github.com/finchrelia/chirpy-server/internal/database" + "github.com/google/uuid" ) func (cfg *apiConfig) Login(w http.ResponseWriter, r *http.Request) { @@ -22,9 +26,10 @@ func (cfg *apiConfig) Login(w http.ResponseWriter, r *http.Request) { w.WriteHeader(401) return } + loggedUser, err := cfg.DB.GetUserByEmail(r.Context(), p.Email) if err != nil { - log.Printf("Error retrieving user: %s", err) + log.Printf("Error retrieving user: %v", err) } err = auth.CheckPasswordHash(p.Password, loggedUser.HashedPassword) @@ -34,11 +39,46 @@ func (cfg *apiConfig) Login(w http.ResponseWriter, r *http.Request) { return } - data, err := json.Marshal(User{ - ID: loggedUser.ID, - CreatedAt: loggedUser.CreatedAt, - UpdatedAt: loggedUser.UpdatedAt, - Email: loggedUser.Email, + newJwt, err := auth.MakeJWT(loggedUser.ID, cfg.JWT) + if err != nil { + log.Printf("Error creating JWT: %s", newJwt) + w.WriteHeader(500) + return + } + newRefreshToken, err := auth.MakeRefreshToken() + if err != nil { + log.Printf("Error creating refresh token: %v", err) + w.WriteHeader(500) + return + } + + refreshTokenParams := database.CreateRefreshTokenParams{ + Token: newRefreshToken, + UserID: loggedUser.ID, + ExpiresAt: sql.NullTime{Time: time.Now().AddDate(0, 0, 60), Valid: true}, + } + _, err = cfg.DB.CreateRefreshToken(r.Context(), refreshTokenParams) + if err != nil { + log.Printf("Error adding refresh token to db: %s", err) + w.WriteHeader(500) + return + } + type loginResponse struct { + ID uuid.UUID `json:"id"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + Email string `json:"email"` + AccessToken string `json:"token"` + RefreshToken string `json:"refresh_token"` + } + + data, err := json.Marshal(loginResponse{ + ID: loggedUser.ID, + CreatedAt: loggedUser.CreatedAt, + UpdatedAt: loggedUser.UpdatedAt, + Email: loggedUser.Email, + AccessToken: newJwt, + RefreshToken: newRefreshToken, }) if err != nil { log.Printf("Error marshalling JSON: %s", err) diff --git a/main.go b/main.go index 93a1225..39e9d52 100644 --- a/main.go +++ b/main.go @@ -16,6 +16,7 @@ type apiConfig struct { fileserverHits atomic.Int32 DB *database.Queries Platform string + JWT string } func main() { @@ -28,6 +29,10 @@ func main() { if platform == "" { log.Fatalf("Empty PLATFORM env var!") } + jwtSecret := os.Getenv("JWT_SECRET") + if jwtSecret == "" { + log.Fatalf("Empty JWT_SECRET env var!") + } db, err := sql.Open("postgres", dbURL) if err != nil { log.Fatalf("Unable to connect to db: %s", err) @@ -36,6 +41,7 @@ func main() { fileserverHits: atomic.Int32{}, DB: database.New(db), Platform: platform, + JWT: jwtSecret, } mux := http.NewServeMux() fsHandler := apiCfg.middlewareMetricsInc(http.StripPrefix("/app", http.FileServer(http.Dir(".")))) @@ -52,6 +58,8 @@ func main() { mux.HandleFunc("POST /api/users", apiCfg.createUsers) mux.HandleFunc("GET /api/chirps/{chirpID}", apiCfg.getChirp) mux.HandleFunc("POST /api/login", apiCfg.Login) + mux.HandleFunc("POST /api/refresh", apiCfg.RefreshToken) + mux.HandleFunc("POST /api/revoke", apiCfg.RevokeToken) server := &http.Server{ Addr: ":8080", diff --git a/sql/queries/refresh_token.sql b/sql/queries/refresh_token.sql new file mode 100644 index 0000000..ca40189 --- /dev/null +++ b/sql/queries/refresh_token.sql @@ -0,0 +1,11 @@ +-- name: CreateRefreshToken :one +INSERT INTO refresh_tokens (token, created_at, updated_at, user_id, expires_at, revoked_at) +VALUES ( + $1, + NOW(), + NOW(), + $2, + $3, + NULL +) +RETURNING *; \ No newline at end of file diff --git a/sql/queries/update_token.sql b/sql/queries/update_token.sql new file mode 100644 index 0000000..6000e9d --- /dev/null +++ b/sql/queries/update_token.sql @@ -0,0 +1,6 @@ +-- name: RevokeRefreshToken :exec +UPDATE refresh_tokens +SET + revoked_at = NOW(), + updated_at = NOW() +WHERE token = $1; \ No newline at end of file diff --git a/sql/queries/user_from_token.sql b/sql/queries/user_from_token.sql new file mode 100644 index 0000000..6d7f8d8 --- /dev/null +++ b/sql/queries/user_from_token.sql @@ -0,0 +1,5 @@ +-- name: GetUserFromRefreshToken :one +SELECT user_id FROM refresh_tokens +WHERE refresh_tokens.token = $1 +AND refresh_tokens.expires_at > NOW() +AND refresh_tokens.revoked_at IS NULL; \ No newline at end of file diff --git a/sql/queries/users.sql b/sql/queries/users.sql index b520299..1f63c63 100644 --- a/sql/queries/users.sql +++ b/sql/queries/users.sql @@ -1,10 +1,11 @@ -- name: CreateUser :one -INSERT INTO users (id, created_at, updated_at, email) +INSERT INTO users (id, created_at, updated_at, email, hashed_password) VALUES ( gen_random_uuid(), NOW(), NOW(), - $1 + $1, + $2 ) RETURNING *; diff --git a/sql/schema/004_refresh_token.sql b/sql/schema/004_refresh_token.sql new file mode 100644 index 0000000..93212c9 --- /dev/null +++ b/sql/schema/004_refresh_token.sql @@ -0,0 +1,12 @@ +-- +goose Up +CREATE TABLE refresh_tokens ( + token TEXT PRIMARY KEY, + created_at TIMESTAMP NOT NULL, + updated_at TIMESTAMP NOT NULL, + user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + expires_at TIMESTAMP, + revoked_at TIMESTAMP +); + +-- +goose Down +DROP TABLE refresh_tokens; diff --git a/token.go b/token.go new file mode 100644 index 0000000..5613e01 --- /dev/null +++ b/token.go @@ -0,0 +1,62 @@ +package main + +import ( + "encoding/json" + "log" + "net/http" + + "github.com/finchrelia/chirpy-server/internal/auth" +) + +func (cfg *apiConfig) RefreshToken(w http.ResponseWriter, r *http.Request) { + token, err := auth.GetBearerToken(r.Header) + if err != nil { + log.Printf("Error extracting token: %s", err) + w.WriteHeader(401) + return + } + + dbUser, err := cfg.DB.GetUserFromRefreshToken(r.Context(), token) + if err != nil { + log.Printf("Error getting user: %v", err) + w.WriteHeader(401) + return + } + newToken, err := auth.MakeJWT(dbUser, cfg.JWT) + if err != nil { + log.Printf("Error creating new JWT: %v", err) + w.WriteHeader(500) + return + } + type tokenResponse struct { + AccessToken string `json:"token"` + } + + data, err := json.Marshal(tokenResponse{ + AccessToken: newToken, + }) + if err != nil { + log.Printf("Error marshalling JSON: %s", err) + w.WriteHeader(500) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + w.Write(data) +} + +func (cfg *apiConfig) RevokeToken(w http.ResponseWriter, r *http.Request) { + token, err := auth.GetBearerToken(r.Header) + if err != nil { + log.Printf("Error extracting token: %s", err) + w.WriteHeader(401) + return + } + err = cfg.DB.RevokeRefreshToken(r.Context(), token) + if err != nil { + log.Printf("Error revoking token in database: %v", err) + w.WriteHeader(500) + return + } + w.WriteHeader(204) +} diff --git a/users.go b/users.go index 4f1e7ee..0c30cc8 100644 --- a/users.go +++ b/users.go @@ -7,6 +7,7 @@ import ( "time" "github.com/finchrelia/chirpy-server/internal/auth" + "github.com/finchrelia/chirpy-server/internal/database" "github.com/google/uuid" ) @@ -38,7 +39,10 @@ func (cfg *apiConfig) createUsers(w http.ResponseWriter, r *http.Request) { if err != nil { log.Printf("Error hashing password: %s", err) } - newDBUser, err := cfg.DB.CreateUser(r.Context(), params.Email, hashedPassword) + newDBUser, err := cfg.DB.CreateUser(r.Context(), database.CreateUserParams{ + Email: params.Email, + HashedPassword: hashedPassword, + }) if err != nil { log.Printf("Error creating user %s: %s", params.Email, err) w.WriteHeader(500)