diff --git a/internal/cryptography/token.go b/internal/cryptography/token.go index e49551a..612b079 100644 --- a/internal/cryptography/token.go +++ b/internal/cryptography/token.go @@ -4,8 +4,10 @@ import ( "context" "crypto/rand" "encoding/base64" + "errors" "fmt" "nixcn-cms/data" + "strings" "time" "github.com/golang-jwt/jwt/v5" @@ -187,3 +189,34 @@ func (self *Token) RevokeRefreshToken(refreshToken string) error { return err } + +func (self *Token) HeaderVerify(header string) (string, error) { + if header == "" { + return "", nil + } + + jwtSecret := []byte(viper.GetString("secrets.jwt_secret")) + // Split header to 2 + parts := strings.SplitN(header, " ", 2) + if len(parts) != 2 || parts[0] != "Bearer" { + return "", errors.New("invalid Authorization header format") + } + + tokenStr := parts[1] + + // Verify access token + claims := &JwtClaims{} + token, err := jwt.ParseWithClaims( + tokenStr, + claims, + func(token *jwt.Token) (any, error) { + return jwtSecret, nil + }, + ) + + if err != nil || !token.Valid { + return "", errors.New("invalid or expired token") + } + + return claims.UserID.String(), nil +} diff --git a/middleware/jwt.go b/middleware/jwt.go index ca6351e..5328b7e 100644 --- a/middleware/jwt.go +++ b/middleware/jwt.go @@ -1,57 +1,30 @@ package middleware import ( - "net/http" - "strings" - "nixcn-cms/internal/cryptography" "github.com/gin-gonic/gin" - "github.com/golang-jwt/jwt/v5" - "github.com/spf13/viper" ) func JWTAuth() gin.HandlerFunc { - jwtSecret := []byte(viper.GetString("secrets.jwt_secret")) return func(c *gin.Context) { auth := c.GetHeader("Authorization") - if auth == "" { - c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ - "error": "missing Authorization header", - }) + + token := new(cryptography.Token) + uid, err := token.HeaderVerify(auth) + if err != nil { + c.JSON(401, gin.H{"status": err.Error()}) return } - // Split header to 2 - parts := strings.SplitN(auth, " ", 2) - if len(parts) != 2 || parts[0] != "Bearer" { - c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ - "error": "invalid Authorization header format", - }) + if err == nil && uid == "" { + c.Set("user_id", "") + c.Next() return } - tokenStr := parts[1] - - // Verify access token - claims := &cryptography.JwtClaims{} - token, err := jwt.ParseWithClaims( - tokenStr, - claims, - func(token *jwt.Token) (any, error) { - return jwtSecret, nil - }, - ) - - if err != nil || !token.Valid { - c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ - "error": "invalid or expired token", - }) - return - } - - c.Set("user_id", claims.UserID) + c.Set("user_id", uid) c.Next() } } diff --git a/service/auth/magic.go b/service/auth/magic.go index e11bb33..17c8943 100644 --- a/service/auth/magic.go +++ b/service/auth/magic.go @@ -46,8 +46,8 @@ func Magic(c *gin.Context) { uri := viper.GetString("server.external_url") + "/api/v1/auth/redirect?" + "code=" + code + - "&redirect_uri" + req.RedirectUri + - "&state" + req.State + "&redirect_uri=" + req.RedirectUri + + "&state=" + req.State debugMode := viper.GetString("server.debug_mode") if debugMode == "true" { diff --git a/service/user/checkin.go b/service/user/checkin.go index eff88b0..b2d9a97 100644 --- a/service/user/checkin.go +++ b/service/user/checkin.go @@ -9,13 +9,19 @@ import ( func Checkin(c *gin.Context) { data := new(data.Attendance) - userId, ok := c.Get("user_id") + userIdOrig, ok := c.Get("user_id") if !ok { c.JSON(401, gin.H{ "status": "unauthorized", }) return } + userId, err := uuid.Parse(userIdOrig.(string)) + if err != nil { + c.JSON(500, gin.H{ + "status": "failed to parse uuid", + }) + } // Get event id from query eventIdOrig, ok := c.GetQuery("event_id") @@ -34,8 +40,7 @@ func Checkin(c *gin.Context) { }) return } - - data.UserId = userId.(uuid.UUID) + data.UserId = userId code, err := data.GenCheckinCode(eventId) if err != nil { c.JSON(500, gin.H{ @@ -50,15 +55,21 @@ func Checkin(c *gin.Context) { } func CheckinSubmit(c *gin.Context) { - userId, ok := c.Get("user_id") + userIdOrig, ok := c.Get("user_id") if !ok { c.JSON(403, gin.H{ "status": "unauthorized", }) } + userId, err := uuid.Parse(userIdOrig.(string)) + if err != nil { + c.JSON(500, gin.H{ + "status": "failed to parse uuid", + }) + } userData := new(data.User) - userData.GetByUserId(userId.(uuid.UUID)) + userData.GetByUserId(userId) if userData.PermissionLevel <= 20 { c.JSON(403, gin.H{ "status": "access denied", diff --git a/service/user/info.go b/service/user/info.go index 68059d2..068f796 100644 --- a/service/user/info.go +++ b/service/user/info.go @@ -9,16 +9,22 @@ import ( func Info(c *gin.Context) { userData := new(data.User) - userId, ok := c.Get("user_id") + userIdOrig, ok := c.Get("user_id") if !ok { c.JSON(404, gin.H{ "status": "user not found", }) return } + userId, err := uuid.Parse(userIdOrig.(string)) + if err != nil { + c.JSON(500, gin.H{ + "status": "failed to parse uuid", + }) + } // Get user from database - user, err := userData.GetByUserId(userId.(uuid.UUID)) + user, err := userData.GetByUserId(userId) if err != nil { c.JSON(404, gin.H{ "status": "user not found", diff --git a/service/user/query.go b/service/user/query.go index 71036d9..b01eaaf 100644 --- a/service/user/query.go +++ b/service/user/query.go @@ -8,11 +8,17 @@ import ( ) func Query(c *gin.Context) { - userId, ok := c.Get("user_id") + userIdOrig, ok := c.Get("user_id") if !ok { c.JSON(400, gin.H{"status": "could not found user_id"}) return } + userId, err := uuid.Parse(userIdOrig.(string)) + if err != nil { + c.JSON(500, gin.H{ + "status": "failed to parse uuid", + }) + } eventIdOrig, ok := c.GetQuery("event_id") if !ok { @@ -26,7 +32,7 @@ func Query(c *gin.Context) { } attendanceData := new(data.Attendance) - attendance, err := attendanceData.GetAttendance(userId.(uuid.UUID), eventId) + attendance, err := attendanceData.GetAttendance(userId, eventId) if err != nil { c.JSON(500, gin.H{"status": "database error"}) return diff --git a/service/user/update.go b/service/user/update.go index e7136b3..fe395db 100644 --- a/service/user/update.go +++ b/service/user/update.go @@ -13,16 +13,22 @@ func Update(c *gin.Context) { // New user model user := new(data.User) - userId, ok := c.Get("user_id") + userIdOrig, ok := c.Get("user_id") if !ok { c.JSON(403, gin.H{ "status": "can not found user id", }) return } + userId, err := uuid.Parse(userIdOrig.(string)) + if err != nil { + c.JSON(500, gin.H{ + "status": "failed to parse uuid", + }) + } // Get user info - user.GetByUserId(userId.(uuid.UUID)) + user.GetByUserId(userId) // Reject permission 0 user if user.PermissionLevel == 0 { @@ -38,7 +44,7 @@ func Update(c *gin.Context) { user.Subtitle = ReqInfo.Subtitle // Update user info - user.UpdateByUserID(userId.(uuid.UUID)) + user.UpdateByUserID(userId) c.JSON(200, gin.H{ "status": "success",