Add context for everything

Signed-off-by: Asai Neko <sugar@sne.moe>
This commit is contained in:
2026-01-21 16:43:46 +08:00
parent 83df018d34
commit b8f89ab655
27 changed files with 309 additions and 127 deletions

View File

@@ -39,11 +39,11 @@ func (self *Token) NewClaims(clientId string, userId uuid.UUID) JwtClaims {
}
// Generate access token
func (self *Token) GenerateAccessToken(clientId string, userId uuid.UUID) (string, error) {
func (self *Token) GenerateAccessToken(ctx context.Context, clientId string, userId uuid.UUID) (string, error) {
claims := self.NewClaims(clientId, userId)
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
clientData, err := new(data.Client).GetClientByClientId(clientId)
clientData, err := new(data.Client).GetClientByClientId(ctx, clientId)
if err != nil {
return "", fmt.Errorf("error getting client data: %v", err)
}
@@ -70,9 +70,9 @@ func (self *Token) GenerateRefreshToken() (string, error) {
}
// Issue both access and refresh token
func (self *Token) IssueTokens(clientId string, userId uuid.UUID) (string, string, error) {
func (self *Token) IssueTokens(ctx context.Context, clientId string, userId uuid.UUID) (string, string, error) {
// access token
access, err := self.GenerateAccessToken(clientId, userId)
access, err := self.GenerateAccessToken(ctx, clientId, userId)
if err != nil {
return "", "", err
}
@@ -83,7 +83,6 @@ func (self *Token) IssueTokens(clientId string, userId uuid.UUID) (string, strin
return "", "", err
}
ctx := context.Background()
ttl := viper.GetDuration("ttl.refresh_ttl")
refreshKey := "refresh:" + refresh
@@ -122,8 +121,7 @@ func (self *Token) IssueTokens(clientId string, userId uuid.UUID) (string, strin
}
// Refresh access token
func (self *Token) RefreshAccessToken(refreshToken string) (string, error) {
ctx := context.Background()
func (self *Token) RefreshAccessToken(ctx context.Context, refreshToken string) (string, error) {
key := "refresh:" + refreshToken
// read refresh token bind data
@@ -145,11 +143,10 @@ func (self *Token) RefreshAccessToken(refreshToken string) (string, error) {
}
// Generate new access token
return self.GenerateAccessToken(clientId, userId)
return self.GenerateAccessToken(ctx, clientId, userId)
}
func (self *Token) RenewRefreshToken(refreshToken string) (string, error) {
ctx := context.Background()
func (self *Token) RenewRefreshToken(ctx context.Context, refreshToken string) (string, error) {
ttl := viper.GetDuration("ttl.refresh_ttl")
oldKey := "refresh:" + refreshToken
@@ -174,7 +171,7 @@ func (self *Token) RenewRefreshToken(refreshToken string) (string, error) {
}
// revoke old refresh token
if err := self.RevokeRefreshToken(refreshToken); err != nil {
if err := self.RevokeRefreshToken(ctx, refreshToken); err != nil {
return "", err
}
@@ -211,9 +208,7 @@ func (self *Token) RenewRefreshToken(refreshToken string) (string, error) {
return newRefresh, nil
}
func (self *Token) RevokeRefreshToken(refreshToken string) error {
ctx := context.Background()
func (self *Token) RevokeRefreshToken(ctx context.Context, refreshToken string) error {
refreshKey := "refresh:" + refreshToken
// read refresh token metadata (user_id, client_id)
@@ -246,7 +241,7 @@ func (self *Token) RevokeRefreshToken(refreshToken string) error {
return err
}
func (self *Token) HeaderVerify(header string) (string, error) {
func (self *Token) HeaderVerify(ctx context.Context, header string) (string, error) {
if header == "" {
return "", nil
}
@@ -273,7 +268,7 @@ func (self *Token) HeaderVerify(header string) (string, error) {
return nil, errors.New("[Auth Token] client_id missing in token")
}
clientData, err := new(data.Client).GetClientByClientId(claims.ClientId)
clientData, err := new(data.Client).GetClientByClientId(ctx, claims.ClientId)
if err != nil {
return nil, fmt.Errorf("error getting client data: %v", err)
}