Files
nixcn-cms/pkgs/authtoken/authtoken.go
Asai Neko b0684492fa
Some checks failed
Build Backend (NixCN CMS) TeamCity build failed
Build Frontend (NixCN CMS) TeamCity build finished
Change authcode using redis, authtoken use client secret to sign jwt
Signed-off-by: Asai Neko <sugar@sne.moe>
2026-01-05 21:59:37 +08:00

296 lines
7.0 KiB
Go

package authtoken
import (
"context"
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"nixcn-cms/data"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/spf13/viper"
)
type Token struct {
Application string
}
type JwtClaims struct {
ClientId string `json:"client_id"`
UserID uuid.UUID `json:"user_id"`
jwt.RegisteredClaims
}
// Generate jwt clames
func (self *Token) NewClaims(clientId string, userId uuid.UUID) JwtClaims {
return JwtClaims{
ClientId: clientId,
UserID: userId,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(viper.GetDuration("ttl.access_ttl"))),
IssuedAt: jwt.NewNumericDate(time.Now()),
Issuer: self.Application,
},
}
}
// Generate access token
func (self *Token) GenerateAccessToken(clientId string, userId uuid.UUID) (string, error) {
claims := self.NewClaims(clientId, userId)
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
clientData, err := new(data.Client).GetClientByClientId(clientId)
if err != nil {
return "", fmt.Errorf("error getting client data: %v", err)
}
secret, err := clientData.GetDecryptedSecret()
if err != nil {
return "", fmt.Errorf("error getting decrypted secret: %v", err)
}
signedToken, err := token.SignedString([]byte(secret))
if err != nil {
return "", fmt.Errorf("error signing token: %v", err)
}
return signedToken, nil
}
// Generate refresh token
func (self *Token) GenerateRefreshToken() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(b), nil
}
// Issue both access and refresh token
func (self *Token) IssueTokens(clientId string, userId uuid.UUID) (string, string, error) {
// access token
access, err := self.GenerateAccessToken(clientId, userId)
if err != nil {
return "", "", err
}
// refresh token
refresh, err := self.GenerateRefreshToken()
if err != nil {
return "", "", err
}
ctx := context.Background()
ttl := viper.GetDuration("ttl.refresh_ttl")
refreshKey := "refresh:" + refresh
// refresh -> user + client
if err := data.Redis.HSet(
ctx,
refreshKey,
map[string]any{
"user_id": userId.String(),
"client_id": clientId,
},
).Err(); err != nil {
return "", "", err
}
if err := data.Redis.Expire(ctx, refreshKey, ttl).Err(); err != nil {
return "", "", err
}
// user -> refresh tokens
userSetKey := "user:" + userId.String() + ":refresh_tokens"
if err := data.Redis.SAdd(ctx, userSetKey, refresh).Err(); err != nil {
return "", "", err
}
_ = data.Redis.Expire(ctx, userSetKey, ttl).Err()
// client -> refresh tokens
clientSetKey := "client:" + clientId + ":refresh_tokens"
_ = data.Redis.SAdd(ctx, clientSetKey, refresh).Err()
_ = data.Redis.Expire(ctx, clientSetKey, ttl).Err()
return access, refresh, nil
}
// Refresh access token
func (self *Token) RefreshAccessToken(refreshToken string) (string, error) {
ctx := context.Background()
key := "refresh:" + refreshToken
// read refresh token bind data
dataMap, err := data.Redis.HGetAll(ctx, key).Result()
if err != nil || len(dataMap) == 0 {
return "", errors.New("invalid refresh token")
}
userIdStr := dataMap["user_id"]
clientId := dataMap["client_id"]
if userIdStr == "" || clientId == "" {
return "", errors.New("refresh token corrupted")
}
userId, err := uuid.Parse(userIdStr)
if err != nil {
return "", err
}
// Generate new access token
return self.GenerateAccessToken(clientId, userId)
}
func (self *Token) RenewRefreshToken(refreshToken string) (string, error) {
ctx := context.Background()
ttl := viper.GetDuration("ttl.refresh_ttl")
oldKey := "refresh:" + refreshToken
// read old refresh token bind data
dataMap, err := data.Redis.HGetAll(ctx, oldKey).Result()
if err != nil || len(dataMap) == 0 {
return "", errors.New("invalid refresh token")
}
userIdStr := dataMap["user_id"]
clientId := dataMap["client_id"]
if userIdStr == "" || clientId == "" {
return "", errors.New("refresh token corrupted")
}
// generate new refresh token
newRefresh, err := self.GenerateRefreshToken()
if err != nil {
return "", err
}
// revoke old refresh token
if err := self.RevokeRefreshToken(refreshToken); err != nil {
return "", err
}
newKey := "refresh:" + newRefresh
// refresh -> user + client
if err := data.Redis.HSet(
ctx,
newKey,
map[string]any{
"user_id": userIdStr,
"client_id": clientId,
},
).Err(); err != nil {
return "", err
}
if err := data.Redis.Expire(ctx, newKey, ttl).Err(); err != nil {
return "", err
}
// user -> refresh tokens
userSetKey := "user:" + userIdStr + ":refresh_tokens"
if err := data.Redis.SAdd(ctx, userSetKey, newRefresh).Err(); err != nil {
return "", err
}
_ = data.Redis.Expire(ctx, userSetKey, ttl).Err()
// client -> refresh tokens
clientSetKey := "client:" + clientId + ":refresh_tokens"
_ = data.Redis.SAdd(ctx, clientSetKey, newRefresh).Err()
_ = data.Redis.Expire(ctx, clientSetKey, ttl).Err()
return newRefresh, nil
}
func (self *Token) RevokeRefreshToken(refreshToken string) error {
ctx := context.Background()
refreshKey := "refresh:" + refreshToken
// read refresh token metadata (user_id, client_id)
dataMap, err := data.Redis.HGetAll(ctx, refreshKey).Result()
if err != nil || len(dataMap) == 0 {
// Token already revoked or not found
return nil
}
userID := dataMap["user_id"]
clientID := dataMap["client_id"]
// build index keys
userSetKey := "user:" + userID + ":refresh_tokens"
clientSetKey := "client:" + clientID + ":refresh_tokens"
// remove refresh token and all related indexes atomically
pipe := data.Redis.TxPipeline()
// remove main refresh token record
pipe.Del(ctx, refreshKey)
// remove refresh token from user's active refresh token set
pipe.SRem(ctx, userSetKey, refreshToken)
// remove refresh token from client's active refresh token set
pipe.SRem(ctx, clientSetKey, refreshToken)
_, err = pipe.Exec(ctx)
return err
}
func (self *Token) HeaderVerify(header string) (string, error) {
if header == "" {
return "", nil
}
// Split header to 2
parts := strings.SplitN(header, " ", 2)
if len(parts) != 2 || parts[0] != "Bearer" {
return "", errors.New("invalid Authorization header format")
}
tokenStr := parts[1]
// Verify access token
claims := &JwtClaims{}
token, err := jwt.ParseWithClaims(
tokenStr,
claims,
func(token *jwt.Token) (any, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, errors.New("unexpected signing method")
}
if claims.ClientId == "" {
return nil, errors.New("client_id missing in token")
}
clientData, err := new(data.Client).GetClientByClientId(claims.ClientId)
if err != nil {
return nil, fmt.Errorf("error getting client data: %v", err)
}
secret, err := clientData.GetDecryptedSecret()
if err != nil {
return nil, fmt.Errorf("error getting decrypted secret: %v", err)
}
return secret, nil
},
)
if err != nil || !token.Valid {
return "", errors.New("invalid or expired token")
}
return claims.UserID.String(), nil
}