Files
nixcn-cms/pkgs/email/email.go
Asai Neko aaedddfd2f
All checks were successful
Build Backend (NixCN CMS) TeamCity build finished
Build Frontend (NixCN CMS) TeamCity build finished
Add Exchange SMTP Oauth2 Support (not verified)
Signed-off-by: Asai Neko <sugar@sne.moe>
2026-01-07 18:32:27 +08:00

314 lines
6.8 KiB
Go

package email
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/smtp"
"strings"
"sync"
"time"
"github.com/spf13/viper"
"golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials"
gomail "gopkg.in/gomail.v2"
)
type Client struct {
// basic smtp
dialer *gomail.Dialer
// shared
from string
host string
port int
username string
security string
insecure bool
// auth mode
authMode string
// oauth2
oauth *oauthTokenProvider
}
type oauthTokenProvider struct {
cfg clientcredentials.Config
mu sync.Mutex
token *oauth2.Token
fetchErr error
}
func (p *oauthTokenProvider) getToken(ctx context.Context) (string, error) {
p.mu.Lock()
defer p.mu.Unlock()
if p.token != nil && p.token.Valid() && time.Until(p.token.Expiry) > 60*time.Second {
return p.token.AccessToken, nil
}
tok, err := p.cfg.Token(ctx)
if err != nil {
p.fetchErr = err
return "", err
}
p.token = tok
p.fetchErr = nil
return tok.AccessToken, nil
}
func NewSMTPClient() (*Client, error) {
host := viper.GetString("email.host")
port := viper.GetInt("email.port")
user := viper.GetString("email.username")
pass := viper.GetString("email.password")
from := viper.GetString("email.from")
security := strings.ToLower(viper.GetString("email.security"))
insecure := viper.GetBool("email.insecure_skip_verify")
authMode := strings.ToLower(viper.GetString("email.auth"))
if authMode == "" {
authMode = "basic"
}
if host == "" || port == 0 || user == "" {
return nil, errors.New("SMTP config not set")
}
c := &Client{
from: from,
host: host,
port: port,
username: user,
security: security,
insecure: insecure,
authMode: authMode,
}
switch authMode {
case "basic":
if pass == "" {
return nil, errors.New("SMTP basic auth requires email.password")
}
dialer := gomail.NewDialer(host, port, user, pass)
dialer.TLSConfig = &tls.Config{
ServerName: host,
InsecureSkipVerify: insecure,
}
switch security {
case "ssl":
dialer.SSL = true
case "starttls":
dialer.SSL = false
case "plain", "":
dialer.SSL = false
dialer.TLSConfig = nil
default:
return nil, errors.New("unknown smtp security mode: " + security)
}
c.dialer = dialer
return c, nil
case "oauth2":
if security == "" {
security = "starttls"
c.security = "starttls"
}
if security == "plain" {
return nil, errors.New("oauth2 requires TLS (starttls or ssl); plain is not allowed")
}
tenantID := viper.GetString("email.oauth2.tenant_id")
clientID := viper.GetString("email.oauth2.client_id")
clientSecret := viper.GetString("email.oauth2.client_secret")
scope := viper.GetString("email.oauth2.scope")
if scope == "" {
// Microsoft Learn: client credentials for SMTP uses https://outlook.office365.com/.default :contentReference[oaicite:3]{index=3}
scope = "https://outlook.office365.com/.default"
}
if tenantID == "" || clientID == "" || clientSecret == "" {
return nil, errors.New("oauth2 requires email.oauth2.tenant_id/client_id/client_secret")
}
c.oauth = &oauthTokenProvider{
cfg: clientcredentials.Config{
ClientID: clientID,
ClientSecret: clientSecret,
TokenURL: fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/token", tenantID),
Scopes: []string{scope},
},
}
return c, nil
default:
return nil, errors.New("unknown email.auth: " + authMode)
}
}
func (c *Client) Send(to, subject, html string) (string, error) {
m := gomail.NewMessage()
m.SetHeader("From", c.from)
m.SetHeader("To", to)
m.SetHeader("Subject", subject)
m.SetBody("text/html", html)
switch c.authMode {
case "basic":
if c.dialer == nil {
return "", errors.New("basic dialer not initialized")
}
if err := c.dialer.DialAndSend(m); err != nil {
return "", err
}
return time.Now().Format(time.RFC3339Nano), nil
case "oauth2":
if err := c.sendWithXOAUTH2(m, to); err != nil {
return "", err
}
return time.Now().Format(time.RFC3339Nano), nil
default:
return "", errors.New("unsupported auth mode: " + c.authMode)
}
}
// XOAUTH2 auth for net/smtp
type xoauth2Auth struct {
username string
token string
}
func (a *xoauth2Auth) Start(server *smtp.ServerInfo) (string, []byte, error) {
if !server.TLS {
return "", nil, errors.New("refusing to authenticate over insecure connection")
}
// Microsoft Learn XOAUTH2 Format: user=<user>\x01auth=Bearer <token>\x01\x01 :contentReference[oaicite:4]{index=4}
resp := fmt.Sprintf("user=%s\x01auth=Bearer %s\x01\x01", a.username, a.token)
return "XOAUTH2", []byte(resp), nil
}
func (a *xoauth2Auth) Next(fromServer []byte, more bool) ([]byte, error) {
if more {
return nil, errors.New("unexpected server challenge during XOAUTH2 auth")
}
return nil, nil
}
func (c *Client) sendWithXOAUTH2(m *gomail.Message, rcpt string) error {
if c.oauth == nil {
return errors.New("oauth2 provider not initialized")
}
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
token, err := c.oauth.getToken(ctx)
if err != nil {
return fmt.Errorf("oauth2 token error: %w", err)
}
// write gomail.Message to RFC822
var buf bytes.Buffer
if _, err := m.WriteTo(&buf); err != nil {
return err
}
msg := buf.Bytes()
addr := fmt.Sprintf("%s:%d", c.host, c.port)
tlsCfg := &tls.Config{
ServerName: c.host,
InsecureSkipVerify: c.insecure,
}
var (
conn net.Conn
cl *smtp.Client
)
switch c.security {
case "ssl":
conn, err = tls.Dial("tcp", addr, tlsCfg)
if err != nil {
return err
}
cl, err = smtp.NewClient(conn, c.host)
if err != nil {
_ = conn.Close()
return err
}
case "starttls", "":
conn, err = net.Dial("tcp", addr)
if err != nil {
return err
}
cl, err = smtp.NewClient(conn, c.host)
if err != nil {
_ = conn.Close()
return err
}
// Upgrade with STARTTLS
if ok, _ := cl.Extension("STARTTLS"); ok {
if err := cl.StartTLS(tlsCfg); err != nil {
_ = cl.Close()
return err
}
} else {
_ = cl.Close()
return errors.New("server does not support STARTTLS")
}
default:
return errors.New("unknown smtp security mode: " + c.security)
}
defer func() { _ = cl.Quit() }()
// AUTH XOAUTH2
if err := cl.Auth(&xoauth2Auth{username: c.username, token: token}); err != nil {
return err
}
// MAIL FROM / RCPT TO / DATA
if err := cl.Mail(extractAddress(c.from)); err != nil {
return err
}
if err := cl.Rcpt(rcpt); err != nil {
return err
}
w, err := cl.Data()
if err != nil {
return err
}
if _, err := w.Write(msg); err != nil {
_ = w.Close()
return err
}
return w.Close()
}
func extractAddress(from string) string {
if i := strings.LastIndex(from, "<"); i >= 0 {
if j := strings.LastIndex(from, ">"); j > i {
return strings.TrimSpace(from[i+1 : j])
}
}
return strings.TrimSpace(from)
}