Implement graceful shutdown (#11)
* Implement graceful shutdown This borrows heavily from the implementation of net/http Server.Shutdown(). Closes #10 * Add Server.Address() function to return listener address of a Server This is useful for logging, etc. * Refactor out Server.Wait() and add 'wait' param to Shutdown() This allows a consumer which runs multiple Servers (decke/smtprelay) to first shut down all servers (passing wait=false) so no new connections are accepted, and then wait for outstanding client connections to finish.
This commit is contained in:
parent
32be721d71
commit
b5f17a69f6
2 changed files with 239 additions and 10 deletions
139
smtpd.go
139
smtpd.go
|
@ -4,10 +4,13 @@ package smtpd
|
|||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
@ -49,6 +52,14 @@ type Server struct {
|
|||
ForceTLS bool // Force STARTTLS usage.
|
||||
|
||||
ProtocolLogger *log.Logger
|
||||
|
||||
// mu guards doneChan and makes closing it and listener atomic from
|
||||
// perspective of Serve()
|
||||
mu sync.Mutex
|
||||
doneChan chan struct{}
|
||||
listener *net.Listener
|
||||
waitgrp sync.WaitGroup
|
||||
inShutdown atomicBool // true when server is in shutdown
|
||||
}
|
||||
|
||||
// Protocol represents the protocol used in the SMTP session
|
||||
|
@ -82,6 +93,10 @@ type Error struct {
|
|||
// Error returns a string representation of the SMTP error
|
||||
func (e Error) Error() string { return fmt.Sprintf("%d %s", e.Code, e.Message) }
|
||||
|
||||
// ErrServerClosed is returned by the Server's Serve and ListenAndServe,
|
||||
// methods after a call to Shutdown.
|
||||
var ErrServerClosed = errors.New("smtp: Server closed")
|
||||
|
||||
type session struct {
|
||||
server *Server
|
||||
|
||||
|
@ -134,6 +149,9 @@ func (srv *Server) newSession(c net.Conn) (s *session) {
|
|||
|
||||
// ListenAndServe starts the SMTP server and listens on the address provided
|
||||
func (srv *Server) ListenAndServe(addr string) error {
|
||||
if srv.shuttingDown() {
|
||||
return ErrServerClosed
|
||||
}
|
||||
|
||||
srv.configureDefaults()
|
||||
|
||||
|
@ -147,23 +165,31 @@ func (srv *Server) ListenAndServe(addr string) error {
|
|||
|
||||
// Serve starts the SMTP server and listens on the Listener provided
|
||||
func (srv *Server) Serve(l net.Listener) error {
|
||||
if srv.shuttingDown() {
|
||||
return ErrServerClosed
|
||||
}
|
||||
|
||||
srv.configureDefaults()
|
||||
|
||||
l = &onceCloseListener{Listener: l}
|
||||
defer l.Close()
|
||||
srv.listener = &l
|
||||
|
||||
var limiter chan struct{}
|
||||
|
||||
if srv.MaxConnections > 0 {
|
||||
limiter = make(chan struct{}, srv.MaxConnections)
|
||||
} else {
|
||||
limiter = nil
|
||||
}
|
||||
|
||||
for {
|
||||
|
||||
conn, e := l.Accept()
|
||||
if e != nil {
|
||||
select {
|
||||
case <-srv.getDoneChan():
|
||||
return ErrServerClosed
|
||||
default:
|
||||
}
|
||||
|
||||
if ne, ok := e.(net.Error); ok && ne.Temporary() {
|
||||
time.Sleep(time.Second)
|
||||
continue
|
||||
|
@ -173,8 +199,10 @@ func (srv *Server) Serve(l net.Listener) error {
|
|||
|
||||
session := srv.newSession(conn)
|
||||
|
||||
if limiter != nil {
|
||||
go func() {
|
||||
srv.waitgrp.Add(1)
|
||||
go func() {
|
||||
defer srv.waitgrp.Done()
|
||||
if limiter != nil {
|
||||
select {
|
||||
case limiter <- struct{}{}:
|
||||
session.serve()
|
||||
|
@ -182,15 +210,53 @@ func (srv *Server) Serve(l net.Listener) error {
|
|||
default:
|
||||
session.reject()
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
go session.serve()
|
||||
}
|
||||
|
||||
} else {
|
||||
session.serve()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Shutdown instructs the server to shutdown, starting by closing the
|
||||
// associated listener. If wait is true, it will wait for the shutdown
|
||||
// to complete. If wait is false, Wait must be called afterwards.
|
||||
func (srv *Server) Shutdown(wait bool) error {
|
||||
var lnerr error
|
||||
srv.inShutdown.setTrue()
|
||||
|
||||
// First close the listener
|
||||
srv.mu.Lock()
|
||||
if srv.listener != nil {
|
||||
lnerr = (*srv.listener).Close();
|
||||
}
|
||||
srv.closeDoneChanLocked()
|
||||
srv.mu.Unlock()
|
||||
|
||||
// Now wait for all client connections to close
|
||||
if wait {
|
||||
srv.Wait()
|
||||
}
|
||||
|
||||
return lnerr
|
||||
}
|
||||
|
||||
// Wait waits for all client connections to close and the server to finish
|
||||
// shutting down.
|
||||
func (srv *Server) Wait() error {
|
||||
if !srv.shuttingDown() {
|
||||
return errors.New("Server has not been Shutdown")
|
||||
}
|
||||
|
||||
srv.waitgrp.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Address returns the listening address of the server
|
||||
func (srv *Server) Address() net.Addr {
|
||||
return (*srv.listener).Addr();
|
||||
}
|
||||
|
||||
func (srv *Server) configureDefaults() {
|
||||
|
||||
if srv.MaxMessageSize == 0 {
|
||||
|
@ -366,3 +432,56 @@ func (session *session) close() {
|
|||
time.Sleep(200 * time.Millisecond)
|
||||
session.conn.Close()
|
||||
}
|
||||
|
||||
|
||||
// From net/http/server.go
|
||||
|
||||
func (s *Server) shuttingDown() bool {
|
||||
return s.inShutdown.isSet()
|
||||
}
|
||||
|
||||
func (s *Server) getDoneChan() <-chan struct{} {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.getDoneChanLocked()
|
||||
}
|
||||
|
||||
func (s *Server) getDoneChanLocked() chan struct{} {
|
||||
if s.doneChan == nil {
|
||||
s.doneChan = make(chan struct{})
|
||||
}
|
||||
return s.doneChan
|
||||
}
|
||||
|
||||
func (s *Server) closeDoneChanLocked() {
|
||||
ch := s.getDoneChanLocked()
|
||||
select {
|
||||
case <-ch:
|
||||
// Already closed. Don't close again.
|
||||
default:
|
||||
// Safe to close here. We're the only closer, guarded
|
||||
// by s.mu.
|
||||
close(ch)
|
||||
}
|
||||
}
|
||||
|
||||
// onceCloseListener wraps a net.Listener, protecting it from
|
||||
// multiple Close calls.
|
||||
type onceCloseListener struct {
|
||||
net.Listener
|
||||
once sync.Once
|
||||
closeErr error
|
||||
}
|
||||
|
||||
func (oc *onceCloseListener) Close() error {
|
||||
oc.once.Do(oc.close)
|
||||
return oc.closeErr
|
||||
}
|
||||
|
||||
func (oc *onceCloseListener) close() { oc.closeErr = oc.Listener.Close() }
|
||||
|
||||
type atomicBool int32
|
||||
|
||||
func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 }
|
||||
func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) }
|
||||
func (b *atomicBool) setFalse() { atomic.StoreInt32((*int32)(b), 0) }
|
||||
|
|
110
smtpd_test.go
110
smtpd_test.go
|
@ -1436,3 +1436,113 @@ func TestTLSListener(t *testing.T) {
|
|||
}
|
||||
|
||||
}
|
||||
|
||||
func TestShutdown(t *testing.T) {
|
||||
fmt.Println("Starting test")
|
||||
server := &smtpd.Server{
|
||||
ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile),
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Listen failed: %v", err)
|
||||
}
|
||||
|
||||
srvres := make(chan error)
|
||||
go func() {
|
||||
t.Log("Starting server")
|
||||
srvres <- server.Serve(ln)
|
||||
}()
|
||||
|
||||
// Connect a client
|
||||
c, err := smtp.Dial(ln.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("Dial failed: %v", err)
|
||||
}
|
||||
|
||||
if err := c.Hello("localhost"); err != nil {
|
||||
t.Fatalf("HELO failed: %v", err)
|
||||
}
|
||||
|
||||
// While the client connection is open, shut down the server (without
|
||||
// waiting for it to finish)
|
||||
err = server.Shutdown(false)
|
||||
if err != nil {
|
||||
t.Fatalf("Shutdown returned error: %v", err)
|
||||
}
|
||||
|
||||
// Verify that Shutdown() worked by attempting to connect another client
|
||||
_, err = smtp.Dial(ln.Addr().String())
|
||||
if err == nil {
|
||||
t.Fatalf("Dial did not fail as expected")
|
||||
}
|
||||
if _, typok := err.(*net.OpError); !typok {
|
||||
t.Fatalf("Dial did not return net.OpError as expected: %v (%T)", err, err)
|
||||
}
|
||||
|
||||
// Wait for shutdown to complete
|
||||
shutres := make(chan error)
|
||||
go func() {
|
||||
t.Log("Waiting for server shutdown to finish")
|
||||
shutres <- server.Wait()
|
||||
}()
|
||||
|
||||
// Slight delay to ensure Shutdown() blocks
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
|
||||
// Wait() should not have returned yet due to open client conn
|
||||
select {
|
||||
case shuterr := <-shutres:
|
||||
t.Fatalf("Wait() returned early w/ error: %v", shuterr)
|
||||
default:
|
||||
}
|
||||
|
||||
// Now close the client
|
||||
t.Log("Closing client connection")
|
||||
if err := c.Quit(); err != nil {
|
||||
t.Fatalf("QUIT failed: %v", err)
|
||||
}
|
||||
c.Close()
|
||||
|
||||
// Wait for Wait() to return
|
||||
t.Log("Waiting for Wait() to return")
|
||||
select {
|
||||
case shuterr := <-shutres:
|
||||
if shuterr != nil {
|
||||
t.Fatalf("Wait() returned error: %v", shuterr)
|
||||
}
|
||||
case <-time.After(15 * time.Second):
|
||||
t.Fatalf("Timed out waiting for Wait() to return")
|
||||
}
|
||||
|
||||
// Wait for Serve() to return
|
||||
t.Log("Waiting for Serve() to return")
|
||||
select {
|
||||
case srverr := <-srvres:
|
||||
if srverr != smtpd.ErrServerClosed {
|
||||
t.Fatalf("Serve() returned error: %v", srverr)
|
||||
}
|
||||
case <-time.After(15 * time.Second):
|
||||
t.Fatalf("Timed out waiting for Serve() to return")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeFailsIfShutdown(t *testing.T) {
|
||||
server := &smtpd.Server{}
|
||||
err := server.Shutdown(true)
|
||||
if err != nil {
|
||||
t.Fatalf("Shutdown() failed: %v", err)
|
||||
}
|
||||
err = server.Serve(nil)
|
||||
if err != smtpd.ErrServerClosed {
|
||||
t.Fatalf("Serve() did not return ErrServerClosed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWaitFailsIfNotShutdown(t *testing.T) {
|
||||
server := &smtpd.Server{}
|
||||
err := server.Wait()
|
||||
if err == nil {
|
||||
t.Fatalf("Wait() did not fail as expected")
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue