Files
cms-server/middleware/middleware_test.go
Asai Neko 82c412d839
All checks were successful
Server Check Build (NixCN CMS) TeamCity build finished
Add more tests for modules co worked by claude
Signed-off-by: Asai Neko <sugar@sne.moe>
2026-03-26 23:36:40 +08:00

285 lines
7.0 KiB
Go

package middleware
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"nixcn-cms/internal/authtoken"
"nixcn-cms/testutil"
)
func init() { gin.SetMode(gin.TestMode) }
// issueToken generates a real JWT access token for the given user UUID.
func issueToken(t *testing.T, userId uuid.UUID) string {
t.Helper()
tok := &authtoken.Token{Application: viper.GetString("server.application")}
access, _, err := tok.IssueTokens(context.Background(), testutil.TestClientID, userId)
require.NoError(t, err)
return access
}
// ---- GinLogger ----
func TestGinLogger200(t *testing.T) {
testutil.Setup(t)
r := gin.New()
r.Use(GinLogger())
r.GET("/ping", func(c *gin.Context) { c.String(200, "pong") })
req := httptest.NewRequest(http.MethodGet, "/ping", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
assert.Equal(t, "pong", w.Body.String())
}
func TestGinLogger400(t *testing.T) {
testutil.Setup(t)
r := gin.New()
r.Use(GinLogger())
r.GET("/bad", func(c *gin.Context) { c.Status(400) })
req := httptest.NewRequest(http.MethodGet, "/bad", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, 400, w.Code)
}
func TestGinLogger500(t *testing.T) {
testutil.Setup(t)
r := gin.New()
r.Use(GinLogger())
r.GET("/err", func(c *gin.Context) { c.Status(500) })
req := httptest.NewRequest(http.MethodGet, "/err", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, 500, w.Code)
}
func TestGinLoggerBodyPreserved(t *testing.T) {
testutil.Setup(t)
var captured string
r := gin.New()
r.Use(GinLogger())
r.POST("/echo", func(c *gin.Context) {
var m map[string]string
require.NoError(t, c.ShouldBindJSON(&m))
captured = m["msg"]
c.String(200, "ok")
})
body, _ := json.Marshal(map[string]string{"msg": "hello"})
req := httptest.NewRequest(http.MethodPost, "/echo", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
assert.Equal(t, "hello", captured)
}
// ---- JWTAuth ----
func TestJWTAuthNoHeader(t *testing.T) {
testutil.Setup(t)
testutil.SeedClient(t)
r := gin.New()
r.Use(JWTAuth())
r.GET("/protected", func(c *gin.Context) { c.String(200, "ok") })
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, 401, w.Code)
}
func TestJWTAuthBadFormat(t *testing.T) {
testutil.Setup(t)
testutil.SeedClient(t)
r := gin.New()
r.Use(JWTAuth())
r.GET("/protected", func(c *gin.Context) { c.String(200, "ok") })
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
req.Header.Set("Authorization", "not-a-valid-token")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, 401, w.Code)
}
func TestJWTAuthValidToken(t *testing.T) {
testutil.Setup(t)
testutil.SeedClient(t)
user := testutil.SeedUser(t, testutil.RandomEmail(), 0)
token := issueToken(t, user.UUID)
r := gin.New()
r.Use(JWTAuth())
r.GET("/protected", func(c *gin.Context) {
uid, _ := c.Get("user_id")
c.String(200, uid.(string))
})
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
req.Header.Set("Authorization", "Bearer "+token)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
assert.Equal(t, user.UUID.String(), w.Body.String())
}
// ---- Permission ----
func TestPermissionPresetSufficient(t *testing.T) {
testutil.Setup(t)
r := gin.New()
r.Use(func(c *gin.Context) {
c.Set("user_id", uuid.New().String())
c.Set("permission_level", uint(10))
c.Next()
})
r.Use(Permission(5))
r.GET("/admin", func(c *gin.Context) { c.String(200, "ok") })
req := httptest.NewRequest(http.MethodGet, "/admin", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
}
func TestPermissionPresetInsufficient(t *testing.T) {
testutil.Setup(t)
r := gin.New()
r.Use(func(c *gin.Context) {
c.Set("user_id", uuid.New().String())
c.Set("permission_level", uint(1))
c.Next()
})
r.Use(Permission(99))
r.GET("/admin", func(c *gin.Context) { c.String(200, "ok") })
req := httptest.NewRequest(http.MethodGet, "/admin", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, 403, w.Code)
}
func TestPermissionNoUserId(t *testing.T) {
testutil.Setup(t)
r := gin.New()
r.Use(Permission(0))
r.GET("/any", func(c *gin.Context) { c.String(200, "ok") })
req := httptest.NewRequest(http.MethodGet, "/any", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, 401, w.Code)
}
func TestPermissionFromDBSufficient(t *testing.T) {
testutil.Setup(t)
user := testutil.SeedUser(t, testutil.RandomEmail(), 10)
r := gin.New()
r.Use(func(c *gin.Context) {
// Permission middleware looks up by the user_id column.
c.Set("user_id", user.UserId.String())
c.Next()
})
r.Use(Permission(5))
r.GET("/admin", func(c *gin.Context) { c.String(200, "ok") })
req := httptest.NewRequest(http.MethodGet, "/admin", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
}
func TestPermissionFromDBInsufficient(t *testing.T) {
testutil.Setup(t)
user := testutil.SeedUser(t, testutil.RandomEmail(), 1)
r := gin.New()
r.Use(func(c *gin.Context) {
c.Set("user_id", user.UserId.String())
c.Next()
})
r.Use(Permission(99))
r.GET("/admin", func(c *gin.Context) { c.String(200, "ok") })
req := httptest.NewRequest(http.MethodGet, "/admin", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, 403, w.Code)
}
func TestPermissionUserNotFound(t *testing.T) {
testutil.Setup(t)
r := gin.New()
r.Use(func(c *gin.Context) {
c.Set("user_id", uuid.New().String())
c.Next()
})
r.Use(Permission(0))
r.GET("/any", func(c *gin.Context) { c.String(200, "ok") })
req := httptest.NewRequest(http.MethodGet, "/any", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, 404, w.Code)
}
func TestPermissionCachesLevelFromDB(t *testing.T) {
testutil.Setup(t)
user := testutil.SeedUser(t, testutil.RandomEmail(), 10)
calls := 0
r := gin.New()
r.Use(func(c *gin.Context) {
c.Set("user_id", user.UserId.String())
calls++
c.Next()
})
r.Use(Permission(5))
r.GET("/any", func(c *gin.Context) {
lvl, ok := c.Get("permission_level")
assert.True(t, ok)
assert.Equal(t, uint(10), lvl.(uint))
c.String(200, "ok")
})
req := httptest.NewRequest(http.MethodGet, "/any", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
}
func TestJWTAuthBadToken(t *testing.T) {
testutil.Setup(t)
testutil.SeedClient(t)
r := gin.New()
r.Use(JWTAuth())
r.GET("/protected", func(c *gin.Context) { c.String(200, "ok") })
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
req.Header.Set("Authorization", "Bearer completely.invalid.token")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, 401, w.Code)
}