Modify jwt middleware logic

Signed-off-by: Asai Neko <sugar@sne.moe>
This commit is contained in:
2026-01-02 12:36:07 +08:00
parent 3d685b5a86
commit cbec9bf2b3
7 changed files with 85 additions and 50 deletions

View File

@@ -4,8 +4,10 @@ import (
"context" "context"
"crypto/rand" "crypto/rand"
"encoding/base64" "encoding/base64"
"errors"
"fmt" "fmt"
"nixcn-cms/data" "nixcn-cms/data"
"strings"
"time" "time"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
@@ -187,3 +189,34 @@ func (self *Token) RevokeRefreshToken(refreshToken string) error {
return err return err
} }
func (self *Token) HeaderVerify(header string) (string, error) {
if header == "" {
return "", nil
}
jwtSecret := []byte(viper.GetString("secrets.jwt_secret"))
// Split header to 2
parts := strings.SplitN(header, " ", 2)
if len(parts) != 2 || parts[0] != "Bearer" {
return "", errors.New("invalid Authorization header format")
}
tokenStr := parts[1]
// Verify access token
claims := &JwtClaims{}
token, err := jwt.ParseWithClaims(
tokenStr,
claims,
func(token *jwt.Token) (any, error) {
return jwtSecret, nil
},
)
if err != nil || !token.Valid {
return "", errors.New("invalid or expired token")
}
return claims.UserID.String(), nil
}

View File

@@ -1,57 +1,30 @@
package middleware package middleware
import ( import (
"net/http"
"strings"
"nixcn-cms/internal/cryptography" "nixcn-cms/internal/cryptography"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/spf13/viper"
) )
func JWTAuth() gin.HandlerFunc { func JWTAuth() gin.HandlerFunc {
jwtSecret := []byte(viper.GetString("secrets.jwt_secret"))
return func(c *gin.Context) { return func(c *gin.Context) {
auth := c.GetHeader("Authorization") auth := c.GetHeader("Authorization")
if auth == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ token := new(cryptography.Token)
"error": "missing Authorization header", uid, err := token.HeaderVerify(auth)
}) if err != nil {
c.JSON(401, gin.H{"status": err.Error()})
return return
} }
// Split header to 2 if err == nil && uid == "" {
parts := strings.SplitN(auth, " ", 2) c.Set("user_id", "")
if len(parts) != 2 || parts[0] != "Bearer" { c.Next()
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "invalid Authorization header format",
})
return return
} }
tokenStr := parts[1] c.Set("user_id", uid)
// Verify access token
claims := &cryptography.JwtClaims{}
token, err := jwt.ParseWithClaims(
tokenStr,
claims,
func(token *jwt.Token) (any, error) {
return jwtSecret, nil
},
)
if err != nil || !token.Valid {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "invalid or expired token",
})
return
}
c.Set("user_id", claims.UserID)
c.Next() c.Next()
} }
} }

View File

@@ -46,8 +46,8 @@ func Magic(c *gin.Context) {
uri := viper.GetString("server.external_url") + uri := viper.GetString("server.external_url") +
"/api/v1/auth/redirect?" + "/api/v1/auth/redirect?" +
"code=" + code + "code=" + code +
"&redirect_uri" + req.RedirectUri + "&redirect_uri=" + req.RedirectUri +
"&state" + req.State "&state=" + req.State
debugMode := viper.GetString("server.debug_mode") debugMode := viper.GetString("server.debug_mode")
if debugMode == "true" { if debugMode == "true" {

View File

@@ -9,13 +9,19 @@ import (
func Checkin(c *gin.Context) { func Checkin(c *gin.Context) {
data := new(data.Attendance) data := new(data.Attendance)
userId, ok := c.Get("user_id") userIdOrig, ok := c.Get("user_id")
if !ok { if !ok {
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"status": "unauthorized", "status": "unauthorized",
}) })
return return
} }
userId, err := uuid.Parse(userIdOrig.(string))
if err != nil {
c.JSON(500, gin.H{
"status": "failed to parse uuid",
})
}
// Get event id from query // Get event id from query
eventIdOrig, ok := c.GetQuery("event_id") eventIdOrig, ok := c.GetQuery("event_id")
@@ -34,8 +40,7 @@ func Checkin(c *gin.Context) {
}) })
return return
} }
data.UserId = userId
data.UserId = userId.(uuid.UUID)
code, err := data.GenCheckinCode(eventId) code, err := data.GenCheckinCode(eventId)
if err != nil { if err != nil {
c.JSON(500, gin.H{ c.JSON(500, gin.H{
@@ -50,15 +55,21 @@ func Checkin(c *gin.Context) {
} }
func CheckinSubmit(c *gin.Context) { func CheckinSubmit(c *gin.Context) {
userId, ok := c.Get("user_id") userIdOrig, ok := c.Get("user_id")
if !ok { if !ok {
c.JSON(403, gin.H{ c.JSON(403, gin.H{
"status": "unauthorized", "status": "unauthorized",
}) })
} }
userId, err := uuid.Parse(userIdOrig.(string))
if err != nil {
c.JSON(500, gin.H{
"status": "failed to parse uuid",
})
}
userData := new(data.User) userData := new(data.User)
userData.GetByUserId(userId.(uuid.UUID)) userData.GetByUserId(userId)
if userData.PermissionLevel <= 20 { if userData.PermissionLevel <= 20 {
c.JSON(403, gin.H{ c.JSON(403, gin.H{
"status": "access denied", "status": "access denied",

View File

@@ -9,16 +9,22 @@ import (
func Info(c *gin.Context) { func Info(c *gin.Context) {
userData := new(data.User) userData := new(data.User)
userId, ok := c.Get("user_id") userIdOrig, ok := c.Get("user_id")
if !ok { if !ok {
c.JSON(404, gin.H{ c.JSON(404, gin.H{
"status": "user not found", "status": "user not found",
}) })
return return
} }
userId, err := uuid.Parse(userIdOrig.(string))
if err != nil {
c.JSON(500, gin.H{
"status": "failed to parse uuid",
})
}
// Get user from database // Get user from database
user, err := userData.GetByUserId(userId.(uuid.UUID)) user, err := userData.GetByUserId(userId)
if err != nil { if err != nil {
c.JSON(404, gin.H{ c.JSON(404, gin.H{
"status": "user not found", "status": "user not found",

View File

@@ -8,11 +8,17 @@ import (
) )
func Query(c *gin.Context) { func Query(c *gin.Context) {
userId, ok := c.Get("user_id") userIdOrig, ok := c.Get("user_id")
if !ok { if !ok {
c.JSON(400, gin.H{"status": "could not found user_id"}) c.JSON(400, gin.H{"status": "could not found user_id"})
return return
} }
userId, err := uuid.Parse(userIdOrig.(string))
if err != nil {
c.JSON(500, gin.H{
"status": "failed to parse uuid",
})
}
eventIdOrig, ok := c.GetQuery("event_id") eventIdOrig, ok := c.GetQuery("event_id")
if !ok { if !ok {
@@ -26,7 +32,7 @@ func Query(c *gin.Context) {
} }
attendanceData := new(data.Attendance) attendanceData := new(data.Attendance)
attendance, err := attendanceData.GetAttendance(userId.(uuid.UUID), eventId) attendance, err := attendanceData.GetAttendance(userId, eventId)
if err != nil { if err != nil {
c.JSON(500, gin.H{"status": "database error"}) c.JSON(500, gin.H{"status": "database error"})
return return

View File

@@ -13,16 +13,22 @@ func Update(c *gin.Context) {
// New user model // New user model
user := new(data.User) user := new(data.User)
userId, ok := c.Get("user_id") userIdOrig, ok := c.Get("user_id")
if !ok { if !ok {
c.JSON(403, gin.H{ c.JSON(403, gin.H{
"status": "can not found user id", "status": "can not found user id",
}) })
return return
} }
userId, err := uuid.Parse(userIdOrig.(string))
if err != nil {
c.JSON(500, gin.H{
"status": "failed to parse uuid",
})
}
// Get user info // Get user info
user.GetByUserId(userId.(uuid.UUID)) user.GetByUserId(userId)
// Reject permission 0 user // Reject permission 0 user
if user.PermissionLevel == 0 { if user.PermissionLevel == 0 {
@@ -38,7 +44,7 @@ func Update(c *gin.Context) {
user.Subtitle = ReqInfo.Subtitle user.Subtitle = ReqInfo.Subtitle
// Update user info // Update user info
user.UpdateByUserID(userId.(uuid.UUID)) user.UpdateByUserID(userId)
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": "success", "status": "success",