Implement graceful shutdown

This borrows heavily from the implementation of net/http Server.Shutdown().

Closes #10
This commit is contained in:
Jonathon Reinhart 2021-03-14 22:38:58 -04:00 committed by Jonathon Reinhart
parent 32be721d71
commit aee434025e
2 changed files with 194 additions and 10 deletions

112
smtpd.go
View file

@ -4,10 +4,13 @@ package smtpd
import ( import (
"bufio" "bufio"
"crypto/tls" "crypto/tls"
"errors"
"fmt" "fmt"
"log" "log"
"net" "net"
"strings" "strings"
"sync"
"sync/atomic"
"time" "time"
) )
@ -49,6 +52,14 @@ type Server struct {
ForceTLS bool // Force STARTTLS usage. ForceTLS bool // Force STARTTLS usage.
ProtocolLogger *log.Logger ProtocolLogger *log.Logger
// mu guards doneChan and makes closing it and listener atomic from
// perspective of Serve()
mu sync.Mutex
doneChan chan struct{}
listener *net.Listener
waitgrp sync.WaitGroup
inShutdown atomicBool // true when server is in shutdown
} }
// Protocol represents the protocol used in the SMTP session // Protocol represents the protocol used in the SMTP session
@ -82,6 +93,10 @@ type Error struct {
// Error returns a string representation of the SMTP error // Error returns a string representation of the SMTP error
func (e Error) Error() string { return fmt.Sprintf("%d %s", e.Code, e.Message) } func (e Error) Error() string { return fmt.Sprintf("%d %s", e.Code, e.Message) }
// ErrServerClosed is returned by the Server's Serve and ListenAndServe,
// methods after a call to Shutdown.
var ErrServerClosed = errors.New("smtp: Server closed")
type session struct { type session struct {
server *Server server *Server
@ -134,6 +149,9 @@ func (srv *Server) newSession(c net.Conn) (s *session) {
// ListenAndServe starts the SMTP server and listens on the address provided // ListenAndServe starts the SMTP server and listens on the address provided
func (srv *Server) ListenAndServe(addr string) error { func (srv *Server) ListenAndServe(addr string) error {
if srv.shuttingDown() {
return ErrServerClosed
}
srv.configureDefaults() srv.configureDefaults()
@ -147,23 +165,31 @@ func (srv *Server) ListenAndServe(addr string) error {
// Serve starts the SMTP server and listens on the Listener provided // Serve starts the SMTP server and listens on the Listener provided
func (srv *Server) Serve(l net.Listener) error { func (srv *Server) Serve(l net.Listener) error {
if srv.shuttingDown() {
return ErrServerClosed
}
srv.configureDefaults() srv.configureDefaults()
l = &onceCloseListener{Listener: l}
defer l.Close() defer l.Close()
srv.listener = &l
var limiter chan struct{} var limiter chan struct{}
if srv.MaxConnections > 0 { if srv.MaxConnections > 0 {
limiter = make(chan struct{}, srv.MaxConnections) limiter = make(chan struct{}, srv.MaxConnections)
} else {
limiter = nil
} }
for { for {
conn, e := l.Accept() conn, e := l.Accept()
if e != nil { if e != nil {
select {
case <-srv.getDoneChan():
return ErrServerClosed
default:
}
if ne, ok := e.(net.Error); ok && ne.Temporary() { if ne, ok := e.(net.Error); ok && ne.Temporary() {
time.Sleep(time.Second) time.Sleep(time.Second)
continue continue
@ -173,8 +199,10 @@ func (srv *Server) Serve(l net.Listener) error {
session := srv.newSession(conn) session := srv.newSession(conn)
if limiter != nil { srv.waitgrp.Add(1)
go func() { go func() {
defer srv.waitgrp.Done()
if limiter != nil {
select { select {
case limiter <- struct{}{}: case limiter <- struct{}{}:
session.serve() session.serve()
@ -182,13 +210,30 @@ func (srv *Server) Serve(l net.Listener) error {
default: default:
session.reject() session.reject()
} }
}()
} else { } else {
go session.serve() session.serve()
}
}()
} }
} }
func (srv *Server) Shutdown() error {
var lnerr error
srv.inShutdown.setTrue()
// First close the listener
srv.mu.Lock()
if srv.listener != nil {
lnerr = (*srv.listener).Close();
}
srv.closeDoneChanLocked()
srv.mu.Unlock()
// Now wait for all client connections to close
srv.waitgrp.Wait()
return lnerr
} }
func (srv *Server) configureDefaults() { func (srv *Server) configureDefaults() {
@ -366,3 +411,56 @@ func (session *session) close() {
time.Sleep(200 * time.Millisecond) time.Sleep(200 * time.Millisecond)
session.conn.Close() session.conn.Close()
} }
// From net/http/server.go
func (s *Server) shuttingDown() bool {
return s.inShutdown.isSet()
}
func (s *Server) getDoneChan() <-chan struct{} {
s.mu.Lock()
defer s.mu.Unlock()
return s.getDoneChanLocked()
}
func (s *Server) getDoneChanLocked() chan struct{} {
if s.doneChan == nil {
s.doneChan = make(chan struct{})
}
return s.doneChan
}
func (s *Server) closeDoneChanLocked() {
ch := s.getDoneChanLocked()
select {
case <-ch:
// Already closed. Don't close again.
default:
// Safe to close here. We're the only closer, guarded
// by s.mu.
close(ch)
}
}
// onceCloseListener wraps a net.Listener, protecting it from
// multiple Close calls.
type onceCloseListener struct {
net.Listener
once sync.Once
closeErr error
}
func (oc *onceCloseListener) Close() error {
oc.once.Do(oc.close)
return oc.closeErr
}
func (oc *onceCloseListener) close() { oc.closeErr = oc.Listener.Close() }
type atomicBool int32
func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 }
func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) }
func (b *atomicBool) setFalse() { atomic.StoreInt32((*int32)(b), 0) }

View file

@ -1436,3 +1436,89 @@ func TestTLSListener(t *testing.T) {
} }
} }
func TestShutdown(t *testing.T) {
fmt.Println("Starting test")
server := &smtpd.Server{
ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile),
}
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Listen failed: %v", err)
}
srvres := make(chan error)
go func() {
t.Log("Starting server")
srvres <- server.Serve(ln)
}()
// Connect a client
c, err := smtp.Dial(ln.Addr().String())
if err != nil {
t.Fatalf("Dial failed: %v", err)
}
if err := c.Hello("localhost"); err != nil {
t.Fatalf("HELO failed: %v", err)
}
// While the client connection is open, shut down the server
shutres := make(chan error)
go func() {
t.Log("Waiting for server shutdown")
shutres <- server.Shutdown()
}()
// Slight delay to ensure Shutdown() blocks
time.Sleep(250 * time.Millisecond)
// Shutdown() should not have returned yet due to open client conn
select {
case shuterr := <-shutres:
t.Fatalf("Shutdown() returned early w/ error: %v", shuterr)
default:
}
// Now close the client
t.Log("Closing client connection")
if err := c.Quit(); err != nil {
t.Fatalf("QUIT failed: %v", err)
}
c.Close()
// Wait for Shutdown() to return
t.Log("Waiting for Shutdown() to return")
select {
case shuterr := <-shutres:
if shuterr != nil {
t.Fatalf("Shutdown() returned error: %v", shuterr)
}
case <-time.After(15 * time.Second):
t.Fatalf("Timed out waiting for Shutdown() to return")
}
// Wait for Serve() to return
t.Log("Waiting for Serve() to return")
select {
case srverr := <-srvres:
if srverr != smtpd.ErrServerClosed {
t.Fatalf("Serve() returned error: %v", srverr)
}
case <-time.After(15 * time.Second):
t.Fatalf("Timed out waiting for Serve() to return")
}
}
func TestServeFailsIfShutdown(t *testing.T) {
server := &smtpd.Server{}
err := server.Shutdown()
if err != nil {
t.Fatalf("Shutdown() failed: %v", err)
}
err = server.Serve(nil)
if err != smtpd.ErrServerClosed {
t.Fatalf("Serve() did not return ErrServerClosed: %v", err)
}
}