Implement graceful shutdown (#11)
* Implement graceful shutdown This borrows heavily from the implementation of net/http Server.Shutdown(). Closes #10 * Add Server.Address() function to return listener address of a Server This is useful for logging, etc. * Refactor out Server.Wait() and add 'wait' param to Shutdown() This allows a consumer which runs multiple Servers (decke/smtprelay) to first shut down all servers (passing wait=false) so no new connections are accepted, and then wait for outstanding client connections to finish.
This commit is contained in:
parent
32be721d71
commit
b5f17a69f6
2 changed files with 239 additions and 10 deletions
131
smtpd.go
131
smtpd.go
|
@ -4,10 +4,13 @@ package smtpd
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -49,6 +52,14 @@ type Server struct {
|
||||||
ForceTLS bool // Force STARTTLS usage.
|
ForceTLS bool // Force STARTTLS usage.
|
||||||
|
|
||||||
ProtocolLogger *log.Logger
|
ProtocolLogger *log.Logger
|
||||||
|
|
||||||
|
// mu guards doneChan and makes closing it and listener atomic from
|
||||||
|
// perspective of Serve()
|
||||||
|
mu sync.Mutex
|
||||||
|
doneChan chan struct{}
|
||||||
|
listener *net.Listener
|
||||||
|
waitgrp sync.WaitGroup
|
||||||
|
inShutdown atomicBool // true when server is in shutdown
|
||||||
}
|
}
|
||||||
|
|
||||||
// Protocol represents the protocol used in the SMTP session
|
// Protocol represents the protocol used in the SMTP session
|
||||||
|
@ -82,6 +93,10 @@ type Error struct {
|
||||||
// Error returns a string representation of the SMTP error
|
// Error returns a string representation of the SMTP error
|
||||||
func (e Error) Error() string { return fmt.Sprintf("%d %s", e.Code, e.Message) }
|
func (e Error) Error() string { return fmt.Sprintf("%d %s", e.Code, e.Message) }
|
||||||
|
|
||||||
|
// ErrServerClosed is returned by the Server's Serve and ListenAndServe,
|
||||||
|
// methods after a call to Shutdown.
|
||||||
|
var ErrServerClosed = errors.New("smtp: Server closed")
|
||||||
|
|
||||||
type session struct {
|
type session struct {
|
||||||
server *Server
|
server *Server
|
||||||
|
|
||||||
|
@ -134,6 +149,9 @@ func (srv *Server) newSession(c net.Conn) (s *session) {
|
||||||
|
|
||||||
// ListenAndServe starts the SMTP server and listens on the address provided
|
// ListenAndServe starts the SMTP server and listens on the address provided
|
||||||
func (srv *Server) ListenAndServe(addr string) error {
|
func (srv *Server) ListenAndServe(addr string) error {
|
||||||
|
if srv.shuttingDown() {
|
||||||
|
return ErrServerClosed
|
||||||
|
}
|
||||||
|
|
||||||
srv.configureDefaults()
|
srv.configureDefaults()
|
||||||
|
|
||||||
|
@ -147,23 +165,31 @@ func (srv *Server) ListenAndServe(addr string) error {
|
||||||
|
|
||||||
// Serve starts the SMTP server and listens on the Listener provided
|
// Serve starts the SMTP server and listens on the Listener provided
|
||||||
func (srv *Server) Serve(l net.Listener) error {
|
func (srv *Server) Serve(l net.Listener) error {
|
||||||
|
if srv.shuttingDown() {
|
||||||
|
return ErrServerClosed
|
||||||
|
}
|
||||||
|
|
||||||
srv.configureDefaults()
|
srv.configureDefaults()
|
||||||
|
|
||||||
|
l = &onceCloseListener{Listener: l}
|
||||||
defer l.Close()
|
defer l.Close()
|
||||||
|
srv.listener = &l
|
||||||
|
|
||||||
var limiter chan struct{}
|
var limiter chan struct{}
|
||||||
|
|
||||||
if srv.MaxConnections > 0 {
|
if srv.MaxConnections > 0 {
|
||||||
limiter = make(chan struct{}, srv.MaxConnections)
|
limiter = make(chan struct{}, srv.MaxConnections)
|
||||||
} else {
|
|
||||||
limiter = nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
|
||||||
conn, e := l.Accept()
|
conn, e := l.Accept()
|
||||||
if e != nil {
|
if e != nil {
|
||||||
|
select {
|
||||||
|
case <-srv.getDoneChan():
|
||||||
|
return ErrServerClosed
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
if ne, ok := e.(net.Error); ok && ne.Temporary() {
|
if ne, ok := e.(net.Error); ok && ne.Temporary() {
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
continue
|
continue
|
||||||
|
@ -173,8 +199,10 @@ func (srv *Server) Serve(l net.Listener) error {
|
||||||
|
|
||||||
session := srv.newSession(conn)
|
session := srv.newSession(conn)
|
||||||
|
|
||||||
if limiter != nil {
|
srv.waitgrp.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
|
defer srv.waitgrp.Done()
|
||||||
|
if limiter != nil {
|
||||||
select {
|
select {
|
||||||
case limiter <- struct{}{}:
|
case limiter <- struct{}{}:
|
||||||
session.serve()
|
session.serve()
|
||||||
|
@ -182,13 +210,51 @@ func (srv *Server) Serve(l net.Listener) error {
|
||||||
default:
|
default:
|
||||||
session.reject()
|
session.reject()
|
||||||
}
|
}
|
||||||
}()
|
|
||||||
} else {
|
} else {
|
||||||
go session.serve()
|
session.serve()
|
||||||
|
}
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Shutdown instructs the server to shutdown, starting by closing the
|
||||||
|
// associated listener. If wait is true, it will wait for the shutdown
|
||||||
|
// to complete. If wait is false, Wait must be called afterwards.
|
||||||
|
func (srv *Server) Shutdown(wait bool) error {
|
||||||
|
var lnerr error
|
||||||
|
srv.inShutdown.setTrue()
|
||||||
|
|
||||||
|
// First close the listener
|
||||||
|
srv.mu.Lock()
|
||||||
|
if srv.listener != nil {
|
||||||
|
lnerr = (*srv.listener).Close();
|
||||||
|
}
|
||||||
|
srv.closeDoneChanLocked()
|
||||||
|
srv.mu.Unlock()
|
||||||
|
|
||||||
|
// Now wait for all client connections to close
|
||||||
|
if wait {
|
||||||
|
srv.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
return lnerr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait waits for all client connections to close and the server to finish
|
||||||
|
// shutting down.
|
||||||
|
func (srv *Server) Wait() error {
|
||||||
|
if !srv.shuttingDown() {
|
||||||
|
return errors.New("Server has not been Shutdown")
|
||||||
|
}
|
||||||
|
|
||||||
|
srv.waitgrp.Wait()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Address returns the listening address of the server
|
||||||
|
func (srv *Server) Address() net.Addr {
|
||||||
|
return (*srv.listener).Addr();
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Server) configureDefaults() {
|
func (srv *Server) configureDefaults() {
|
||||||
|
@ -366,3 +432,56 @@ func (session *session) close() {
|
||||||
time.Sleep(200 * time.Millisecond)
|
time.Sleep(200 * time.Millisecond)
|
||||||
session.conn.Close()
|
session.conn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// From net/http/server.go
|
||||||
|
|
||||||
|
func (s *Server) shuttingDown() bool {
|
||||||
|
return s.inShutdown.isSet()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) getDoneChan() <-chan struct{} {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
return s.getDoneChanLocked()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) getDoneChanLocked() chan struct{} {
|
||||||
|
if s.doneChan == nil {
|
||||||
|
s.doneChan = make(chan struct{})
|
||||||
|
}
|
||||||
|
return s.doneChan
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) closeDoneChanLocked() {
|
||||||
|
ch := s.getDoneChanLocked()
|
||||||
|
select {
|
||||||
|
case <-ch:
|
||||||
|
// Already closed. Don't close again.
|
||||||
|
default:
|
||||||
|
// Safe to close here. We're the only closer, guarded
|
||||||
|
// by s.mu.
|
||||||
|
close(ch)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// onceCloseListener wraps a net.Listener, protecting it from
|
||||||
|
// multiple Close calls.
|
||||||
|
type onceCloseListener struct {
|
||||||
|
net.Listener
|
||||||
|
once sync.Once
|
||||||
|
closeErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (oc *onceCloseListener) Close() error {
|
||||||
|
oc.once.Do(oc.close)
|
||||||
|
return oc.closeErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (oc *onceCloseListener) close() { oc.closeErr = oc.Listener.Close() }
|
||||||
|
|
||||||
|
type atomicBool int32
|
||||||
|
|
||||||
|
func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 }
|
||||||
|
func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) }
|
||||||
|
func (b *atomicBool) setFalse() { atomic.StoreInt32((*int32)(b), 0) }
|
||||||
|
|
110
smtpd_test.go
110
smtpd_test.go
|
@ -1436,3 +1436,113 @@ func TestTLSListener(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestShutdown(t *testing.T) {
|
||||||
|
fmt.Println("Starting test")
|
||||||
|
server := &smtpd.Server{
|
||||||
|
ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile),
|
||||||
|
}
|
||||||
|
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Listen failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
srvres := make(chan error)
|
||||||
|
go func() {
|
||||||
|
t.Log("Starting server")
|
||||||
|
srvres <- server.Serve(ln)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Connect a client
|
||||||
|
c, err := smtp.Dial(ln.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Dial failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.Hello("localhost"); err != nil {
|
||||||
|
t.Fatalf("HELO failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// While the client connection is open, shut down the server (without
|
||||||
|
// waiting for it to finish)
|
||||||
|
err = server.Shutdown(false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Shutdown returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify that Shutdown() worked by attempting to connect another client
|
||||||
|
_, err = smtp.Dial(ln.Addr().String())
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("Dial did not fail as expected")
|
||||||
|
}
|
||||||
|
if _, typok := err.(*net.OpError); !typok {
|
||||||
|
t.Fatalf("Dial did not return net.OpError as expected: %v (%T)", err, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for shutdown to complete
|
||||||
|
shutres := make(chan error)
|
||||||
|
go func() {
|
||||||
|
t.Log("Waiting for server shutdown to finish")
|
||||||
|
shutres <- server.Wait()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Slight delay to ensure Shutdown() blocks
|
||||||
|
time.Sleep(250 * time.Millisecond)
|
||||||
|
|
||||||
|
// Wait() should not have returned yet due to open client conn
|
||||||
|
select {
|
||||||
|
case shuterr := <-shutres:
|
||||||
|
t.Fatalf("Wait() returned early w/ error: %v", shuterr)
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now close the client
|
||||||
|
t.Log("Closing client connection")
|
||||||
|
if err := c.Quit(); err != nil {
|
||||||
|
t.Fatalf("QUIT failed: %v", err)
|
||||||
|
}
|
||||||
|
c.Close()
|
||||||
|
|
||||||
|
// Wait for Wait() to return
|
||||||
|
t.Log("Waiting for Wait() to return")
|
||||||
|
select {
|
||||||
|
case shuterr := <-shutres:
|
||||||
|
if shuterr != nil {
|
||||||
|
t.Fatalf("Wait() returned error: %v", shuterr)
|
||||||
|
}
|
||||||
|
case <-time.After(15 * time.Second):
|
||||||
|
t.Fatalf("Timed out waiting for Wait() to return")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for Serve() to return
|
||||||
|
t.Log("Waiting for Serve() to return")
|
||||||
|
select {
|
||||||
|
case srverr := <-srvres:
|
||||||
|
if srverr != smtpd.ErrServerClosed {
|
||||||
|
t.Fatalf("Serve() returned error: %v", srverr)
|
||||||
|
}
|
||||||
|
case <-time.After(15 * time.Second):
|
||||||
|
t.Fatalf("Timed out waiting for Serve() to return")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServeFailsIfShutdown(t *testing.T) {
|
||||||
|
server := &smtpd.Server{}
|
||||||
|
err := server.Shutdown(true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Shutdown() failed: %v", err)
|
||||||
|
}
|
||||||
|
err = server.Serve(nil)
|
||||||
|
if err != smtpd.ErrServerClosed {
|
||||||
|
t.Fatalf("Serve() did not return ErrServerClosed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWaitFailsIfNotShutdown(t *testing.T) {
|
||||||
|
server := &smtpd.Server{}
|
||||||
|
err := server.Wait()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("Wait() did not fail as expected")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue