From aee434025e22ecf9c8b7f0593db704e91c5f7223 Mon Sep 17 00:00:00 2001 From: Jonathon Reinhart Date: Sun, 14 Mar 2021 22:38:58 -0400 Subject: [PATCH] Implement graceful shutdown This borrows heavily from the implementation of net/http Server.Shutdown(). Closes #10 --- smtpd.go | 118 +++++++++++++++++++++++++++++++++++++++++++++----- smtpd_test.go | 86 ++++++++++++++++++++++++++++++++++++ 2 files changed, 194 insertions(+), 10 deletions(-) diff --git a/smtpd.go b/smtpd.go index 0dbf38c..ca7ac31 100644 --- a/smtpd.go +++ b/smtpd.go @@ -4,10 +4,13 @@ package smtpd import ( "bufio" "crypto/tls" + "errors" "fmt" "log" "net" "strings" + "sync" + "sync/atomic" "time" ) @@ -49,6 +52,14 @@ type Server struct { ForceTLS bool // Force STARTTLS usage. ProtocolLogger *log.Logger + + // mu guards doneChan and makes closing it and listener atomic from + // perspective of Serve() + mu sync.Mutex + doneChan chan struct{} + listener *net.Listener + waitgrp sync.WaitGroup + inShutdown atomicBool // true when server is in shutdown } // Protocol represents the protocol used in the SMTP session @@ -82,6 +93,10 @@ type Error struct { // Error returns a string representation of the SMTP error func (e Error) Error() string { return fmt.Sprintf("%d %s", e.Code, e.Message) } +// ErrServerClosed is returned by the Server's Serve and ListenAndServe, +// methods after a call to Shutdown. +var ErrServerClosed = errors.New("smtp: Server closed") + type session struct { server *Server @@ -134,6 +149,9 @@ func (srv *Server) newSession(c net.Conn) (s *session) { // ListenAndServe starts the SMTP server and listens on the address provided func (srv *Server) ListenAndServe(addr string) error { + if srv.shuttingDown() { + return ErrServerClosed + } srv.configureDefaults() @@ -147,23 +165,31 @@ func (srv *Server) ListenAndServe(addr string) error { // Serve starts the SMTP server and listens on the Listener provided func (srv *Server) Serve(l net.Listener) error { + if srv.shuttingDown() { + return ErrServerClosed + } srv.configureDefaults() + l = &onceCloseListener{Listener: l} defer l.Close() + srv.listener = &l var limiter chan struct{} if srv.MaxConnections > 0 { limiter = make(chan struct{}, srv.MaxConnections) - } else { - limiter = nil } for { - conn, e := l.Accept() if e != nil { + select { + case <-srv.getDoneChan(): + return ErrServerClosed + default: + } + if ne, ok := e.(net.Error); ok && ne.Temporary() { time.Sleep(time.Second) continue @@ -173,8 +199,10 @@ func (srv *Server) Serve(l net.Listener) error { session := srv.newSession(conn) - if limiter != nil { - go func() { + srv.waitgrp.Add(1) + go func() { + defer srv.waitgrp.Done() + if limiter != nil { select { case limiter <- struct{}{}: session.serve() @@ -182,15 +210,32 @@ func (srv *Server) Serve(l net.Listener) error { default: session.reject() } - }() - } else { - go session.serve() - } - + } else { + session.serve() + } + }() } } +func (srv *Server) Shutdown() error { + var lnerr error + srv.inShutdown.setTrue() + + // First close the listener + srv.mu.Lock() + if srv.listener != nil { + lnerr = (*srv.listener).Close(); + } + srv.closeDoneChanLocked() + srv.mu.Unlock() + + // Now wait for all client connections to close + srv.waitgrp.Wait() + + return lnerr +} + func (srv *Server) configureDefaults() { if srv.MaxMessageSize == 0 { @@ -366,3 +411,56 @@ func (session *session) close() { time.Sleep(200 * time.Millisecond) session.conn.Close() } + + +// From net/http/server.go + +func (s *Server) shuttingDown() bool { + return s.inShutdown.isSet() +} + +func (s *Server) getDoneChan() <-chan struct{} { + s.mu.Lock() + defer s.mu.Unlock() + return s.getDoneChanLocked() +} + +func (s *Server) getDoneChanLocked() chan struct{} { + if s.doneChan == nil { + s.doneChan = make(chan struct{}) + } + return s.doneChan +} + +func (s *Server) closeDoneChanLocked() { + ch := s.getDoneChanLocked() + select { + case <-ch: + // Already closed. Don't close again. + default: + // Safe to close here. We're the only closer, guarded + // by s.mu. + close(ch) + } +} + +// onceCloseListener wraps a net.Listener, protecting it from +// multiple Close calls. +type onceCloseListener struct { + net.Listener + once sync.Once + closeErr error +} + +func (oc *onceCloseListener) Close() error { + oc.once.Do(oc.close) + return oc.closeErr +} + +func (oc *onceCloseListener) close() { oc.closeErr = oc.Listener.Close() } + +type atomicBool int32 + +func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 } +func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) } +func (b *atomicBool) setFalse() { atomic.StoreInt32((*int32)(b), 0) } diff --git a/smtpd_test.go b/smtpd_test.go index dbe8d85..048de5d 100644 --- a/smtpd_test.go +++ b/smtpd_test.go @@ -1436,3 +1436,89 @@ func TestTLSListener(t *testing.T) { } } + +func TestShutdown(t *testing.T) { + fmt.Println("Starting test") + server := &smtpd.Server{ + ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile), + } + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Listen failed: %v", err) + } + + srvres := make(chan error) + go func() { + t.Log("Starting server") + srvres <- server.Serve(ln) + }() + + // Connect a client + c, err := smtp.Dial(ln.Addr().String()) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + + if err := c.Hello("localhost"); err != nil { + t.Fatalf("HELO failed: %v", err) + } + + // While the client connection is open, shut down the server + shutres := make(chan error) + go func() { + t.Log("Waiting for server shutdown") + shutres <- server.Shutdown() + }() + + // Slight delay to ensure Shutdown() blocks + time.Sleep(250 * time.Millisecond) + + // Shutdown() should not have returned yet due to open client conn + select { + case shuterr := <-shutres: + t.Fatalf("Shutdown() returned early w/ error: %v", shuterr) + default: + } + + // Now close the client + t.Log("Closing client connection") + if err := c.Quit(); err != nil { + t.Fatalf("QUIT failed: %v", err) + } + c.Close() + + // Wait for Shutdown() to return + t.Log("Waiting for Shutdown() to return") + select { + case shuterr := <-shutres: + if shuterr != nil { + t.Fatalf("Shutdown() returned error: %v", shuterr) + } + case <-time.After(15 * time.Second): + t.Fatalf("Timed out waiting for Shutdown() to return") + } + + // Wait for Serve() to return + t.Log("Waiting for Serve() to return") + select { + case srverr := <-srvres: + if srverr != smtpd.ErrServerClosed { + t.Fatalf("Serve() returned error: %v", srverr) + } + case <-time.After(15 * time.Second): + t.Fatalf("Timed out waiting for Serve() to return") + } +} + +func TestServeFailsIfShutdown(t *testing.T) { + server := &smtpd.Server{} + err := server.Shutdown() + if err != nil { + t.Fatalf("Shutdown() failed: %v", err) + } + err = server.Serve(nil) + if err != smtpd.ErrServerClosed { + t.Fatalf("Serve() did not return ErrServerClosed: %v", err) + } +}