Normalize all email fields to lower chars #14

Merged
sugar merged 1 commits from develop into main 2026-06-12 21:30:38 +00:00
3 changed files with 54 additions and 2 deletions

View File

@@ -2,11 +2,16 @@ package data
import (
"context"
"strings"
"github.com/google/uuid"
"gorm.io/gorm"
)
func NormalizeEmail(s string) string {
return strings.ToLower(strings.TrimSpace(s))
}
type User struct {
Id uint `json:"id" gorm:"primarykey;autoincrement"`
UUID uuid.UUID `json:"uuid" gorm:"type:uuid;uniqueindex;not null"`
@@ -62,7 +67,10 @@ type userOpts struct {
type UserOption func(*userOpts)
func WithEmail(v string) UserOption { return func(o *userOpts) { o.Email = &v } }
func WithEmail(v string) UserOption {
v = NormalizeEmail(v)
return func(o *userOpts) { o.Email = &v }
}
func WithUsername(v string) UserOption { return func(o *userOpts) { o.Username = &v } }
func WithNickname(v string) UserOption { return func(o *userOpts) { o.Nickname = &v } }
func WithSubtitle(v string) UserOption { return func(o *userOpts) { o.Subtitle = &v } }
@@ -112,8 +120,13 @@ func NewUser(opts ...UserOption) *User {
func (self *User) GetByEmail(ctx context.Context, email *string) (*User, error) {
var user User
if email == nil {
return nil, gorm.ErrRecordNotFound
}
normalized := NormalizeEmail(*email)
err := Database.WithContext(ctx).
Where("email = ?", email).
Where("email = ?", normalized).
First(&user).Error
if err != nil {

View File

@@ -38,6 +38,42 @@ func TestUserGetByEmailNotFound(t *testing.T) {
require.Error(t, err)
}
// TestUserEmailIsCaseInsensitive guards against the regression where an
// upper-cased email was treated as a different identifier from the lower-cased
// version, allowing duplicate user rows to be inserted with the same logical
// address.
func TestUserEmailIsCaseInsensitive(t *testing.T) {
testutil.Setup(t)
ctx := context.Background()
u := data.NewUser(
data.WithEmail(" Mixed.Case@Example.COM "),
data.WithUsername("mixedcase"),
data.WithPermissionLevel(10),
)
require.NoError(t, u.Create(ctx))
assert.Equal(t, "mixed.case@example.com", u.Email,
"WithEmail must store the canonical (trimmed + lower-cased) form")
for _, variant := range []string{
"mixed.case@example.com",
"MIXED.CASE@EXAMPLE.COM",
" Mixed.Case@Example.com ",
} {
got, err := new(data.User).GetByEmail(ctx, &variant)
require.NoErrorf(t, err, "GetByEmail(%q) must find the row", variant)
assert.Equal(t, u.UserId, got.UserId)
}
dup := data.NewUser(
data.WithEmail("MIXED.CASE@example.com"),
data.WithUsername("mixedcase-dup"),
data.WithPermissionLevel(10),
)
require.Error(t, dup.Create(ctx),
"the unique index on email must reject case-variant duplicates")
}
func TestUserGetByUserId(t *testing.T) {
testutil.Setup(t)
ctx := context.Background()

View File

@@ -3,6 +3,7 @@ package service_auth
import (
"context"
"net/url"
"nixcn-cms/data"
"nixcn-cms/internal/authcode"
"nixcn-cms/internal/email"
"nixcn-cms/internal/exception"
@@ -46,6 +47,8 @@ func (self *AuthServiceImpl) Magic(payload *MagicPayload) (result *MagicResult)
ctx = exception.ContextWithService(ctx, exception.ServiceAuthMagic)
payload.Data.Email = data.NormalizeEmail(payload.Data.Email)
var ok bool
var err error