package email import ( "bytes" "context" "crypto/tls" "errors" "fmt" "net" "net/smtp" "strings" "sync" "time" "github.com/spf13/viper" "golang.org/x/oauth2" "golang.org/x/oauth2/clientcredentials" gomail "gopkg.in/gomail.v2" ) type Client struct { // basic smtp dialer *gomail.Dialer // shared from string host string port int username string security string insecure bool // auth mode authMode string // oauth2 oauth *oauthTokenProvider } type oauthTokenProvider struct { cfg clientcredentials.Config mu sync.Mutex token *oauth2.Token fetchErr error } func (p *oauthTokenProvider) getToken(ctx context.Context) (string, error) { p.mu.Lock() defer p.mu.Unlock() if p.token != nil && p.token.Valid() && time.Until(p.token.Expiry) > 60*time.Second { return p.token.AccessToken, nil } tok, err := p.cfg.Token(ctx) if err != nil { p.fetchErr = err return "", err } p.token = tok p.fetchErr = nil return tok.AccessToken, nil } func NewSMTPClient() (*Client, error) { host := viper.GetString("email.host") port := viper.GetInt("email.port") user := viper.GetString("email.username") pass := viper.GetString("email.password") from := viper.GetString("email.from") security := strings.ToLower(viper.GetString("email.security")) insecure := viper.GetBool("email.insecure_skip_verify") authMode := strings.ToLower(viper.GetString("email.auth")) if authMode == "" { authMode = "basic" } if host == "" || port == 0 || user == "" { return nil, errors.New("SMTP config not set") } c := &Client{ from: from, host: host, port: port, username: user, security: security, insecure: insecure, authMode: authMode, } switch authMode { case "basic": if pass == "" { return nil, errors.New("SMTP basic auth requires email.password") } dialer := gomail.NewDialer(host, port, user, pass) dialer.TLSConfig = &tls.Config{ ServerName: host, InsecureSkipVerify: insecure, } switch security { case "ssl": dialer.SSL = true case "starttls": dialer.SSL = false case "plain", "": dialer.SSL = false dialer.TLSConfig = nil default: return nil, errors.New("unknown smtp security mode: " + security) } c.dialer = dialer return c, nil case "oauth2": if security == "" { security = "starttls" c.security = "starttls" } if security == "plain" { return nil, errors.New("oauth2 requires TLS (starttls or ssl); plain is not allowed") } tenantID := viper.GetString("email.oauth2.tenant_id") clientID := viper.GetString("email.oauth2.client_id") clientSecret := viper.GetString("email.oauth2.client_secret") scope := viper.GetString("email.oauth2.scope") if scope == "" { // Microsoft Learn: client credentials for SMTP uses https://outlook.office365.com/.default :contentReference[oaicite:3]{index=3} scope = "https://outlook.office365.com/.default" } if tenantID == "" || clientID == "" || clientSecret == "" { return nil, errors.New("oauth2 requires email.oauth2.tenant_id/client_id/client_secret") } c.oauth = &oauthTokenProvider{ cfg: clientcredentials.Config{ ClientID: clientID, ClientSecret: clientSecret, TokenURL: fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/token", tenantID), Scopes: []string{scope}, }, } return c, nil default: return nil, errors.New("unknown email.auth: " + authMode) } } func (c *Client) Send(to, subject, html string) (string, error) { m := gomail.NewMessage() m.SetHeader("From", c.from) m.SetHeader("To", to) m.SetHeader("Subject", subject) m.SetBody("text/html", html) switch c.authMode { case "basic": if c.dialer == nil { return "", errors.New("basic dialer not initialized") } if err := c.dialer.DialAndSend(m); err != nil { return "", err } return time.Now().Format(time.RFC3339Nano), nil case "oauth2": if err := c.sendWithXOAUTH2(m, to); err != nil { return "", err } return time.Now().Format(time.RFC3339Nano), nil default: return "", errors.New("unsupported auth mode: " + c.authMode) } } // XOAUTH2 auth for net/smtp type xoauth2Auth struct { username string token string } func (a *xoauth2Auth) Start(server *smtp.ServerInfo) (string, []byte, error) { if !server.TLS { return "", nil, errors.New("refusing to authenticate over insecure connection") } // Microsoft Learn XOAUTH2 Format: user=\x01auth=Bearer \x01\x01 :contentReference[oaicite:4]{index=4} resp := fmt.Sprintf("user=%s\x01auth=Bearer %s\x01\x01", a.username, a.token) return "XOAUTH2", []byte(resp), nil } func (a *xoauth2Auth) Next(fromServer []byte, more bool) ([]byte, error) { if more { return nil, errors.New("unexpected server challenge during XOAUTH2 auth") } return nil, nil } func (c *Client) sendWithXOAUTH2(m *gomail.Message, rcpt string) error { if c.oauth == nil { return errors.New("oauth2 provider not initialized") } ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) defer cancel() token, err := c.oauth.getToken(ctx) if err != nil { return fmt.Errorf("oauth2 token error: %w", err) } // write gomail.Message to RFC822 var buf bytes.Buffer if _, err := m.WriteTo(&buf); err != nil { return err } msg := buf.Bytes() addr := fmt.Sprintf("%s:%d", c.host, c.port) tlsCfg := &tls.Config{ ServerName: c.host, InsecureSkipVerify: c.insecure, } var ( conn net.Conn cl *smtp.Client ) switch c.security { case "ssl": conn, err = tls.Dial("tcp", addr, tlsCfg) if err != nil { return err } cl, err = smtp.NewClient(conn, c.host) if err != nil { _ = conn.Close() return err } case "starttls", "": conn, err = net.Dial("tcp", addr) if err != nil { return err } cl, err = smtp.NewClient(conn, c.host) if err != nil { _ = conn.Close() return err } // Upgrade with STARTTLS if ok, _ := cl.Extension("STARTTLS"); ok { if err := cl.StartTLS(tlsCfg); err != nil { _ = cl.Close() return err } } else { _ = cl.Close() return errors.New("server does not support STARTTLS") } default: return errors.New("unknown smtp security mode: " + c.security) } defer func() { _ = cl.Quit() }() // AUTH XOAUTH2 if err := cl.Auth(&xoauth2Auth{username: c.username, token: token}); err != nil { return err } // MAIL FROM / RCPT TO / DATA if err := cl.Mail(extractAddress(c.from)); err != nil { return err } if err := cl.Rcpt(rcpt); err != nil { return err } w, err := cl.Data() if err != nil { return err } if _, err := w.Write(msg); err != nil { _ = w.Close() return err } return w.Close() } func extractAddress(from string) string { if i := strings.LastIndex(from, "<"); i >= 0 { if j := strings.LastIndex(from, ">"); j > i { return strings.TrimSpace(from[i+1 : j]) } } return strings.TrimSpace(from) }