smtpd/smtpd.go

301 lines
7.2 KiB
Go

// Package smtpd implements an SMTP server with support for STARTTLS, authentication (PLAIN/LOGIN), XCLIENT and optional restrictions on the different stages of the SMTP session.
package smtpd
import (
"bufio"
"crypto/tls"
"errors"
"fmt"
"log/slog"
"log"
"net"
"sync"
"time"
)
// Server defines the parameters for running the SMTP server
type Server struct {
Hostname string // Server hostname. (default: "localhost.localdomain")
WelcomeMessage string // Initial server banner. (default: "<hostname> ESMTP ready.")
ReadTimeout time.Duration // Socket timeout for read operations. (default: 60s)
WriteTimeout time.Duration // Socket timeout for write operations. (default: 60s)
DataTimeout time.Duration // Socket timeout for DATA command (default: 5m)
MaxConnections int // Max concurrent connections, use -1 to disable. (default: 100)
MaxMessageSize int // Max message size in bytes. (default: 10240000)
MaxRecipients int // Max RCPT TO calls for each envelope. (default: 100)
// New e-mails are handed off to this function.
// Can be left empty for a NOOP server.
// If an error is returned, it will be reported in the SMTP session.
Handler func(peer Peer, env Envelope) error
// Enable various checks during the SMTP session.
// Can be left empty for no restrictions.
// If an error is returned, it will be reported in the SMTP session.
// Use the Error struct for access to error codes.
ConnectionChecker func(peer Peer) error // Called upon new connection.
HeloChecker func(peer Peer, name string) error // Called after HELO/EHLO.
SenderChecker func(peer Peer, addr string) error // Called after MAIL FROM.
RecipientChecker func(peer Peer, addr string) error // Called after each RCPT TO.
// Enable PLAIN/LOGIN authentication, only available after STARTTLS.
// Can be left empty for no authentication support.
Authenticator func(peer Peer, username, password string) error
EnableXCLIENT bool // Enable XCLIENT support (default: false)
EnableProxyProtocol bool // Enable proxy protocol support (default: false)
TLSConfig *tls.Config // Enable STARTTLS support.
ForceTLS bool // Force STARTTLS usage.
ProtocolLogger *slog.Logger
// mu guards doneChan and makes closing it and listener atomic from
// perspective of Serve()
mu sync.Mutex
doneChan chan struct{}
listener *net.Listener
waitgrp sync.WaitGroup
inShutdown atomicBool // true when server is in shutdown
}
// Protocol represents the protocol used in the SMTP session
type Protocol string
const (
// SMTP
SMTP Protocol = "SMTP"
// Extended SMTP
ESMTP = "ESMTP"
)
// Peer represents the client connecting to the server
type Peer struct {
HeloName string // Server name used in HELO/EHLO command
Username string // Username from authentication, if authenticated
Password string // Password from authentication, if authenticated
Protocol Protocol // Protocol used, SMTP or ESMTP
ServerName string // A copy of Server.Hostname
Addr net.Addr // Network address
TLS *tls.ConnectionState // TLS Connection details, if on TLS
}
func (srv *Server) newSession(c net.Conn) (s *session) {
s = &session{
server: srv,
conn: c,
reader: bufio.NewReader(c),
writer: bufio.NewWriter(c),
peer: Peer{
Addr: c.RemoteAddr(),
ServerName: srv.Hostname,
},
}
// Check if the underlying connection is already TLS.
// This will happen if the Listerner provided Serve()
// is from tls.Listen()
var tlsConn *tls.Conn
tlsConn, s.tls = c.(*tls.Conn)
if s.tls {
// run handshake otherwise it's done when we first
// read/write and connection state will be invalid
tlsConn.Handshake()
state := tlsConn.ConnectionState()
s.peer.TLS = &state
}
s.scanner = bufio.NewScanner(s.reader)
return
}
// ListenAndServe starts the SMTP server and listens on the address provided
func (srv *Server) ListenAndServe(addr string) error {
if srv.shuttingDown() {
return ErrServerClosed
}
srv.configureDefaults()
l, err := net.Listen("tcp", addr)
if err != nil {
return err
}
return srv.Serve(l)
}
// Serve starts the SMTP server and listens on the Listener provided
func (srv *Server) Serve(l net.Listener) error {
if srv.shuttingDown() {
return ErrServerClosed
}
srv.configureDefaults()
l = &onceCloseListener{Listener: l}
defer l.Close()
srv.listener = &l
var limiter chan struct{}
if srv.MaxConnections > 0 {
limiter = make(chan struct{}, srv.MaxConnections)
}
for {
conn, e := l.Accept()
if e != nil {
select {
case <-srv.getDoneChan():
return ErrServerClosed
default:
}
if ne, ok := e.(net.Error); ok && ne.Timeout() {
time.Sleep(time.Second)
continue
}
return e
}
session := srv.newSession(conn)
srv.waitgrp.Add(1)
go func() {
defer srv.waitgrp.Done()
if limiter != nil {
select {
case limiter <- struct{}{}:
session.serve()
<-limiter
default:
session.reject()
}
} else {
session.serve()
}
}()
}
}
// Shutdown instructs the server to shutdown, starting by closing the
// associated listener. If wait is true, it will wait for the shutdown
// to complete. If wait is false, Wait must be called afterwards.
func (srv *Server) Shutdown(wait bool) error {
var lnerr error
srv.inShutdown.setTrue()
// First close the listener
srv.mu.Lock()
if srv.listener != nil {
lnerr = (*srv.listener).Close();
}
srv.closeDoneChanLocked()
srv.mu.Unlock()
// Now wait for all client connections to close
if wait {
srv.Wait()
}
return lnerr
}
// Wait waits for all client connections to close and the server to finish
// shutting down.
func (srv *Server) Wait() error {
if !srv.shuttingDown() {
return errors.New("Server has not been Shutdown")
}
srv.waitgrp.Wait()
return nil
}
// Address returns the listening address of the server
func (srv *Server) Address() net.Addr {
return (*srv.listener).Addr();
}
func (srv *Server) configureDefaults() {
if srv.MaxMessageSize == 0 {
srv.MaxMessageSize = 10240000
}
if srv.MaxConnections == 0 {
srv.MaxConnections = 100
}
if srv.MaxRecipients == 0 {
srv.MaxRecipients = 100
}
if srv.ReadTimeout == 0 {
srv.ReadTimeout = time.Second * 60
}
if srv.WriteTimeout == 0 {
srv.WriteTimeout = time.Second * 60
}
if srv.DataTimeout == 0 {
srv.DataTimeout = time.Minute * 5
}
if srv.ForceTLS && srv.TLSConfig == nil {
log.Fatal("Cannot use ForceTLS with no TLSConfig")
}
if srv.Hostname == "" {
srv.Hostname = "localhost.localdomain"
}
if srv.WelcomeMessage == "" {
srv.WelcomeMessage = fmt.Sprintf("%s ESMTP ready.", srv.Hostname)
}
}
// From net/http/server.go
func (s *Server) shuttingDown() bool {
return s.inShutdown.isSet()
}
func (s *Server) getDoneChan() <-chan struct{} {
s.mu.Lock()
defer s.mu.Unlock()
return s.getDoneChanLocked()
}
func (s *Server) getDoneChanLocked() chan struct{} {
if s.doneChan == nil {
s.doneChan = make(chan struct{})
}
return s.doneChan
}
func (s *Server) closeDoneChanLocked() {
ch := s.getDoneChanLocked()
select {
case <-ch:
// Already closed. Don't close again.
default:
// Safe to close here. We're the only closer, guarded
// by s.mu.
close(ch)
}
}