diff --git a/smtpd.go b/smtpd.go index 516fcea..852ff30 100644 --- a/smtpd.go +++ b/smtpd.go @@ -218,7 +218,10 @@ func (srv *Server) Serve(l net.Listener) error { } -func (srv *Server) Shutdown() error { +// Shutdown instructs the server to shutdown, starting by closing the +// associated listener. If wait is true, it will wait for the shutdown +// to complete. If wait is false, Wait must be called afterwards. +func (srv *Server) Shutdown(wait bool) error { var lnerr error srv.inShutdown.setTrue() @@ -231,11 +234,24 @@ func (srv *Server) Shutdown() error { srv.mu.Unlock() // Now wait for all client connections to close - srv.waitgrp.Wait() + if wait { + srv.Wait() + } return lnerr } +// Wait waits for all client connections to close and the server to finish +// shutting down. +func (srv *Server) Wait() error { + if !srv.shuttingDown() { + return errors.New("Server has not been Shutdown") + } + + srv.waitgrp.Wait() + return nil +} + // Address returns the listening address of the server func (srv *Server) Address() net.Addr { return (*srv.listener).Addr(); diff --git a/smtpd_test.go b/smtpd_test.go index 048de5d..3f8488a 100644 --- a/smtpd_test.go +++ b/smtpd_test.go @@ -1464,20 +1464,36 @@ func TestShutdown(t *testing.T) { t.Fatalf("HELO failed: %v", err) } - // While the client connection is open, shut down the server + // While the client connection is open, shut down the server (without + // waiting for it to finish) + err = server.Shutdown(false) + if err != nil { + t.Fatalf("Shutdown returned error: %v", err) + } + + // Verify that Shutdown() worked by attempting to connect another client + _, err = smtp.Dial(ln.Addr().String()) + if err == nil { + t.Fatalf("Dial did not fail as expected") + } + if _, typok := err.(*net.OpError); !typok { + t.Fatalf("Dial did not return net.OpError as expected: %v (%T)", err, err) + } + + // Wait for shutdown to complete shutres := make(chan error) go func() { - t.Log("Waiting for server shutdown") - shutres <- server.Shutdown() + t.Log("Waiting for server shutdown to finish") + shutres <- server.Wait() }() // Slight delay to ensure Shutdown() blocks time.Sleep(250 * time.Millisecond) - // Shutdown() should not have returned yet due to open client conn + // Wait() should not have returned yet due to open client conn select { case shuterr := <-shutres: - t.Fatalf("Shutdown() returned early w/ error: %v", shuterr) + t.Fatalf("Wait() returned early w/ error: %v", shuterr) default: } @@ -1488,15 +1504,15 @@ func TestShutdown(t *testing.T) { } c.Close() - // Wait for Shutdown() to return - t.Log("Waiting for Shutdown() to return") + // Wait for Wait() to return + t.Log("Waiting for Wait() to return") select { case shuterr := <-shutres: if shuterr != nil { - t.Fatalf("Shutdown() returned error: %v", shuterr) + t.Fatalf("Wait() returned error: %v", shuterr) } case <-time.After(15 * time.Second): - t.Fatalf("Timed out waiting for Shutdown() to return") + t.Fatalf("Timed out waiting for Wait() to return") } // Wait for Serve() to return @@ -1513,7 +1529,7 @@ func TestShutdown(t *testing.T) { func TestServeFailsIfShutdown(t *testing.T) { server := &smtpd.Server{} - err := server.Shutdown() + err := server.Shutdown(true) if err != nil { t.Fatalf("Shutdown() failed: %v", err) } @@ -1522,3 +1538,11 @@ func TestServeFailsIfShutdown(t *testing.T) { t.Fatalf("Serve() did not return ErrServerClosed: %v", err) } } + +func TestWaitFailsIfNotShutdown(t *testing.T) { + server := &smtpd.Server{} + err := server.Wait() + if err == nil { + t.Fatalf("Wait() did not fail as expected") + } +}