package authtoken import ( "context" "crypto/rand" "encoding/base64" "errors" "fmt" "nixcn-cms/data" "strings" "time" "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" "github.com/spf13/viper" ) type Token struct { Application string } type JwtClaims struct { ClientId string `json:"client_id"` UserID uuid.UUID `json:"user_id"` jwt.RegisteredClaims } // Generate jwt clames func (self *Token) NewClaims(clientId string, userId uuid.UUID) JwtClaims { return JwtClaims{ ClientId: clientId, UserID: userId, RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(viper.GetDuration("ttl.access_ttl"))), IssuedAt: jwt.NewNumericDate(time.Now()), Issuer: self.Application, }, } } // Generate access token func (self *Token) GenerateAccessToken(clientId string, userId uuid.UUID) (string, error) { claims := self.NewClaims(clientId, userId) token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) clientData, err := new(data.Client).GetClientByClientId(clientId) if err != nil { return "", fmt.Errorf("error getting client data: %v", err) } secret, err := clientData.GetDecryptedSecret() if err != nil { return "", fmt.Errorf("error getting decrypted secret: %v", err) } signedToken, err := token.SignedString([]byte(secret)) if err != nil { return "", fmt.Errorf("error signing token: %v", err) } return signedToken, nil } // Generate refresh token func (self *Token) GenerateRefreshToken() (string, error) { b := make([]byte, 32) if _, err := rand.Read(b); err != nil { return "", err } return base64.URLEncoding.EncodeToString(b), nil } // Issue both access and refresh token func (self *Token) IssueTokens(clientId string, userId uuid.UUID) (string, string, error) { // access token access, err := self.GenerateAccessToken(clientId, userId) if err != nil { return "", "", err } // refresh token refresh, err := self.GenerateRefreshToken() if err != nil { return "", "", err } ctx := context.Background() ttl := viper.GetDuration("ttl.refresh_ttl") refreshKey := "refresh:" + refresh // refresh -> user + client if err := data.Redis.HSet( ctx, refreshKey, map[string]any{ "user_id": userId.String(), "client_id": clientId, }, ).Err(); err != nil { return "", "", err } if err := data.Redis.Expire(ctx, refreshKey, ttl).Err(); err != nil { return "", "", err } // user -> refresh tokens userSetKey := "user:" + userId.String() + ":refresh_tokens" if err := data.Redis.SAdd(ctx, userSetKey, refresh).Err(); err != nil { return "", "", err } _ = data.Redis.Expire(ctx, userSetKey, ttl).Err() // client -> refresh tokens clientSetKey := "client:" + clientId + ":refresh_tokens" _ = data.Redis.SAdd(ctx, clientSetKey, refresh).Err() _ = data.Redis.Expire(ctx, clientSetKey, ttl).Err() return access, refresh, nil } // Refresh access token func (self *Token) RefreshAccessToken(refreshToken string) (string, error) { ctx := context.Background() key := "refresh:" + refreshToken // read refresh token bind data dataMap, err := data.Redis.HGetAll(ctx, key).Result() if err != nil || len(dataMap) == 0 { return "", errors.New("invalid refresh token") } userIdStr := dataMap["user_id"] clientId := dataMap["client_id"] if userIdStr == "" || clientId == "" { return "", errors.New("refresh token corrupted") } userId, err := uuid.Parse(userIdStr) if err != nil { return "", err } // Generate new access token return self.GenerateAccessToken(clientId, userId) } func (self *Token) RenewRefreshToken(refreshToken string) (string, error) { ctx := context.Background() ttl := viper.GetDuration("ttl.refresh_ttl") oldKey := "refresh:" + refreshToken // read old refresh token bind data dataMap, err := data.Redis.HGetAll(ctx, oldKey).Result() if err != nil || len(dataMap) == 0 { return "", errors.New("invalid refresh token") } userIdStr := dataMap["user_id"] clientId := dataMap["client_id"] if userIdStr == "" || clientId == "" { return "", errors.New("refresh token corrupted") } // generate new refresh token newRefresh, err := self.GenerateRefreshToken() if err != nil { return "", err } // revoke old refresh token if err := self.RevokeRefreshToken(refreshToken); err != nil { return "", err } newKey := "refresh:" + newRefresh // refresh -> user + client if err := data.Redis.HSet( ctx, newKey, map[string]any{ "user_id": userIdStr, "client_id": clientId, }, ).Err(); err != nil { return "", err } if err := data.Redis.Expire(ctx, newKey, ttl).Err(); err != nil { return "", err } // user -> refresh tokens userSetKey := "user:" + userIdStr + ":refresh_tokens" if err := data.Redis.SAdd(ctx, userSetKey, newRefresh).Err(); err != nil { return "", err } _ = data.Redis.Expire(ctx, userSetKey, ttl).Err() // client -> refresh tokens clientSetKey := "client:" + clientId + ":refresh_tokens" _ = data.Redis.SAdd(ctx, clientSetKey, newRefresh).Err() _ = data.Redis.Expire(ctx, clientSetKey, ttl).Err() return newRefresh, nil } func (self *Token) RevokeRefreshToken(refreshToken string) error { ctx := context.Background() refreshKey := "refresh:" + refreshToken // read refresh token metadata (user_id, client_id) dataMap, err := data.Redis.HGetAll(ctx, refreshKey).Result() if err != nil || len(dataMap) == 0 { // Token already revoked or not found return nil } userID := dataMap["user_id"] clientID := dataMap["client_id"] // build index keys userSetKey := "user:" + userID + ":refresh_tokens" clientSetKey := "client:" + clientID + ":refresh_tokens" // remove refresh token and all related indexes atomically pipe := data.Redis.TxPipeline() // remove main refresh token record pipe.Del(ctx, refreshKey) // remove refresh token from user's active refresh token set pipe.SRem(ctx, userSetKey, refreshToken) // remove refresh token from client's active refresh token set pipe.SRem(ctx, clientSetKey, refreshToken) _, err = pipe.Exec(ctx) return err } func (self *Token) HeaderVerify(header string) (string, error) { if header == "" { return "", nil } // Split header to 2 parts := strings.SplitN(header, " ", 2) if len(parts) != 2 || parts[0] != "Bearer" { return "", errors.New("invalid Authorization header format") } tokenStr := parts[1] // Verify access token claims := &JwtClaims{} token, err := jwt.ParseWithClaims( tokenStr, claims, func(token *jwt.Token) (any, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, errors.New("unexpected signing method") } if claims.ClientId == "" { return nil, errors.New("client_id missing in token") } clientData, err := new(data.Client).GetClientByClientId(claims.ClientId) if err != nil { return nil, fmt.Errorf("error getting client data: %v", err) } secret, err := clientData.GetDecryptedSecret() if err != nil { return nil, fmt.Errorf("error getting decrypted secret: %v", err) } return secret, nil }, ) if err != nil || !token.Valid { return "", errors.New("invalid or expired token") } return claims.UserID.String(), nil }