package cryptography import ( "context" "crypto/rand" "encoding/base64" "errors" "nixcn-cms/data" "time" "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" "github.com/redis/go-redis/v9" "github.com/spf13/viper" ) type Token struct { UserID uuid.UUID Application string } type JwtClaims struct { UserID uuid.UUID `json:"user_id"` jwt.RegisteredClaims } // Generate jwt clames func (self *Token) NewClaims() JwtClaims { return JwtClaims{ UserID: self.UserID, RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(viper.GetDuration("ttl.jwt_ttl"))), IssuedAt: jwt.NewNumericDate(time.Now()), Issuer: self.Application, }, } } // Generate access token func (self *Token) GenerateAccessToken() (string, error) { claims := self.NewClaims() token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) secret := viper.GetString("secrets.jwt_secret") return token.SignedString(secret) } // Generate refresh token func (self *Token) GenerateRefreshToken() (string, error) { b := make([]byte, 32) if _, err := rand.Read(b); err != nil { return "", err } return base64.RawURLEncoding.EncodeToString(b), nil } // Issue both access and refresh token func (self *Token) IssueTokens() (string, string, error) { // Gen atk access, err := self.GenerateAccessToken() if err != nil { return "", "", err } // Gen rtk refresh, err := self.GenerateRefreshToken() if err != nil { return "", "", err } // Store to redis ctx := context.Background() ttl := viper.GetDuration("ttl.refresh_ttl") // refresh -> user if err := data.Redis.Set( ctx, "refresh:"+refresh, self.UserID.String(), ttl, ).Err(); err != nil { return "", "", err } // user -> refresh tokens userSetKey := "user:" + self.UserID.String() + ":refresh_tokens" if err := data.Redis.SAdd( ctx, userSetKey, refresh, ).Err(); err != nil { return "", "", err } // set user ttl >= all refresh token _ = data.Redis.Expire(ctx, userSetKey, ttl).Err() return access, refresh, nil } // Refresh access token func (self *Token) RefreshAccessToken(refreshToken string) (string, error) { // Read rtk:userid from redis ctx := context.Background() key := "refresh:" + refreshToken userIDStr, err := data.Redis.Get(ctx, key).Result() if err != nil { if err == redis.Nil { return "", errors.New("invalid refresh token") } return "", err } userID, err := uuid.Parse(userIDStr) if err != nil { return "", err } self.UserID = userID // Generate access token return self.GenerateAccessToken() } func (self *Token) RevokeRefreshToken(refreshToken string) error { ctx := context.Background() key := "refresh:" + refreshToken userIDStr, err := data.Redis.Get(ctx, key).Result() if err != nil { return nil } userSetKey := "user:" + userIDStr + ":refresh_tokens" // Delete rtk from redis pipe := data.Redis.TxPipeline() pipe.Del(ctx, key) // rtk:userid index pipe.SRem(ctx, userSetKey, refreshToken) // userid:rtk index _, err = pipe.Exec(ctx) return err }