diff --git a/smtpd.go b/smtpd.go index 0dbf38c..852ff30 100644 --- a/smtpd.go +++ b/smtpd.go @@ -4,10 +4,13 @@ package smtpd import ( "bufio" "crypto/tls" + "errors" "fmt" "log" "net" "strings" + "sync" + "sync/atomic" "time" ) @@ -49,6 +52,14 @@ type Server struct { ForceTLS bool // Force STARTTLS usage. ProtocolLogger *log.Logger + + // mu guards doneChan and makes closing it and listener atomic from + // perspective of Serve() + mu sync.Mutex + doneChan chan struct{} + listener *net.Listener + waitgrp sync.WaitGroup + inShutdown atomicBool // true when server is in shutdown } // Protocol represents the protocol used in the SMTP session @@ -82,6 +93,10 @@ type Error struct { // Error returns a string representation of the SMTP error func (e Error) Error() string { return fmt.Sprintf("%d %s", e.Code, e.Message) } +// ErrServerClosed is returned by the Server's Serve and ListenAndServe, +// methods after a call to Shutdown. +var ErrServerClosed = errors.New("smtp: Server closed") + type session struct { server *Server @@ -134,6 +149,9 @@ func (srv *Server) newSession(c net.Conn) (s *session) { // ListenAndServe starts the SMTP server and listens on the address provided func (srv *Server) ListenAndServe(addr string) error { + if srv.shuttingDown() { + return ErrServerClosed + } srv.configureDefaults() @@ -147,23 +165,31 @@ func (srv *Server) ListenAndServe(addr string) error { // Serve starts the SMTP server and listens on the Listener provided func (srv *Server) Serve(l net.Listener) error { + if srv.shuttingDown() { + return ErrServerClosed + } srv.configureDefaults() + l = &onceCloseListener{Listener: l} defer l.Close() + srv.listener = &l var limiter chan struct{} if srv.MaxConnections > 0 { limiter = make(chan struct{}, srv.MaxConnections) - } else { - limiter = nil } for { - conn, e := l.Accept() if e != nil { + select { + case <-srv.getDoneChan(): + return ErrServerClosed + default: + } + if ne, ok := e.(net.Error); ok && ne.Temporary() { time.Sleep(time.Second) continue @@ -173,8 +199,10 @@ func (srv *Server) Serve(l net.Listener) error { session := srv.newSession(conn) - if limiter != nil { - go func() { + srv.waitgrp.Add(1) + go func() { + defer srv.waitgrp.Done() + if limiter != nil { select { case limiter <- struct{}{}: session.serve() @@ -182,15 +210,53 @@ func (srv *Server) Serve(l net.Listener) error { default: session.reject() } - }() - } else { - go session.serve() - } - + } else { + session.serve() + } + }() } } +// Shutdown instructs the server to shutdown, starting by closing the +// associated listener. If wait is true, it will wait for the shutdown +// to complete. If wait is false, Wait must be called afterwards. +func (srv *Server) Shutdown(wait bool) error { + var lnerr error + srv.inShutdown.setTrue() + + // First close the listener + srv.mu.Lock() + if srv.listener != nil { + lnerr = (*srv.listener).Close(); + } + srv.closeDoneChanLocked() + srv.mu.Unlock() + + // Now wait for all client connections to close + if wait { + srv.Wait() + } + + return lnerr +} + +// Wait waits for all client connections to close and the server to finish +// shutting down. +func (srv *Server) Wait() error { + if !srv.shuttingDown() { + return errors.New("Server has not been Shutdown") + } + + srv.waitgrp.Wait() + return nil +} + +// Address returns the listening address of the server +func (srv *Server) Address() net.Addr { + return (*srv.listener).Addr(); +} + func (srv *Server) configureDefaults() { if srv.MaxMessageSize == 0 { @@ -366,3 +432,56 @@ func (session *session) close() { time.Sleep(200 * time.Millisecond) session.conn.Close() } + + +// From net/http/server.go + +func (s *Server) shuttingDown() bool { + return s.inShutdown.isSet() +} + +func (s *Server) getDoneChan() <-chan struct{} { + s.mu.Lock() + defer s.mu.Unlock() + return s.getDoneChanLocked() +} + +func (s *Server) getDoneChanLocked() chan struct{} { + if s.doneChan == nil { + s.doneChan = make(chan struct{}) + } + return s.doneChan +} + +func (s *Server) closeDoneChanLocked() { + ch := s.getDoneChanLocked() + select { + case <-ch: + // Already closed. Don't close again. + default: + // Safe to close here. We're the only closer, guarded + // by s.mu. + close(ch) + } +} + +// onceCloseListener wraps a net.Listener, protecting it from +// multiple Close calls. +type onceCloseListener struct { + net.Listener + once sync.Once + closeErr error +} + +func (oc *onceCloseListener) Close() error { + oc.once.Do(oc.close) + return oc.closeErr +} + +func (oc *onceCloseListener) close() { oc.closeErr = oc.Listener.Close() } + +type atomicBool int32 + +func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 } +func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) } +func (b *atomicBool) setFalse() { atomic.StoreInt32((*int32)(b), 0) } diff --git a/smtpd_test.go b/smtpd_test.go index dbe8d85..3f8488a 100644 --- a/smtpd_test.go +++ b/smtpd_test.go @@ -1436,3 +1436,113 @@ func TestTLSListener(t *testing.T) { } } + +func TestShutdown(t *testing.T) { + fmt.Println("Starting test") + server := &smtpd.Server{ + ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile), + } + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Listen failed: %v", err) + } + + srvres := make(chan error) + go func() { + t.Log("Starting server") + srvres <- server.Serve(ln) + }() + + // Connect a client + c, err := smtp.Dial(ln.Addr().String()) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + + if err := c.Hello("localhost"); err != nil { + t.Fatalf("HELO failed: %v", err) + } + + // While the client connection is open, shut down the server (without + // waiting for it to finish) + err = server.Shutdown(false) + if err != nil { + t.Fatalf("Shutdown returned error: %v", err) + } + + // Verify that Shutdown() worked by attempting to connect another client + _, err = smtp.Dial(ln.Addr().String()) + if err == nil { + t.Fatalf("Dial did not fail as expected") + } + if _, typok := err.(*net.OpError); !typok { + t.Fatalf("Dial did not return net.OpError as expected: %v (%T)", err, err) + } + + // Wait for shutdown to complete + shutres := make(chan error) + go func() { + t.Log("Waiting for server shutdown to finish") + shutres <- server.Wait() + }() + + // Slight delay to ensure Shutdown() blocks + time.Sleep(250 * time.Millisecond) + + // Wait() should not have returned yet due to open client conn + select { + case shuterr := <-shutres: + t.Fatalf("Wait() returned early w/ error: %v", shuterr) + default: + } + + // Now close the client + t.Log("Closing client connection") + if err := c.Quit(); err != nil { + t.Fatalf("QUIT failed: %v", err) + } + c.Close() + + // Wait for Wait() to return + t.Log("Waiting for Wait() to return") + select { + case shuterr := <-shutres: + if shuterr != nil { + t.Fatalf("Wait() returned error: %v", shuterr) + } + case <-time.After(15 * time.Second): + t.Fatalf("Timed out waiting for Wait() to return") + } + + // Wait for Serve() to return + t.Log("Waiting for Serve() to return") + select { + case srverr := <-srvres: + if srverr != smtpd.ErrServerClosed { + t.Fatalf("Serve() returned error: %v", srverr) + } + case <-time.After(15 * time.Second): + t.Fatalf("Timed out waiting for Serve() to return") + } +} + +func TestServeFailsIfShutdown(t *testing.T) { + server := &smtpd.Server{} + err := server.Shutdown(true) + if err != nil { + t.Fatalf("Shutdown() failed: %v", err) + } + err = server.Serve(nil) + if err != smtpd.ErrServerClosed { + t.Fatalf("Serve() did not return ErrServerClosed: %v", err) + } +} + +func TestWaitFailsIfNotShutdown(t *testing.T) { + server := &smtpd.Server{} + err := server.Wait() + if err == nil { + t.Fatalf("Wait() did not fail as expected") + } +}