296 lines
7.0 KiB
Go
296 lines
7.0 KiB
Go
package authtoken
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"nixcn-cms/data"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/google/uuid"
|
|
"github.com/spf13/viper"
|
|
)
|
|
|
|
type Token struct {
|
|
Application string
|
|
}
|
|
|
|
type JwtClaims struct {
|
|
ClientId string `json:"client_id"`
|
|
UserID uuid.UUID `json:"user_id"`
|
|
jwt.RegisteredClaims
|
|
}
|
|
|
|
// Generate jwt clames
|
|
func (self *Token) NewClaims(clientId string, userId uuid.UUID) JwtClaims {
|
|
return JwtClaims{
|
|
ClientId: clientId,
|
|
UserID: userId,
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(viper.GetDuration("ttl.access_ttl"))),
|
|
IssuedAt: jwt.NewNumericDate(time.Now()),
|
|
Issuer: self.Application,
|
|
},
|
|
}
|
|
}
|
|
|
|
// Generate access token
|
|
func (self *Token) GenerateAccessToken(clientId string, userId uuid.UUID) (string, error) {
|
|
claims := self.NewClaims(clientId, userId)
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
|
|
clientData, err := new(data.Client).GetClientByClientId(clientId)
|
|
if err != nil {
|
|
return "", fmt.Errorf("error getting client data: %v", err)
|
|
}
|
|
|
|
secret, err := clientData.GetDecryptedSecret()
|
|
if err != nil {
|
|
return "", fmt.Errorf("error getting decrypted secret: %v", err)
|
|
}
|
|
|
|
signedToken, err := token.SignedString([]byte(secret))
|
|
if err != nil {
|
|
return "", fmt.Errorf("error signing token: %v", err)
|
|
}
|
|
return signedToken, nil
|
|
}
|
|
|
|
// Generate refresh token
|
|
func (self *Token) GenerateRefreshToken() (string, error) {
|
|
b := make([]byte, 32)
|
|
if _, err := rand.Read(b); err != nil {
|
|
return "", err
|
|
}
|
|
return base64.URLEncoding.EncodeToString(b), nil
|
|
}
|
|
|
|
// Issue both access and refresh token
|
|
func (self *Token) IssueTokens(clientId string, userId uuid.UUID) (string, string, error) {
|
|
// access token
|
|
access, err := self.GenerateAccessToken(clientId, userId)
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
|
|
// refresh token
|
|
refresh, err := self.GenerateRefreshToken()
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
|
|
ctx := context.Background()
|
|
ttl := viper.GetDuration("ttl.refresh_ttl")
|
|
|
|
refreshKey := "refresh:" + refresh
|
|
|
|
// refresh -> user + client
|
|
if err := data.Redis.HSet(
|
|
ctx,
|
|
refreshKey,
|
|
map[string]any{
|
|
"user_id": userId.String(),
|
|
"client_id": clientId,
|
|
},
|
|
).Err(); err != nil {
|
|
return "", "", err
|
|
}
|
|
|
|
if err := data.Redis.Expire(ctx, refreshKey, ttl).Err(); err != nil {
|
|
return "", "", err
|
|
}
|
|
|
|
// user -> refresh tokens
|
|
userSetKey := "user:" + userId.String() + ":refresh_tokens"
|
|
|
|
if err := data.Redis.SAdd(ctx, userSetKey, refresh).Err(); err != nil {
|
|
return "", "", err
|
|
}
|
|
|
|
_ = data.Redis.Expire(ctx, userSetKey, ttl).Err()
|
|
|
|
// client -> refresh tokens
|
|
clientSetKey := "client:" + clientId + ":refresh_tokens"
|
|
_ = data.Redis.SAdd(ctx, clientSetKey, refresh).Err()
|
|
_ = data.Redis.Expire(ctx, clientSetKey, ttl).Err()
|
|
|
|
return access, refresh, nil
|
|
}
|
|
|
|
// Refresh access token
|
|
func (self *Token) RefreshAccessToken(refreshToken string) (string, error) {
|
|
ctx := context.Background()
|
|
key := "refresh:" + refreshToken
|
|
|
|
// read refresh token bind data
|
|
dataMap, err := data.Redis.HGetAll(ctx, key).Result()
|
|
if err != nil || len(dataMap) == 0 {
|
|
return "", errors.New("invalid refresh token")
|
|
}
|
|
|
|
userIdStr := dataMap["user_id"]
|
|
clientId := dataMap["client_id"]
|
|
|
|
if userIdStr == "" || clientId == "" {
|
|
return "", errors.New("refresh token corrupted")
|
|
}
|
|
|
|
userId, err := uuid.Parse(userIdStr)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// Generate new access token
|
|
return self.GenerateAccessToken(clientId, userId)
|
|
}
|
|
|
|
func (self *Token) RenewRefreshToken(refreshToken string) (string, error) {
|
|
ctx := context.Background()
|
|
ttl := viper.GetDuration("ttl.refresh_ttl")
|
|
|
|
oldKey := "refresh:" + refreshToken
|
|
|
|
// read old refresh token bind data
|
|
dataMap, err := data.Redis.HGetAll(ctx, oldKey).Result()
|
|
if err != nil || len(dataMap) == 0 {
|
|
return "", errors.New("invalid refresh token")
|
|
}
|
|
|
|
userIdStr := dataMap["user_id"]
|
|
clientId := dataMap["client_id"]
|
|
|
|
if userIdStr == "" || clientId == "" {
|
|
return "", errors.New("refresh token corrupted")
|
|
}
|
|
|
|
// generate new refresh token
|
|
newRefresh, err := self.GenerateRefreshToken()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// revoke old refresh token
|
|
if err := self.RevokeRefreshToken(refreshToken); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
newKey := "refresh:" + newRefresh
|
|
|
|
// refresh -> user + client
|
|
if err := data.Redis.HSet(
|
|
ctx,
|
|
newKey,
|
|
map[string]any{
|
|
"user_id": userIdStr,
|
|
"client_id": clientId,
|
|
},
|
|
).Err(); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
if err := data.Redis.Expire(ctx, newKey, ttl).Err(); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// user -> refresh tokens
|
|
userSetKey := "user:" + userIdStr + ":refresh_tokens"
|
|
if err := data.Redis.SAdd(ctx, userSetKey, newRefresh).Err(); err != nil {
|
|
return "", err
|
|
}
|
|
_ = data.Redis.Expire(ctx, userSetKey, ttl).Err()
|
|
|
|
// client -> refresh tokens
|
|
clientSetKey := "client:" + clientId + ":refresh_tokens"
|
|
_ = data.Redis.SAdd(ctx, clientSetKey, newRefresh).Err()
|
|
_ = data.Redis.Expire(ctx, clientSetKey, ttl).Err()
|
|
|
|
return newRefresh, nil
|
|
}
|
|
|
|
func (self *Token) RevokeRefreshToken(refreshToken string) error {
|
|
ctx := context.Background()
|
|
|
|
refreshKey := "refresh:" + refreshToken
|
|
|
|
// read refresh token metadata (user_id, client_id)
|
|
dataMap, err := data.Redis.HGetAll(ctx, refreshKey).Result()
|
|
if err != nil || len(dataMap) == 0 {
|
|
// Token already revoked or not found
|
|
return nil
|
|
}
|
|
|
|
userID := dataMap["user_id"]
|
|
clientID := dataMap["client_id"]
|
|
|
|
// build index keys
|
|
userSetKey := "user:" + userID + ":refresh_tokens"
|
|
clientSetKey := "client:" + clientID + ":refresh_tokens"
|
|
|
|
// remove refresh token and all related indexes atomically
|
|
pipe := data.Redis.TxPipeline()
|
|
|
|
// remove main refresh token record
|
|
pipe.Del(ctx, refreshKey)
|
|
|
|
// remove refresh token from user's active refresh token set
|
|
pipe.SRem(ctx, userSetKey, refreshToken)
|
|
|
|
// remove refresh token from client's active refresh token set
|
|
pipe.SRem(ctx, clientSetKey, refreshToken)
|
|
|
|
_, err = pipe.Exec(ctx)
|
|
return err
|
|
}
|
|
|
|
func (self *Token) HeaderVerify(header string) (string, error) {
|
|
if header == "" {
|
|
return "", nil
|
|
}
|
|
|
|
// Split header to 2
|
|
parts := strings.SplitN(header, " ", 2)
|
|
if len(parts) != 2 || parts[0] != "Bearer" {
|
|
return "", errors.New("invalid Authorization header format")
|
|
}
|
|
|
|
tokenStr := parts[1]
|
|
|
|
// Verify access token
|
|
claims := &JwtClaims{}
|
|
token, err := jwt.ParseWithClaims(
|
|
tokenStr,
|
|
claims,
|
|
func(token *jwt.Token) (any, error) {
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
return nil, errors.New("unexpected signing method")
|
|
}
|
|
|
|
if claims.ClientId == "" {
|
|
return nil, errors.New("client_id missing in token")
|
|
}
|
|
|
|
clientData, err := new(data.Client).GetClientByClientId(claims.ClientId)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error getting client data: %v", err)
|
|
}
|
|
|
|
secret, err := clientData.GetDecryptedSecret()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error getting decrypted secret: %v", err)
|
|
}
|
|
|
|
return secret, nil
|
|
},
|
|
)
|
|
|
|
if err != nil || !token.Valid {
|
|
return "", errors.New("invalid or expired token")
|
|
}
|
|
|
|
return claims.UserID.String(), nil
|
|
}
|