WIP: Full restruct, seprate service and api
Signed-off-by: Asai Neko <sugar@sne.moe>
This commit is contained in:
291
internal/authtoken/authtoken.go
Normal file
291
internal/authtoken/authtoken.go
Normal file
@@ -0,0 +1,291 @@
|
||||
package authtoken
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"nixcn-cms/data"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
type Token struct {
|
||||
Application string
|
||||
}
|
||||
|
||||
type JwtClaims struct {
|
||||
ClientId string `json:"client_id"`
|
||||
UserID uuid.UUID `json:"user_id"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// Generate jwt clames
|
||||
func (self *Token) NewClaims(clientId string, userId uuid.UUID) JwtClaims {
|
||||
return JwtClaims{
|
||||
ClientId: clientId,
|
||||
UserID: userId,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(viper.GetDuration("ttl.access_ttl"))),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
Issuer: self.Application,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Generate access token
|
||||
func (self *Token) GenerateAccessToken(ctx context.Context, clientId string, userId uuid.UUID) (string, error) {
|
||||
claims := self.NewClaims(clientId, userId)
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
|
||||
clientData, err := new(data.Client).GetClientByClientId(ctx, clientId)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error getting client data: %v", err)
|
||||
}
|
||||
|
||||
secret, err := clientData.GetDecryptedSecret()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error getting decrypted secret: %v", err)
|
||||
}
|
||||
|
||||
signedToken, err := token.SignedString([]byte(secret))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error signing token: %v", err)
|
||||
}
|
||||
return signedToken, nil
|
||||
}
|
||||
|
||||
// Generate refresh token
|
||||
func (self *Token) GenerateRefreshToken() (string, error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.URLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// Issue both access and refresh token
|
||||
func (self *Token) IssueTokens(ctx context.Context, clientId string, userId uuid.UUID) (string, string, error) {
|
||||
// access token
|
||||
access, err := self.GenerateAccessToken(ctx, clientId, userId)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
// refresh token
|
||||
refresh, err := self.GenerateRefreshToken()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
ttl := viper.GetDuration("ttl.refresh_ttl")
|
||||
|
||||
refreshKey := "refresh:" + refresh
|
||||
|
||||
// refresh -> user + client
|
||||
if err := data.Redis.HSet(
|
||||
ctx,
|
||||
refreshKey,
|
||||
map[string]any{
|
||||
"user_id": userId.String(),
|
||||
"client_id": clientId,
|
||||
},
|
||||
).Err(); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
if err := data.Redis.Expire(ctx, refreshKey, ttl).Err(); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
// user -> refresh tokens
|
||||
userSetKey := "user:" + userId.String() + ":refresh_tokens"
|
||||
|
||||
if err := data.Redis.SAdd(ctx, userSetKey, refresh).Err(); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
_ = data.Redis.Expire(ctx, userSetKey, ttl).Err()
|
||||
|
||||
// client -> refresh tokens
|
||||
clientSetKey := "client:" + clientId + ":refresh_tokens"
|
||||
_ = data.Redis.SAdd(ctx, clientSetKey, refresh).Err()
|
||||
_ = data.Redis.Expire(ctx, clientSetKey, ttl).Err()
|
||||
|
||||
return access, refresh, nil
|
||||
}
|
||||
|
||||
// Refresh access token
|
||||
func (self *Token) RefreshAccessToken(ctx context.Context, refreshToken string) (string, error) {
|
||||
key := "refresh:" + refreshToken
|
||||
|
||||
// read refresh token bind data
|
||||
dataMap, err := data.Redis.HGetAll(ctx, key).Result()
|
||||
if err != nil || len(dataMap) == 0 {
|
||||
return "", errors.New("[Auth Token] invalid refresh token")
|
||||
}
|
||||
|
||||
userIdStr := dataMap["user_id"]
|
||||
clientId := dataMap["client_id"]
|
||||
|
||||
if userIdStr == "" || clientId == "" {
|
||||
return "", errors.New("[Auth Token] refresh token corrupted")
|
||||
}
|
||||
|
||||
userId, err := uuid.Parse(userIdStr)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Generate new access token
|
||||
return self.GenerateAccessToken(ctx, clientId, userId)
|
||||
}
|
||||
|
||||
func (self *Token) RenewRefreshToken(ctx context.Context, refreshToken string) (string, error) {
|
||||
ttl := viper.GetDuration("ttl.refresh_ttl")
|
||||
|
||||
oldKey := "refresh:" + refreshToken
|
||||
|
||||
// read old refresh token bind data
|
||||
dataMap, err := data.Redis.HGetAll(ctx, oldKey).Result()
|
||||
if err != nil || len(dataMap) == 0 {
|
||||
return "", errors.New("[Auth Token] invalid refresh token")
|
||||
}
|
||||
|
||||
userIdStr := dataMap["user_id"]
|
||||
clientId := dataMap["client_id"]
|
||||
|
||||
if userIdStr == "" || clientId == "" {
|
||||
return "", errors.New("[Auth Token] refresh token corrupted")
|
||||
}
|
||||
|
||||
// generate new refresh token
|
||||
newRefresh, err := self.GenerateRefreshToken()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// revoke old refresh token
|
||||
if err := self.RevokeRefreshToken(ctx, refreshToken); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
newKey := "refresh:" + newRefresh
|
||||
|
||||
// refresh -> user + client
|
||||
if err := data.Redis.HSet(
|
||||
ctx,
|
||||
newKey,
|
||||
map[string]any{
|
||||
"user_id": userIdStr,
|
||||
"client_id": clientId,
|
||||
},
|
||||
).Err(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := data.Redis.Expire(ctx, newKey, ttl).Err(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// user -> refresh tokens
|
||||
userSetKey := "user:" + userIdStr + ":refresh_tokens"
|
||||
if err := data.Redis.SAdd(ctx, userSetKey, newRefresh).Err(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
_ = data.Redis.Expire(ctx, userSetKey, ttl).Err()
|
||||
|
||||
// client -> refresh tokens
|
||||
clientSetKey := "client:" + clientId + ":refresh_tokens"
|
||||
_ = data.Redis.SAdd(ctx, clientSetKey, newRefresh).Err()
|
||||
_ = data.Redis.Expire(ctx, clientSetKey, ttl).Err()
|
||||
|
||||
return newRefresh, nil
|
||||
}
|
||||
|
||||
func (self *Token) RevokeRefreshToken(ctx context.Context, refreshToken string) error {
|
||||
refreshKey := "refresh:" + refreshToken
|
||||
|
||||
// read refresh token metadata (user_id, client_id)
|
||||
dataMap, err := data.Redis.HGetAll(ctx, refreshKey).Result()
|
||||
if err != nil || len(dataMap) == 0 {
|
||||
// Token already revoked or not found
|
||||
return nil
|
||||
}
|
||||
|
||||
userID := dataMap["user_id"]
|
||||
clientID := dataMap["client_id"]
|
||||
|
||||
// build index keys
|
||||
userSetKey := "user:" + userID + ":refresh_tokens"
|
||||
clientSetKey := "client:" + clientID + ":refresh_tokens"
|
||||
|
||||
// remove refresh token and all related indexes atomically
|
||||
pipe := data.Redis.TxPipeline()
|
||||
|
||||
// remove main refresh token record
|
||||
pipe.Del(ctx, refreshKey)
|
||||
|
||||
// remove refresh token from user's active refresh token set
|
||||
pipe.SRem(ctx, userSetKey, refreshToken)
|
||||
|
||||
// remove refresh token from client's active refresh token set
|
||||
pipe.SRem(ctx, clientSetKey, refreshToken)
|
||||
|
||||
_, err = pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (self *Token) HeaderVerify(ctx context.Context, header string) (string, error) {
|
||||
if header == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Split header to 2
|
||||
parts := strings.SplitN(header, " ", 2)
|
||||
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||
return "", errors.New("[Auth Token] invalid Authorization header format")
|
||||
}
|
||||
|
||||
tokenStr := parts[1]
|
||||
|
||||
// Verify access token
|
||||
claims := &JwtClaims{}
|
||||
token, err := jwt.ParseWithClaims(
|
||||
tokenStr,
|
||||
claims,
|
||||
func(token *jwt.Token) (any, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, errors.New("[Auth Token] unexpected signing method")
|
||||
}
|
||||
|
||||
if claims.ClientId == "" {
|
||||
return nil, errors.New("[Auth Token] client_id missing in token")
|
||||
}
|
||||
|
||||
clientData, err := new(data.Client).GetClientByClientId(ctx, claims.ClientId)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting client data: %v", err)
|
||||
}
|
||||
|
||||
secret, err := clientData.GetDecryptedSecret()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting decrypted secret: %v", err)
|
||||
}
|
||||
|
||||
return secret, nil
|
||||
},
|
||||
)
|
||||
|
||||
if err != nil || !token.Valid {
|
||||
fmt.Println(err)
|
||||
return "", errors.New("[Auth Token] invalid or expired token")
|
||||
}
|
||||
|
||||
return claims.UserID.String(), nil
|
||||
}
|
||||
Reference in New Issue
Block a user