From b0684492fa149da69d422eb8d3a87e99deca201a Mon Sep 17 00:00:00 2001 From: Asai Neko Date: Mon, 5 Jan 2026 21:59:37 +0800 Subject: [PATCH] Change authcode using redis, authtoken use client secret to sign jwt Signed-off-by: Asai Neko --- pkgs/authcode/authcode.go | 65 +++++++----- pkgs/authtoken/authtoken.go | 193 +++++++++++++++++++++++++----------- service/auth/magic.go | 2 +- service/auth/redirect.go | 4 +- service/auth/token.go | 6 +- 5 files changed, 179 insertions(+), 91 deletions(-) diff --git a/pkgs/authcode/authcode.go b/pkgs/authcode/authcode.go index 56576fe..b815c44 100644 --- a/pkgs/authcode/authcode.go +++ b/pkgs/authcode/authcode.go @@ -1,53 +1,68 @@ package authcode import ( + "context" "crypto/rand" "encoding/base64" - "sync" - "time" + "nixcn-cms/data" "github.com/spf13/viper" ) type Token struct { - Email string - ExpiresAt time.Time + ClientId string + Email string } -var ( - store = sync.Map{} -) +func NewAuthCode(clientId string, email string) (string, error) { + ctx := context.Background() -// Generate magic token -func NewAuthCode(email string) (string, error) { + // generate random code b := make([]byte, 32) if _, err := rand.Read(b); err != nil { return "", err } code := base64.RawURLEncoding.EncodeToString(b) + key := "auth_code:" + code - store.Store(code, Token{ - Email: email, - ExpiresAt: time.Now().Add(viper.GetDuration("ttl.auth_code_ttl")), - }) + ttl := viper.GetDuration("ttl.auth_code_ttl") + + // store auth code metadata in Redis + if err := data.Redis.HSet( + ctx, + key, + map[string]any{ + "client_id": clientId, + "email": email, + }, + ).Err(); err != nil { + return "", err + } + + // set expiration (one-time auth code) + if err := data.Redis.Expire(ctx, key, ttl).Err(); err != nil { + return "", err + } return code, nil } -// Verify magic token -func VerifyAuthCode(code string) (string, bool) { - val, ok := store.Load(code) - if !ok { - return "", false +func VerifyAuthCode(code string) (*Token, bool) { + ctx := context.Background() + key := "auth_code:" + code + + // Read auth code payload + dataMap, err := data.Redis.HGetAll(ctx, key).Result() + if err != nil || len(dataMap) == 0 { + return nil, false } - t := val.(Token) - if time.Now().After(t.ExpiresAt) { - store.Delete(code) - return "", false - } + // Delete auth code immediately (one-time use) + _ = data.Redis.Del(ctx, key).Err() - store.Delete(code) - return t.Email, true + return &Token{ + ClientId: dataMap["client_id"], + Email: dataMap["email"], + }, true } diff --git a/pkgs/authtoken/authtoken.go b/pkgs/authtoken/authtoken.go index 2fbdde4..ee06931 100644 --- a/pkgs/authtoken/authtoken.go +++ b/pkgs/authtoken/authtoken.go @@ -20,14 +20,16 @@ type Token struct { } type JwtClaims struct { - UserID uuid.UUID `json:"user_id"` + ClientId string `json:"client_id"` + UserID uuid.UUID `json:"user_id"` jwt.RegisteredClaims } // Generate jwt clames -func (self *Token) NewClaims(userId uuid.UUID) JwtClaims { +func (self *Token) NewClaims(clientId string, userId uuid.UUID) JwtClaims { return JwtClaims{ - UserID: userId, + ClientId: clientId, + UserID: userId, RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(viper.GetDuration("ttl.access_ttl"))), IssuedAt: jwt.NewNumericDate(time.Now()), @@ -37,10 +39,20 @@ func (self *Token) NewClaims(userId uuid.UUID) JwtClaims { } // Generate access token -func (self *Token) GenerateAccessToken(userId uuid.UUID) (string, error) { - claims := self.NewClaims(userId) +func (self *Token) GenerateAccessToken(clientId string, userId uuid.UUID) (string, error) { + claims := self.NewClaims(clientId, userId) token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - secret := viper.GetString("secrets.jwt_secret") + + clientData, err := new(data.Client).GetClientByClientId(clientId) + if err != nil { + return "", fmt.Errorf("error getting client data: %v", err) + } + + secret, err := clientData.GetDecryptedSecret() + if err != nil { + return "", fmt.Errorf("error getting decrypted secret: %v", err) + } + signedToken, err := token.SignedString([]byte(secret)) if err != nil { return "", fmt.Errorf("error signing token: %v", err) @@ -58,59 +70,73 @@ func (self *Token) GenerateRefreshToken() (string, error) { } // Issue both access and refresh token -func (self *Token) IssueTokens(userId uuid.UUID) (string, string, error) { - // Gen atk - access, err := self.GenerateAccessToken(userId) +func (self *Token) IssueTokens(clientId string, userId uuid.UUID) (string, string, error) { + // access token + access, err := self.GenerateAccessToken(clientId, userId) if err != nil { return "", "", err } - // Gen rtk + // refresh token refresh, err := self.GenerateRefreshToken() if err != nil { return "", "", err } - // Store to redis ctx := context.Background() ttl := viper.GetDuration("ttl.refresh_ttl") - // refresh -> user - if err := data.Redis.Set( + refreshKey := "refresh:" + refresh + + // refresh -> user + client + if err := data.Redis.HSet( ctx, - "refresh:"+refresh, - userId.String(), - ttl, + refreshKey, + map[string]any{ + "user_id": userId.String(), + "client_id": clientId, + }, ).Err(); err != nil { return "", "", err } + if err := data.Redis.Expire(ctx, refreshKey, ttl).Err(); err != nil { + return "", "", err + } + // user -> refresh tokens userSetKey := "user:" + userId.String() + ":refresh_tokens" - if err := data.Redis.SAdd( - ctx, - userSetKey, - refresh, - ).Err(); err != nil { + if err := data.Redis.SAdd(ctx, userSetKey, refresh).Err(); err != nil { return "", "", err } - // set user ttl >= all refresh token _ = data.Redis.Expire(ctx, userSetKey, ttl).Err() + // client -> refresh tokens + clientSetKey := "client:" + clientId + ":refresh_tokens" + _ = data.Redis.SAdd(ctx, clientSetKey, refresh).Err() + _ = data.Redis.Expire(ctx, clientSetKey, ttl).Err() + return access, refresh, nil } // Refresh access token func (self *Token) RefreshAccessToken(refreshToken string) (string, error) { - // Read rtk:userid from redis ctx := context.Background() key := "refresh:" + refreshToken - userIdStr, err := data.Redis.Get(ctx, key).Result() - if err != nil { - return "", err + // read refresh token bind data + dataMap, err := data.Redis.HGetAll(ctx, key).Result() + if err != nil || len(dataMap) == 0 { + return "", errors.New("invalid refresh token") + } + + userIdStr := dataMap["user_id"] + clientId := dataMap["client_id"] + + if userIdStr == "" || clientId == "" { + return "", errors.New("refresh token corrupted") } userId, err := uuid.Parse(userIdStr) @@ -118,75 +144,105 @@ func (self *Token) RefreshAccessToken(refreshToken string) (string, error) { return "", err } - // Generate access token - return self.GenerateAccessToken(userId) + // Generate new access token + return self.GenerateAccessToken(clientId, userId) } func (self *Token) RenewRefreshToken(refreshToken string) (string, error) { ctx := context.Background() ttl := viper.GetDuration("ttl.refresh_ttl") - key := "refresh:" + refreshToken - userIdStr, err := data.Redis.Get(ctx, key).Result() + oldKey := "refresh:" + refreshToken + + // read old refresh token bind data + dataMap, err := data.Redis.HGetAll(ctx, oldKey).Result() + if err != nil || len(dataMap) == 0 { + return "", errors.New("invalid refresh token") + } + + userIdStr := dataMap["user_id"] + clientId := dataMap["client_id"] + + if userIdStr == "" || clientId == "" { + return "", errors.New("refresh token corrupted") + } + + // generate new refresh token + newRefresh, err := self.GenerateRefreshToken() if err != nil { return "", err } - refresh, err := self.GenerateRefreshToken() - if err != nil { + // revoke old refresh token + if err := self.RevokeRefreshToken(refreshToken); err != nil { return "", err } - err = self.RevokeRefreshToken(refreshToken) - if err != nil { - return "", err - } + newKey := "refresh:" + newRefresh - // refresh -> user - if err := data.Redis.Set( + // refresh -> user + client + if err := data.Redis.HSet( ctx, - "refresh:"+refresh, - userIdStr, - ttl, + newKey, + map[string]any{ + "user_id": userIdStr, + "client_id": clientId, + }, ).Err(); err != nil { return "", err } + if err := data.Redis.Expire(ctx, newKey, ttl).Err(); err != nil { + return "", err + } + // user -> refresh tokens userSetKey := "user:" + userIdStr + ":refresh_tokens" - - if err := data.Redis.SAdd( - ctx, - userSetKey, - refresh, - ).Err(); err != nil { + if err := data.Redis.SAdd(ctx, userSetKey, newRefresh).Err(); err != nil { return "", err } - - // set user ttl >= all refresh token _ = data.Redis.Expire(ctx, userSetKey, ttl).Err() - return refresh, nil + // client -> refresh tokens + clientSetKey := "client:" + clientId + ":refresh_tokens" + _ = data.Redis.SAdd(ctx, clientSetKey, newRefresh).Err() + _ = data.Redis.Expire(ctx, clientSetKey, ttl).Err() + + return newRefresh, nil } func (self *Token) RevokeRefreshToken(refreshToken string) error { ctx := context.Background() - key := "refresh:" + refreshToken + refreshKey := "refresh:" + refreshToken - userIDStr, err := data.Redis.Get(ctx, key).Result() - if err != nil { + // read refresh token metadata (user_id, client_id) + dataMap, err := data.Redis.HGetAll(ctx, refreshKey).Result() + if err != nil || len(dataMap) == 0 { + // Token already revoked or not found return nil } - userSetKey := "user:" + userIDStr + ":refresh_tokens" + userID := dataMap["user_id"] + clientID := dataMap["client_id"] - // Delete rtk from redis + // build index keys + userSetKey := "user:" + userID + ":refresh_tokens" + clientSetKey := "client:" + clientID + ":refresh_tokens" + + // remove refresh token and all related indexes atomically pipe := data.Redis.TxPipeline() - pipe.Del(ctx, key) // rtk:userid index - pipe.SRem(ctx, userSetKey, refreshToken) // userid:rtk index - _, err = pipe.Exec(ctx) + // remove main refresh token record + pipe.Del(ctx, refreshKey) + + // remove refresh token from user's active refresh token set + pipe.SRem(ctx, userSetKey, refreshToken) + + // remove refresh token from client's active refresh token set + pipe.SRem(ctx, clientSetKey, refreshToken) + + _, err = pipe.Exec(ctx) return err } @@ -195,7 +251,6 @@ func (self *Token) HeaderVerify(header string) (string, error) { return "", nil } - jwtSecret := []byte(viper.GetString("secrets.jwt_secret")) // Split header to 2 parts := strings.SplitN(header, " ", 2) if len(parts) != 2 || parts[0] != "Bearer" { @@ -210,7 +265,25 @@ func (self *Token) HeaderVerify(header string) (string, error) { tokenStr, claims, func(token *jwt.Token) (any, error) { - return jwtSecret, nil + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, errors.New("unexpected signing method") + } + + if claims.ClientId == "" { + return nil, errors.New("client_id missing in token") + } + + clientData, err := new(data.Client).GetClientByClientId(claims.ClientId) + if err != nil { + return nil, fmt.Errorf("error getting client data: %v", err) + } + + secret, err := clientData.GetDecryptedSecret() + if err != nil { + return nil, fmt.Errorf("error getting decrypted secret: %v", err) + } + + return secret, nil }, ) diff --git a/service/auth/magic.go b/service/auth/magic.go index d60d625..634d74a 100644 --- a/service/auth/magic.go +++ b/service/auth/magic.go @@ -33,7 +33,7 @@ func Magic(c *gin.Context) { return } - code, err := authcode.NewAuthCode(req.Email) + code, err := authcode.NewAuthCode(req.ClientId, req.Email) if err != nil { c.JSON(500, gin.H{"status": "code gen failed"}) } diff --git a/service/auth/redirect.go b/service/auth/redirect.go index 2962d0a..1e36a74 100644 --- a/service/auth/redirect.go +++ b/service/auth/redirect.go @@ -50,7 +50,7 @@ func Redirect(c *gin.Context) { return } - code, err := authcode.NewAuthCode(user.Email) + code, err := authcode.NewAuthCode(clientId, user.Email) if err != nil { c.JSON(500, gin.H{"status": "code gen failed"}) return @@ -109,7 +109,7 @@ func Redirect(c *gin.Context) { return } - newCode, err := authcode.NewAuthCode(email) + newCode, err := authcode.NewAuthCode(clientId, email) if err != nil { c.JSON(500, gin.H{"status": "internal server error"}) return diff --git a/service/auth/token.go b/service/auth/token.go index a829fd8..2af75ac 100644 --- a/service/auth/token.go +++ b/service/auth/token.go @@ -22,14 +22,14 @@ func Token(c *gin.Context) { return } - email, ok := authcode.VerifyAuthCode(req.Code) + authCode, ok := authcode.VerifyAuthCode(req.Code) if !ok { c.JSON(403, gin.H{"status": "invalid or expired token"}) return } userData := new(data.User) - user, err := userData.GetByEmail(email) + user, err := userData.GetByEmail(authCode.Email) if err != nil { c.JSON(500, gin.H{"status": "internal server error"}) return @@ -39,7 +39,7 @@ func Token(c *gin.Context) { JwtTool := authtoken.Token{ Application: viper.GetString("server.application"), } - accessToken, refreshToken, err := JwtTool.IssueTokens(user.UserId) + accessToken, refreshToken, err := JwtTool.IssueTokens(authCode.ClientId, user.UserId) if err != nil { c.JSON(500, gin.H{"status": "error generating tokens"}) return