From d314942c08d878ad6ae99224655b5f617e826774 Mon Sep 17 00:00:00 2001 From: Asai Neko Date: Tue, 23 Dec 2025 18:11:31 +0800 Subject: [PATCH] Add jwt crypto module, support unit test for config module Signed-off-by: Asai Neko --- .env.development | 1 + config/config.go | 22 ++++---- config/default.go | 1 + config/env.go | 2 +- config/types.go | 1 + data/user.go | 7 +++ go.mod | 1 + go.sum | 2 + internal/crypto/jwt/jwt.go | 77 ++++++++++++++++++++++++++ internal/crypto/jwt/jwt_test.go | 95 +++++++++++++++++++++++++++++++++ justfile | 2 + service/check/checkin.go | 5 ++ service/check/handler.go | 7 ++- 13 files changed, 212 insertions(+), 11 deletions(-) create mode 100644 internal/crypto/jwt/jwt.go create mode 100644 internal/crypto/jwt/jwt_test.go create mode 100644 service/check/checkin.go diff --git a/.env.development b/.env.development index f14e924..5904d14 100644 --- a/.env.development +++ b/.env.development @@ -1,6 +1,7 @@ SERVER_ADDRESS=:8000 SERVER_DEBUG_MODE=true SERVER_FILE_LOGGER=false +SERVER_JWT_SECRET=test DATABASE_TYPE=postgres DATABASE_HOST=127.0.0.1 DATABASE_NAME=postgres diff --git a/config/config.go b/config/config.go index e0b0c93..41d52a7 100644 --- a/config/config.go +++ b/config/config.go @@ -2,17 +2,29 @@ package config import ( "log" + "os" "github.com/spf13/viper" ) func Init() { + // Set config path by env + confPath := os.Getenv("CONFIG_PATH") + if confPath == "" { + confPath = "config.yaml" + } + // Read global config - viper.SetConfigFile("config.yaml") + viper.SetConfigFile(confPath) viper.SetDefault("Server", serverDef) viper.SetDefault("Database", databaseDef) conf := &config{} if err := viper.ReadInConfig(); err != nil { + // Dont generate config when using dev mode + if os.Getenv("GO_ENV") == "test" || os.Getenv("CONFIG_PATH") != "" { + log.Fatalf("[Config] failed to read config %s: %v", confPath, err) + } + log.Println("Can't read config, trying to modify!") if err := viper.WriteConfig(); err != nil { log.Fatal("[Config] Error writing config: ", err) @@ -24,17 +36,9 @@ func Init() { } func Get(key string) any { - viper.SetConfigFile("config.yaml") - if err := viper.ReadInConfig(); err != nil { - log.Fatal("[Config] Error reading config: ", err) - } return viper.Get(key) } func Set(key string, value any) { - viper.SetConfigFile("config.yaml") - if err := viper.ReadInConfig(); err != nil { - log.Fatal("[Config] Error reading config: ", err) - } viper.Set(key, value) } diff --git a/config/default.go b/config/default.go index 8fee6c8..a2cb58d 100644 --- a/config/default.go +++ b/config/default.go @@ -4,6 +4,7 @@ var serverDef = server{ Address: ":8000", DebugMode: false, FileLogger: false, + JwtSecret: "something", } var databaseDef = database{ diff --git a/config/env.go b/config/env.go index 890a948..3e43efa 100644 --- a/config/env.go +++ b/config/env.go @@ -54,7 +54,7 @@ func SetEnvConf(key string, sub string) { func EnvInit() { var dict = map[string][]string{ - "server": {"address", "debug_mode", "file_logger"}, + "server": {"address", "debug_mode", "file_logger", "jwt_secret"}, "database": {"type", "host", "name", "username", "password"}, } for key, value := range dict { diff --git a/config/types.go b/config/types.go index 274db3e..197eabe 100644 --- a/config/types.go +++ b/config/types.go @@ -9,6 +9,7 @@ type server struct { Address string `yaml:"address"` DebugMode bool `yaml:"debug_mode"` FileLogger bool `yaml:"file_logger"` + JwtSecret string `yaml:"jwt_secret"` } type database struct { diff --git a/data/user.go b/data/user.go index 7efe964..2bb3852 100644 --- a/data/user.go +++ b/data/user.go @@ -21,6 +21,13 @@ func (self *User) GetByEmail(email string) error { return nil } +func (self *User) GetByUserId(userId string) error { + if err := Database.Where("user_id = ?", userId).First(&self).Error; err != nil { + return err + } + return nil +} + func (self *User) SetCheckinState(email string, state bool) error { if err := Database.Where("email = ?", email).First(&self).Error; err != nil { return err diff --git a/go.mod b/go.mod index 3dffd96..4249f4d 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,7 @@ require ( github.com/go-viper/mapstructure/v2 v2.4.0 // indirect github.com/goccy/go-json v0.10.5 // indirect github.com/goccy/go-yaml v1.19.1 // indirect + github.com/golang-jwt/jwt/v5 v5.3.0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect diff --git a/go.sum b/go.sum index 10aa00e..e2fb676 100644 --- a/go.sum +++ b/go.sum @@ -28,6 +28,8 @@ github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/goccy/go-yaml v1.19.1 h1:3rG3+v8pkhRqoQ/88NYNMHYVGYztCOCIZ7UQhu7H+NE= github.com/goccy/go-yaml v1.19.1/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= diff --git a/internal/crypto/jwt/jwt.go b/internal/crypto/jwt/jwt.go new file mode 100644 index 0000000..53e2a05 --- /dev/null +++ b/internal/crypto/jwt/jwt.go @@ -0,0 +1,77 @@ +package jwt + +import ( + "net/http" + "nixcn-cms/config" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" +) + +type Claims struct { + UserID uuid.UUID `json:"user_id"` + jwt.RegisteredClaims +} + +func JWTAuth() gin.HandlerFunc { + var JwtSecret = []byte(config.Get("server.jwt_secret").(string)) + return func(c *gin.Context) { + auth := c.GetHeader("Authorization") + if auth == "" { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ + "error": "missing Authorization header", + }) + return + } + + parts := strings.SplitN(auth, " ", 2) + if len(parts) != 2 || parts[0] != "Bearer" { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ + "error": "invalid Authorization header format", + }) + return + } + + tokenStr := parts[1] + + token, err := jwt.ParseWithClaims(tokenStr, &Claims{}, func(token *jwt.Token) (interface{}, error) { + return JwtSecret, nil + }) + + if err != nil || !token.Valid { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ + "error": "invalid or expired token", + }) + return + } + + claims, ok := token.Claims.(*Claims) + if !ok { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ + "error": "invalid token claims", + }) + return + } + + c.Set("user_id", claims.UserID) + c.Next() + } +} + +func GenerateToken(userID uuid.UUID, application string) (string, error) { + var JwtSecret = []byte(config.Get("server.jwt_secret").(string)) + claims := Claims{ + UserID: userID, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + Issuer: application, + }, + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + return token.SignedString(JwtSecret) +} diff --git a/internal/crypto/jwt/jwt_test.go b/internal/crypto/jwt/jwt_test.go new file mode 100644 index 0000000..a169e4e --- /dev/null +++ b/internal/crypto/jwt/jwt_test.go @@ -0,0 +1,95 @@ +package jwt + +import ( + "net/http" + "net/http/httptest" + "nixcn-cms/config" + "os" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" +) + +func init() { + os.Setenv("GO_ENV", "test") + config.Init() +} + +func generateTestToken(userID uuid.UUID, expire time.Duration) string { + var JwtSecret = []byte(config.Get("server.jwt_secret").(string)) + claims := Claims{ + UserID: userID, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(expire)), + }, + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenStr, _ := token.SignedString(JwtSecret) + return tokenStr +} +func TestJWTAuth_MissingToken(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + r.Use(JWTAuth()) + r.GET("/test", func(c *gin.Context) { + c.JSON(200, gin.H{"ok": true}) + }) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + w := httptest.NewRecorder() + + r.ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d", w.Code) + } +} +func TestJWTAuth_InvalidToken(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + r.Use(JWTAuth()) + r.GET("/test", func(c *gin.Context) { + c.JSON(200, gin.H{"ok": true}) + }) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Header.Set("Authorization", "Bearer invalid.token.here") + w := httptest.NewRecorder() + + r.ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d", w.Code) + } +} +func TestJWTAuth_ValidToken(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + r.Use(JWTAuth()) + r.GET("/test", func(c *gin.Context) { + userID := c.GetUint("user_id") + c.JSON(200, gin.H{ + "user_id": userID, + }) + }) + + uuid, _ := uuid.NewUUID() + token := generateTestToken(uuid, time.Hour) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Header.Set("Authorization", "Bearer "+token) + w := httptest.NewRecorder() + + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } +} diff --git a/justfile b/justfile index 91a4efe..a5307dc 100644 --- a/justfile +++ b/justfile @@ -17,3 +17,5 @@ build: run: cd {{ output_dir }} && {{ exec_path }}{{ if os() == "windows" { ".exe" } else { "" } }} +test: + cd {{output_dir}} && CONFIG_PATH={{output_dir}}/config.yaml GO_ENV=test go test -C .. ./... diff --git a/service/check/checkin.go b/service/check/checkin.go new file mode 100644 index 0000000..14a3884 --- /dev/null +++ b/service/check/checkin.go @@ -0,0 +1,5 @@ +package check + +func Checkin() { + +} diff --git a/service/check/handler.go b/service/check/handler.go index 445c9d2..768e739 100644 --- a/service/check/handler.go +++ b/service/check/handler.go @@ -1,8 +1,13 @@ package check -import "github.com/gin-gonic/gin" +import ( + "nixcn-cms/internal/crypto/jwt" + + "github.com/gin-gonic/gin" +) func Handler(r *gin.RouterGroup) { + r.Use(jwt.JWTAuth()) r.GET("/test", func(ctx *gin.Context) { ctx.JSON(200, gin.H{"Test": "Test"}) })