Refactor out Server.Wait() and add 'wait' param to Shutdown()

This allows a consumer which runs multiple Servers (decke/smtprelay) to
first shut down all servers (passing wait=false) so no new connections are
accepted, and then wait for outstanding client connections to finish.
This commit is contained in:
Jonathon Reinhart 2021-03-17 23:58:22 -04:00 committed by Jonathon Reinhart
parent 9fc41453a7
commit d9909fe818
2 changed files with 52 additions and 12 deletions

View file

@ -218,7 +218,10 @@ func (srv *Server) Serve(l net.Listener) error {
} }
func (srv *Server) Shutdown() error { // Shutdown instructs the server to shutdown, starting by closing the
// associated listener. If wait is true, it will wait for the shutdown
// to complete. If wait is false, Wait must be called afterwards.
func (srv *Server) Shutdown(wait bool) error {
var lnerr error var lnerr error
srv.inShutdown.setTrue() srv.inShutdown.setTrue()
@ -231,11 +234,24 @@ func (srv *Server) Shutdown() error {
srv.mu.Unlock() srv.mu.Unlock()
// Now wait for all client connections to close // Now wait for all client connections to close
srv.waitgrp.Wait() if wait {
srv.Wait()
}
return lnerr return lnerr
} }
// Wait waits for all client connections to close and the server to finish
// shutting down.
func (srv *Server) Wait() error {
if !srv.shuttingDown() {
return errors.New("Server has not been Shutdown")
}
srv.waitgrp.Wait()
return nil
}
// Address returns the listening address of the server // Address returns the listening address of the server
func (srv *Server) Address() net.Addr { func (srv *Server) Address() net.Addr {
return (*srv.listener).Addr(); return (*srv.listener).Addr();

View file

@ -1464,20 +1464,36 @@ func TestShutdown(t *testing.T) {
t.Fatalf("HELO failed: %v", err) t.Fatalf("HELO failed: %v", err)
} }
// While the client connection is open, shut down the server // While the client connection is open, shut down the server (without
// waiting for it to finish)
err = server.Shutdown(false)
if err != nil {
t.Fatalf("Shutdown returned error: %v", err)
}
// Verify that Shutdown() worked by attempting to connect another client
_, err = smtp.Dial(ln.Addr().String())
if err == nil {
t.Fatalf("Dial did not fail as expected")
}
if _, typok := err.(*net.OpError); !typok {
t.Fatalf("Dial did not return net.OpError as expected: %v (%T)", err, err)
}
// Wait for shutdown to complete
shutres := make(chan error) shutres := make(chan error)
go func() { go func() {
t.Log("Waiting for server shutdown") t.Log("Waiting for server shutdown to finish")
shutres <- server.Shutdown() shutres <- server.Wait()
}() }()
// Slight delay to ensure Shutdown() blocks // Slight delay to ensure Shutdown() blocks
time.Sleep(250 * time.Millisecond) time.Sleep(250 * time.Millisecond)
// Shutdown() should not have returned yet due to open client conn // Wait() should not have returned yet due to open client conn
select { select {
case shuterr := <-shutres: case shuterr := <-shutres:
t.Fatalf("Shutdown() returned early w/ error: %v", shuterr) t.Fatalf("Wait() returned early w/ error: %v", shuterr)
default: default:
} }
@ -1488,15 +1504,15 @@ func TestShutdown(t *testing.T) {
} }
c.Close() c.Close()
// Wait for Shutdown() to return // Wait for Wait() to return
t.Log("Waiting for Shutdown() to return") t.Log("Waiting for Wait() to return")
select { select {
case shuterr := <-shutres: case shuterr := <-shutres:
if shuterr != nil { if shuterr != nil {
t.Fatalf("Shutdown() returned error: %v", shuterr) t.Fatalf("Wait() returned error: %v", shuterr)
} }
case <-time.After(15 * time.Second): case <-time.After(15 * time.Second):
t.Fatalf("Timed out waiting for Shutdown() to return") t.Fatalf("Timed out waiting for Wait() to return")
} }
// Wait for Serve() to return // Wait for Serve() to return
@ -1513,7 +1529,7 @@ func TestShutdown(t *testing.T) {
func TestServeFailsIfShutdown(t *testing.T) { func TestServeFailsIfShutdown(t *testing.T) {
server := &smtpd.Server{} server := &smtpd.Server{}
err := server.Shutdown() err := server.Shutdown(true)
if err != nil { if err != nil {
t.Fatalf("Shutdown() failed: %v", err) t.Fatalf("Shutdown() failed: %v", err)
} }
@ -1522,3 +1538,11 @@ func TestServeFailsIfShutdown(t *testing.T) {
t.Fatalf("Serve() did not return ErrServerClosed: %v", err) t.Fatalf("Serve() did not return ErrServerClosed: %v", err)
} }
} }
func TestWaitFailsIfNotShutdown(t *testing.T) {
server := &smtpd.Server{}
err := server.Wait()
if err == nil {
t.Fatalf("Wait() did not fail as expected")
}
}