Implement graceful shutdown (#11)

* Implement graceful shutdown

This borrows heavily from the implementation of net/http Server.Shutdown().

Closes #10

* Add Server.Address() function to return listener address of a Server

This is useful for logging, etc.

* Refactor out Server.Wait() and add 'wait' param to Shutdown()

This allows a consumer which runs multiple Servers (decke/smtprelay) to
first shut down all servers (passing wait=false) so no new connections are
accepted, and then wait for outstanding client connections to finish.
This commit is contained in:
Jonathon Reinhart 2021-03-29 04:24:55 -04:00 committed by GitHub
parent 32be721d71
commit b5f17a69f6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 239 additions and 10 deletions

139
smtpd.go
View file

@ -4,10 +4,13 @@ package smtpd
import (
"bufio"
"crypto/tls"
"errors"
"fmt"
"log"
"net"
"strings"
"sync"
"sync/atomic"
"time"
)
@ -49,6 +52,14 @@ type Server struct {
ForceTLS bool // Force STARTTLS usage.
ProtocolLogger *log.Logger
// mu guards doneChan and makes closing it and listener atomic from
// perspective of Serve()
mu sync.Mutex
doneChan chan struct{}
listener *net.Listener
waitgrp sync.WaitGroup
inShutdown atomicBool // true when server is in shutdown
}
// Protocol represents the protocol used in the SMTP session
@ -82,6 +93,10 @@ type Error struct {
// Error returns a string representation of the SMTP error
func (e Error) Error() string { return fmt.Sprintf("%d %s", e.Code, e.Message) }
// ErrServerClosed is returned by the Server's Serve and ListenAndServe,
// methods after a call to Shutdown.
var ErrServerClosed = errors.New("smtp: Server closed")
type session struct {
server *Server
@ -134,6 +149,9 @@ func (srv *Server) newSession(c net.Conn) (s *session) {
// ListenAndServe starts the SMTP server and listens on the address provided
func (srv *Server) ListenAndServe(addr string) error {
if srv.shuttingDown() {
return ErrServerClosed
}
srv.configureDefaults()
@ -147,23 +165,31 @@ func (srv *Server) ListenAndServe(addr string) error {
// Serve starts the SMTP server and listens on the Listener provided
func (srv *Server) Serve(l net.Listener) error {
if srv.shuttingDown() {
return ErrServerClosed
}
srv.configureDefaults()
l = &onceCloseListener{Listener: l}
defer l.Close()
srv.listener = &l
var limiter chan struct{}
if srv.MaxConnections > 0 {
limiter = make(chan struct{}, srv.MaxConnections)
} else {
limiter = nil
}
for {
conn, e := l.Accept()
if e != nil {
select {
case <-srv.getDoneChan():
return ErrServerClosed
default:
}
if ne, ok := e.(net.Error); ok && ne.Temporary() {
time.Sleep(time.Second)
continue
@ -173,8 +199,10 @@ func (srv *Server) Serve(l net.Listener) error {
session := srv.newSession(conn)
if limiter != nil {
go func() {
srv.waitgrp.Add(1)
go func() {
defer srv.waitgrp.Done()
if limiter != nil {
select {
case limiter <- struct{}{}:
session.serve()
@ -182,15 +210,53 @@ func (srv *Server) Serve(l net.Listener) error {
default:
session.reject()
}
}()
} else {
go session.serve()
}
} else {
session.serve()
}
}()
}
}
// Shutdown instructs the server to shutdown, starting by closing the
// associated listener. If wait is true, it will wait for the shutdown
// to complete. If wait is false, Wait must be called afterwards.
func (srv *Server) Shutdown(wait bool) error {
var lnerr error
srv.inShutdown.setTrue()
// First close the listener
srv.mu.Lock()
if srv.listener != nil {
lnerr = (*srv.listener).Close();
}
srv.closeDoneChanLocked()
srv.mu.Unlock()
// Now wait for all client connections to close
if wait {
srv.Wait()
}
return lnerr
}
// Wait waits for all client connections to close and the server to finish
// shutting down.
func (srv *Server) Wait() error {
if !srv.shuttingDown() {
return errors.New("Server has not been Shutdown")
}
srv.waitgrp.Wait()
return nil
}
// Address returns the listening address of the server
func (srv *Server) Address() net.Addr {
return (*srv.listener).Addr();
}
func (srv *Server) configureDefaults() {
if srv.MaxMessageSize == 0 {
@ -366,3 +432,56 @@ func (session *session) close() {
time.Sleep(200 * time.Millisecond)
session.conn.Close()
}
// From net/http/server.go
func (s *Server) shuttingDown() bool {
return s.inShutdown.isSet()
}
func (s *Server) getDoneChan() <-chan struct{} {
s.mu.Lock()
defer s.mu.Unlock()
return s.getDoneChanLocked()
}
func (s *Server) getDoneChanLocked() chan struct{} {
if s.doneChan == nil {
s.doneChan = make(chan struct{})
}
return s.doneChan
}
func (s *Server) closeDoneChanLocked() {
ch := s.getDoneChanLocked()
select {
case <-ch:
// Already closed. Don't close again.
default:
// Safe to close here. We're the only closer, guarded
// by s.mu.
close(ch)
}
}
// onceCloseListener wraps a net.Listener, protecting it from
// multiple Close calls.
type onceCloseListener struct {
net.Listener
once sync.Once
closeErr error
}
func (oc *onceCloseListener) Close() error {
oc.once.Do(oc.close)
return oc.closeErr
}
func (oc *onceCloseListener) close() { oc.closeErr = oc.Listener.Close() }
type atomicBool int32
func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 }
func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) }
func (b *atomicBool) setFalse() { atomic.StoreInt32((*int32)(b), 0) }

View file

@ -1436,3 +1436,113 @@ func TestTLSListener(t *testing.T) {
}
}
func TestShutdown(t *testing.T) {
fmt.Println("Starting test")
server := &smtpd.Server{
ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile),
}
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Listen failed: %v", err)
}
srvres := make(chan error)
go func() {
t.Log("Starting server")
srvres <- server.Serve(ln)
}()
// Connect a client
c, err := smtp.Dial(ln.Addr().String())
if err != nil {
t.Fatalf("Dial failed: %v", err)
}
if err := c.Hello("localhost"); err != nil {
t.Fatalf("HELO failed: %v", err)
}
// While the client connection is open, shut down the server (without
// waiting for it to finish)
err = server.Shutdown(false)
if err != nil {
t.Fatalf("Shutdown returned error: %v", err)
}
// Verify that Shutdown() worked by attempting to connect another client
_, err = smtp.Dial(ln.Addr().String())
if err == nil {
t.Fatalf("Dial did not fail as expected")
}
if _, typok := err.(*net.OpError); !typok {
t.Fatalf("Dial did not return net.OpError as expected: %v (%T)", err, err)
}
// Wait for shutdown to complete
shutres := make(chan error)
go func() {
t.Log("Waiting for server shutdown to finish")
shutres <- server.Wait()
}()
// Slight delay to ensure Shutdown() blocks
time.Sleep(250 * time.Millisecond)
// Wait() should not have returned yet due to open client conn
select {
case shuterr := <-shutres:
t.Fatalf("Wait() returned early w/ error: %v", shuterr)
default:
}
// Now close the client
t.Log("Closing client connection")
if err := c.Quit(); err != nil {
t.Fatalf("QUIT failed: %v", err)
}
c.Close()
// Wait for Wait() to return
t.Log("Waiting for Wait() to return")
select {
case shuterr := <-shutres:
if shuterr != nil {
t.Fatalf("Wait() returned error: %v", shuterr)
}
case <-time.After(15 * time.Second):
t.Fatalf("Timed out waiting for Wait() to return")
}
// Wait for Serve() to return
t.Log("Waiting for Serve() to return")
select {
case srverr := <-srvres:
if srverr != smtpd.ErrServerClosed {
t.Fatalf("Serve() returned error: %v", srverr)
}
case <-time.After(15 * time.Second):
t.Fatalf("Timed out waiting for Serve() to return")
}
}
func TestServeFailsIfShutdown(t *testing.T) {
server := &smtpd.Server{}
err := server.Shutdown(true)
if err != nil {
t.Fatalf("Shutdown() failed: %v", err)
}
err = server.Serve(nil)
if err != smtpd.ErrServerClosed {
t.Fatalf("Serve() did not return ErrServerClosed: %v", err)
}
}
func TestWaitFailsIfNotShutdown(t *testing.T) {
server := &smtpd.Server{}
err := server.Wait()
if err == nil {
t.Fatalf("Wait() did not fail as expected")
}
}