diff --git a/config.default.yaml b/config.default.yaml index 13b0f5e..962257f 100644 --- a/config.default.yaml +++ b/config.default.yaml @@ -1,4 +1,5 @@ server: + application: nixcn-cms address: :8000 external_url: https://example.com debug_mode: false @@ -22,5 +23,6 @@ secrets: jwt_secret: something turnstile_secret: something ttl: - magic_link_ttl: 15000 - jwt_ttl: 86400000 + magic_link_ttl: 10m + jwt_ttl: 15s + refresh_ttl: 48h diff --git a/config/types.go b/config/types.go index 8cd6829..4cec893 100644 --- a/config/types.go +++ b/config/types.go @@ -10,10 +10,11 @@ type config struct { } type server struct { - Address string `yaml:"address"` - DebugMode string `yaml:"debug_mode"` - FileLogger string `yaml:"file_logger"` - JwtSecret string `yaml:"jwt_secret"` + Application string `yaml:"application"` + Address string `yaml:"address"` + DebugMode string `yaml:"debug_mode"` + FileLogger string `yaml:"file_logger"` + JwtSecret string `yaml:"jwt_secret"` } type database struct { @@ -38,11 +39,12 @@ type email struct { } type secrets struct { - jwt_secret string `yaml:"jwt_secret"` - turnstile_secret string `yaml:"turnstile_secret"` + JwtSecret string `yaml:"jwt_secret"` + TurnstileSecret string `yaml:"turnstile_secret"` } type ttl struct { - magic_link_ttl string `yaml:"magic_link_ttl"` - jwt_ttl string `yaml:"jwt_ttl"` + MagicLinkTTL string `yaml:"magic_link_ttl"` + JwtTTL string `yaml:"jwt_ttl"` + RefreshTTL string `yaml:"refresh_ttl"` } diff --git a/data/data.go b/data/data.go index 22d32be..783bb34 100644 --- a/data/data.go +++ b/data/data.go @@ -9,7 +9,7 @@ import ( ) var Database *drivers.DBClient -var Redis *redis.UniversalClient +var Redis redis.UniversalClient func Init() { // Init database diff --git a/data/drivers/redis.go b/data/drivers/redis.go index 7e403ad..1997dcf 100644 --- a/data/drivers/redis.go +++ b/data/drivers/redis.go @@ -6,7 +6,7 @@ import ( "github.com/redis/go-redis/v9" ) -func Redis(dsn RedisDSN) (*redis.UniversalClient, error) { +func Redis(dsn RedisDSN) (redis.UniversalClient, error) { // Connect to Redis rdb := redis.NewUniversalClient(&redis.UniversalOptions{ Addrs: dsn.Hosts, @@ -18,5 +18,5 @@ func Redis(dsn RedisDSN) (*redis.UniversalClient, error) { ctx := context.Background() // Ping Redis _, err := rdb.Ping(ctx).Result() - return &rdb, err + return rdb, err } diff --git a/internal/crypto/jwt/jwt.go b/internal/crypto/jwt/jwt.go deleted file mode 100644 index 3df7505..0000000 --- a/internal/crypto/jwt/jwt.go +++ /dev/null @@ -1,77 +0,0 @@ -package jwt - -import ( - "net/http" - "strings" - "time" - - "github.com/gin-gonic/gin" - "github.com/golang-jwt/jwt/v5" - "github.com/google/uuid" - "github.com/spf13/viper" -) - -type Claims struct { - UserID uuid.UUID `json:"user_id"` - jwt.RegisteredClaims -} - -func JWTAuth() gin.HandlerFunc { - var JwtSecret = []byte(viper.GetString("secrets.jwt_secret")) - return func(c *gin.Context) { - auth := c.GetHeader("Authorization") - if auth == "" { - c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ - "error": "missing Authorization header", - }) - return - } - - parts := strings.SplitN(auth, " ", 2) - if len(parts) != 2 || parts[0] != "Bearer" { - c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ - "error": "invalid Authorization header format", - }) - return - } - - tokenStr := parts[1] - - token, err := jwt.ParseWithClaims(tokenStr, &Claims{}, func(token *jwt.Token) (interface{}, error) { - return JwtSecret, nil - }) - - if err != nil || !token.Valid { - c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ - "error": "invalid or expired token", - }) - return - } - - claims, ok := token.Claims.(*Claims) - if !ok { - c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ - "error": "invalid token claims", - }) - return - } - - c.Set("user_id", claims.UserID) - c.Next() - } -} - -func GenerateToken(userID uuid.UUID, application string) (string, error) { - var JwtSecret = []byte(viper.GetString("secrets.jwt_secret")) - claims := Claims{ - UserID: userID, - RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(viper.GetDuration("ttl.jwt_ttl"))), - IssuedAt: jwt.NewNumericDate(time.Now()), - Issuer: application, - }, - } - - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - return token.SignedString(JwtSecret) -} diff --git a/internal/crypto/jwt/jwt_test.go b/internal/crypto/jwt/jwt_test.go deleted file mode 100644 index 9f5ebc8..0000000 --- a/internal/crypto/jwt/jwt_test.go +++ /dev/null @@ -1,94 +0,0 @@ -package jwt - -import ( - "net/http" - "net/http/httptest" - "nixcn-cms/config" - "testing" - "time" - - "github.com/gin-gonic/gin" - "github.com/golang-jwt/jwt/v5" - "github.com/google/uuid" - "github.com/spf13/viper" -) - -func init() { - config.Init() -} - -func generateTestToken(userID uuid.UUID, expire time.Duration) string { - var JwtSecret = []byte(viper.GetString("server.jwt_secret")) - claims := Claims{ - UserID: userID, - RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(expire)), - }, - } - - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - tokenStr, _ := token.SignedString(JwtSecret) - return tokenStr -} -func TestJWTAuth_MissingToken(t *testing.T) { - gin.SetMode(gin.TestMode) - - r := gin.New() - r.Use(JWTAuth()) - r.GET("/test", func(c *gin.Context) { - c.JSON(200, gin.H{"ok": true}) - }) - - req := httptest.NewRequest(http.MethodGet, "/test", nil) - w := httptest.NewRecorder() - - r.ServeHTTP(w, req) - - if w.Code != http.StatusUnauthorized { - t.Fatalf("expected 401, got %d", w.Code) - } -} -func TestJWTAuth_InvalidToken(t *testing.T) { - gin.SetMode(gin.TestMode) - - r := gin.New() - r.Use(JWTAuth()) - r.GET("/test", func(c *gin.Context) { - c.JSON(200, gin.H{"ok": true}) - }) - - req := httptest.NewRequest(http.MethodGet, "/test", nil) - req.Header.Set("Authorization", "Bearer invalid.token.here") - w := httptest.NewRecorder() - - r.ServeHTTP(w, req) - - if w.Code != http.StatusUnauthorized { - t.Fatalf("expected 401, got %d", w.Code) - } -} -func TestJWTAuth_ValidToken(t *testing.T) { - gin.SetMode(gin.TestMode) - - r := gin.New() - r.Use(JWTAuth()) - r.GET("/test", func(c *gin.Context) { - userID := c.GetUint("user_id") - c.JSON(200, gin.H{ - "user_id": userID, - }) - }) - - uuid, _ := uuid.NewUUID() - token := generateTestToken(uuid, time.Hour) - - req := httptest.NewRequest(http.MethodGet, "/test", nil) - req.Header.Set("Authorization", "Bearer "+token) - w := httptest.NewRecorder() - - r.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("expected 200, got %d", w.Code) - } -} diff --git a/internal/cryptography/token.go b/internal/cryptography/token.go new file mode 100644 index 0000000..5986982 --- /dev/null +++ b/internal/cryptography/token.go @@ -0,0 +1,145 @@ +package cryptography + +import ( + "context" + "crypto/rand" + "encoding/base64" + "errors" + "nixcn-cms/data" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" + "github.com/redis/go-redis/v9" + "github.com/spf13/viper" +) + +type Token struct { + UserID uuid.UUID + Application string +} + +type JwtClaims struct { + UserID uuid.UUID `json:"user_id"` + jwt.RegisteredClaims +} + +// Generate jwt clames +func (self *Token) NewClaims() JwtClaims { + return JwtClaims{ + UserID: self.UserID, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(viper.GetDuration("ttl.jwt_ttl"))), + IssuedAt: jwt.NewNumericDate(time.Now()), + Issuer: self.Application, + }, + } +} + +// Generate access token +func (self *Token) GenerateAccessToken() (string, error) { + claims := self.NewClaims() + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + secret := viper.GetString("secrets.jwt_secret") + return token.SignedString(secret) +} + +// Generate refresh token +func (self *Token) GenerateRefreshToken() (string, error) { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// Issue both access and refresh token +func (self *Token) IssueTokens() (string, string, error) { + // Gen atk + access, err := self.GenerateAccessToken() + if err != nil { + return "", "", err + } + + // Gen rtk + refresh, err := self.GenerateRefreshToken() + if err != nil { + return "", "", err + } + + // Store to redis + ctx := context.Background() + ttl := viper.GetDuration("ttl.refresh_ttl") + + // refresh -> user + if err := data.Redis.Set( + ctx, + "refresh:"+refresh, + self.UserID.String(), + ttl, + ).Err(); err != nil { + return "", "", err + } + + // user -> refresh tokens + userSetKey := "user:" + self.UserID.String() + ":refresh_tokens" + + if err := data.Redis.SAdd( + ctx, + userSetKey, + refresh, + ).Err(); err != nil { + return "", "", err + } + + // set user ttl >= all refresh token + _ = data.Redis.Expire(ctx, userSetKey, ttl).Err() + + return access, refresh, nil +} + +// Refresh access token +func (self *Token) RefreshAccessToken(refreshToken string) (string, error) { + // Read rtk:userid from redis + ctx := context.Background() + key := "refresh:" + refreshToken + + userIDStr, err := data.Redis.Get(ctx, key).Result() + if err != nil { + if err == redis.Nil { + return "", errors.New("invalid refresh token") + } + return "", err + } + + userID, err := uuid.Parse(userIDStr) + if err != nil { + return "", err + } + + self.UserID = userID + + // Generate access token + return self.GenerateAccessToken() +} + +func (self *Token) RevokeRefreshToken(refreshToken string) error { + ctx := context.Background() + + key := "refresh:" + refreshToken + + userIDStr, err := data.Redis.Get(ctx, key).Result() + if err != nil { + return nil + } + + userSetKey := "user:" + userIDStr + ":refresh_tokens" + + // Delete rtk from redis + pipe := data.Redis.TxPipeline() + pipe.Del(ctx, key) // rtk:userid index + pipe.SRem(ctx, userSetKey, refreshToken) // userid:rtk index + _, err = pipe.Exec(ctx) + + return err +} diff --git a/middleware/jwt.go b/middleware/jwt.go new file mode 100644 index 0000000..ca6351e --- /dev/null +++ b/middleware/jwt.go @@ -0,0 +1,57 @@ +package middleware + +import ( + "net/http" + "strings" + + "nixcn-cms/internal/cryptography" + + "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt/v5" + "github.com/spf13/viper" +) + +func JWTAuth() gin.HandlerFunc { + jwtSecret := []byte(viper.GetString("secrets.jwt_secret")) + + return func(c *gin.Context) { + auth := c.GetHeader("Authorization") + if auth == "" { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ + "error": "missing Authorization header", + }) + return + } + + // Split header to 2 + parts := strings.SplitN(auth, " ", 2) + if len(parts) != 2 || parts[0] != "Bearer" { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ + "error": "invalid Authorization header format", + }) + return + } + + tokenStr := parts[1] + + // Verify access token + claims := &cryptography.JwtClaims{} + token, err := jwt.ParseWithClaims( + tokenStr, + claims, + func(token *jwt.Token) (any, error) { + return jwtSecret, nil + }, + ) + + if err != nil || !token.Valid { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ + "error": "invalid or expired token", + }) + return + } + + c.Set("user_id", claims.UserID) + c.Next() + } +} diff --git a/service/auth/handler.go b/service/auth/handler.go index c1421e2..0dff602 100644 --- a/service/auth/handler.go +++ b/service/auth/handler.go @@ -5,4 +5,5 @@ import "github.com/gin-gonic/gin" func Handler(r *gin.RouterGroup) { r.POST("/magic", RequestMagicLink) r.GET("/magic/verify", VerifyMagicLink) + r.POST("/refresh", Refresh) } diff --git a/service/auth/magic.go b/service/auth/magic.go index 5e12b5d..61f37b7 100644 --- a/service/auth/magic.go +++ b/service/auth/magic.go @@ -3,7 +3,7 @@ package auth import ( "net/http" "nixcn-cms/data" - "nixcn-cms/internal/crypto/jwt" + "nixcn-cms/internal/cryptography" "nixcn-cms/pkgs/email" "nixcn-cms/pkgs/magiclink" "nixcn-cms/pkgs/turnstile" @@ -61,14 +61,14 @@ func RequestMagicLink(c *gin.Context) { func VerifyMagicLink(c *gin.Context) { // Get token from url - token := c.Query("token") - if token == "" { + magicToken := c.Query("token") + if magicToken == "" { c.JSON(400, gin.H{"error": "missing token"}) return } // Verify email token - email, ok := magiclink.VerifyMagicToken(token) + email, ok := magiclink.VerifyMagicToken(magicToken) if !ok { c.JSON(401, gin.H{"error": "invalid or expired token"}) return @@ -80,10 +80,19 @@ func VerifyMagicLink(c *gin.Context) { if err != nil { c.JSON(http.StatusUnauthorized, gin.H{"status": "user not found"}) } - jwtToken, _ := jwt.GenerateToken(userInfo.UserId, "application") + JwtTool := cryptography.Token{ + UserID: userInfo.UserId, + Application: viper.GetString("server.application"), + } + accessToken, refreshToken, err := JwtTool.IssueTokens() + if err != nil { + c.JSON(500, gin.H{ + "status": "error generating tokens", + }) + } c.JSON(200, gin.H{ - "jwt_token": jwtToken, - "email": email, + "access_token": accessToken, + "refresh_token": refreshToken, }) } diff --git a/service/auth/refresh.go b/service/auth/refresh.go new file mode 100644 index 0000000..1505b20 --- /dev/null +++ b/service/auth/refresh.go @@ -0,0 +1,34 @@ +package auth + +import ( + "net/http" + "nixcn-cms/internal/cryptography" + + "github.com/gin-gonic/gin" + "github.com/spf13/viper" +) + +func Refresh(c *gin.Context) { + var req struct { + RefreshToken string `json:"refresh_token"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"}) + return + } + + JwtTool := cryptography.Token{ + Application: viper.GetString("server.application"), + } + + access, err := JwtTool.RefreshAccessToken(req.RefreshToken) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid refresh token"}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "access_token": access, + }) +} diff --git a/service/checkin/handler.go b/service/checkin/handler.go index 8e6c48d..31afa40 100644 --- a/service/checkin/handler.go +++ b/service/checkin/handler.go @@ -1,12 +1,12 @@ package checkin import ( - "nixcn-cms/internal/crypto/jwt" + "nixcn-cms/middleware" "github.com/gin-gonic/gin" ) func Handler(r *gin.RouterGroup) { - r.Use(jwt.JWTAuth()) + r.Use(middleware.JWTAuth()) r.POST("", Checkin) } diff --git a/service/user/handler.go b/service/user/handler.go index a315d98..8c9e333 100644 --- a/service/user/handler.go +++ b/service/user/handler.go @@ -1,12 +1,12 @@ package user import ( - "nixcn-cms/internal/crypto/jwt" + "nixcn-cms/middleware" "github.com/gin-gonic/gin" ) func Handler(r *gin.RouterGroup) { - r.Use(jwt.JWTAuth()) + r.Use(middleware.JWTAuth()) r.GET("/info", UserInfo) }