package middleware import ( "bytes" "context" "encoding/json" "net/http" "net/http/httptest" "testing" "github.com/gin-gonic/gin" "github.com/google/uuid" "github.com/spf13/viper" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "nixcn-cms/internal/authtoken" "nixcn-cms/testutil" ) func init() { gin.SetMode(gin.TestMode) } // issueToken generates a real JWT access token for the given user UUID. func issueToken(t *testing.T, userId uuid.UUID) string { t.Helper() tok := &authtoken.Token{Application: viper.GetString("server.application")} access, _, err := tok.IssueTokens(context.Background(), testutil.TestClientID, userId) require.NoError(t, err) return access } // ---- GinLogger ---- func TestGinLogger200(t *testing.T) { testutil.Setup(t) r := gin.New() r.Use(GinLogger()) r.GET("/ping", func(c *gin.Context) { c.String(200, "pong") }) req := httptest.NewRequest(http.MethodGet, "/ping", nil) w := httptest.NewRecorder() r.ServeHTTP(w, req) assert.Equal(t, 200, w.Code) assert.Equal(t, "pong", w.Body.String()) } func TestGinLogger400(t *testing.T) { testutil.Setup(t) r := gin.New() r.Use(GinLogger()) r.GET("/bad", func(c *gin.Context) { c.Status(400) }) req := httptest.NewRequest(http.MethodGet, "/bad", nil) w := httptest.NewRecorder() r.ServeHTTP(w, req) assert.Equal(t, 400, w.Code) } func TestGinLogger500(t *testing.T) { testutil.Setup(t) r := gin.New() r.Use(GinLogger()) r.GET("/err", func(c *gin.Context) { c.Status(500) }) req := httptest.NewRequest(http.MethodGet, "/err", nil) w := httptest.NewRecorder() r.ServeHTTP(w, req) assert.Equal(t, 500, w.Code) } func TestGinLoggerBodyPreserved(t *testing.T) { testutil.Setup(t) var captured string r := gin.New() r.Use(GinLogger()) r.POST("/echo", func(c *gin.Context) { var m map[string]string require.NoError(t, c.ShouldBindJSON(&m)) captured = m["msg"] c.String(200, "ok") }) body, _ := json.Marshal(map[string]string{"msg": "hello"}) req := httptest.NewRequest(http.MethodPost, "/echo", bytes.NewBuffer(body)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() r.ServeHTTP(w, req) assert.Equal(t, 200, w.Code) assert.Equal(t, "hello", captured) } // ---- JWTAuth ---- func TestJWTAuthNoHeader(t *testing.T) { testutil.Setup(t) testutil.SeedClient(t) r := gin.New() r.Use(JWTAuth()) r.GET("/protected", func(c *gin.Context) { c.String(200, "ok") }) req := httptest.NewRequest(http.MethodGet, "/protected", nil) w := httptest.NewRecorder() r.ServeHTTP(w, req) assert.Equal(t, 401, w.Code) } func TestJWTAuthBadFormat(t *testing.T) { testutil.Setup(t) testutil.SeedClient(t) r := gin.New() r.Use(JWTAuth()) r.GET("/protected", func(c *gin.Context) { c.String(200, "ok") }) req := httptest.NewRequest(http.MethodGet, "/protected", nil) req.Header.Set("Authorization", "not-a-valid-token") w := httptest.NewRecorder() r.ServeHTTP(w, req) assert.Equal(t, 401, w.Code) } func TestJWTAuthValidToken(t *testing.T) { testutil.Setup(t) testutil.SeedClient(t) user := testutil.SeedUser(t, testutil.RandomEmail(), 0) token := issueToken(t, user.UUID) r := gin.New() r.Use(JWTAuth()) r.GET("/protected", func(c *gin.Context) { uid, _ := c.Get("user_id") c.String(200, uid.(string)) }) req := httptest.NewRequest(http.MethodGet, "/protected", nil) req.Header.Set("Authorization", "Bearer "+token) w := httptest.NewRecorder() r.ServeHTTP(w, req) assert.Equal(t, 200, w.Code) assert.Equal(t, user.UUID.String(), w.Body.String()) } // ---- Permission ---- func TestPermissionPresetSufficient(t *testing.T) { testutil.Setup(t) r := gin.New() r.Use(func(c *gin.Context) { c.Set("user_id", uuid.New().String()) c.Set("permission_level", uint(10)) c.Next() }) r.Use(Permission(5)) r.GET("/admin", func(c *gin.Context) { c.String(200, "ok") }) req := httptest.NewRequest(http.MethodGet, "/admin", nil) w := httptest.NewRecorder() r.ServeHTTP(w, req) assert.Equal(t, 200, w.Code) } func TestPermissionPresetInsufficient(t *testing.T) { testutil.Setup(t) r := gin.New() r.Use(func(c *gin.Context) { c.Set("user_id", uuid.New().String()) c.Set("permission_level", uint(1)) c.Next() }) r.Use(Permission(99)) r.GET("/admin", func(c *gin.Context) { c.String(200, "ok") }) req := httptest.NewRequest(http.MethodGet, "/admin", nil) w := httptest.NewRecorder() r.ServeHTTP(w, req) assert.Equal(t, 403, w.Code) } func TestPermissionNoUserId(t *testing.T) { testutil.Setup(t) r := gin.New() r.Use(Permission(0)) r.GET("/any", func(c *gin.Context) { c.String(200, "ok") }) req := httptest.NewRequest(http.MethodGet, "/any", nil) w := httptest.NewRecorder() r.ServeHTTP(w, req) assert.Equal(t, 401, w.Code) } func TestPermissionFromDBSufficient(t *testing.T) { testutil.Setup(t) user := testutil.SeedUser(t, testutil.RandomEmail(), 10) r := gin.New() r.Use(func(c *gin.Context) { // Permission middleware looks up by the user_id column. c.Set("user_id", user.UserId.String()) c.Next() }) r.Use(Permission(5)) r.GET("/admin", func(c *gin.Context) { c.String(200, "ok") }) req := httptest.NewRequest(http.MethodGet, "/admin", nil) w := httptest.NewRecorder() r.ServeHTTP(w, req) assert.Equal(t, 200, w.Code) } func TestPermissionFromDBInsufficient(t *testing.T) { testutil.Setup(t) user := testutil.SeedUser(t, testutil.RandomEmail(), 1) r := gin.New() r.Use(func(c *gin.Context) { c.Set("user_id", user.UserId.String()) c.Next() }) r.Use(Permission(99)) r.GET("/admin", func(c *gin.Context) { c.String(200, "ok") }) req := httptest.NewRequest(http.MethodGet, "/admin", nil) w := httptest.NewRecorder() r.ServeHTTP(w, req) assert.Equal(t, 403, w.Code) } func TestPermissionUserNotFound(t *testing.T) { testutil.Setup(t) r := gin.New() r.Use(func(c *gin.Context) { c.Set("user_id", uuid.New().String()) c.Next() }) r.Use(Permission(0)) r.GET("/any", func(c *gin.Context) { c.String(200, "ok") }) req := httptest.NewRequest(http.MethodGet, "/any", nil) w := httptest.NewRecorder() r.ServeHTTP(w, req) assert.Equal(t, 404, w.Code) } func TestPermissionCachesLevelFromDB(t *testing.T) { testutil.Setup(t) user := testutil.SeedUser(t, testutil.RandomEmail(), 10) calls := 0 r := gin.New() r.Use(func(c *gin.Context) { c.Set("user_id", user.UserId.String()) calls++ c.Next() }) r.Use(Permission(5)) r.GET("/any", func(c *gin.Context) { lvl, ok := c.Get("permission_level") assert.True(t, ok) assert.Equal(t, uint(10), lvl.(uint)) c.String(200, "ok") }) req := httptest.NewRequest(http.MethodGet, "/any", nil) w := httptest.NewRecorder() r.ServeHTTP(w, req) assert.Equal(t, 200, w.Code) } func TestJWTAuthBadToken(t *testing.T) { testutil.Setup(t) testutil.SeedClient(t) r := gin.New() r.Use(JWTAuth()) r.GET("/protected", func(c *gin.Context) { c.String(200, "ok") }) req := httptest.NewRequest(http.MethodGet, "/protected", nil) req.Header.Set("Authorization", "Bearer completely.invalid.token") w := httptest.NewRecorder() r.ServeHTTP(w, req) assert.Equal(t, 401, w.Code) }