Refactor out Server.Wait() and add 'wait' param to Shutdown()
This allows a consumer which runs multiple Servers (decke/smtprelay) to first shut down all servers (passing wait=false) so no new connections are accepted, and then wait for outstanding client connections to finish.
This commit is contained in:
parent
9fc41453a7
commit
d9909fe818
2 changed files with 52 additions and 12 deletions
20
smtpd.go
20
smtpd.go
|
@ -218,7 +218,10 @@ func (srv *Server) Serve(l net.Listener) error {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Server) Shutdown() error {
|
// Shutdown instructs the server to shutdown, starting by closing the
|
||||||
|
// associated listener. If wait is true, it will wait for the shutdown
|
||||||
|
// to complete. If wait is false, Wait must be called afterwards.
|
||||||
|
func (srv *Server) Shutdown(wait bool) error {
|
||||||
var lnerr error
|
var lnerr error
|
||||||
srv.inShutdown.setTrue()
|
srv.inShutdown.setTrue()
|
||||||
|
|
||||||
|
@ -231,11 +234,24 @@ func (srv *Server) Shutdown() error {
|
||||||
srv.mu.Unlock()
|
srv.mu.Unlock()
|
||||||
|
|
||||||
// Now wait for all client connections to close
|
// Now wait for all client connections to close
|
||||||
srv.waitgrp.Wait()
|
if wait {
|
||||||
|
srv.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
return lnerr
|
return lnerr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Wait waits for all client connections to close and the server to finish
|
||||||
|
// shutting down.
|
||||||
|
func (srv *Server) Wait() error {
|
||||||
|
if !srv.shuttingDown() {
|
||||||
|
return errors.New("Server has not been Shutdown")
|
||||||
|
}
|
||||||
|
|
||||||
|
srv.waitgrp.Wait()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Address returns the listening address of the server
|
// Address returns the listening address of the server
|
||||||
func (srv *Server) Address() net.Addr {
|
func (srv *Server) Address() net.Addr {
|
||||||
return (*srv.listener).Addr();
|
return (*srv.listener).Addr();
|
||||||
|
|
|
@ -1464,20 +1464,36 @@ func TestShutdown(t *testing.T) {
|
||||||
t.Fatalf("HELO failed: %v", err)
|
t.Fatalf("HELO failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// While the client connection is open, shut down the server
|
// While the client connection is open, shut down the server (without
|
||||||
|
// waiting for it to finish)
|
||||||
|
err = server.Shutdown(false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Shutdown returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify that Shutdown() worked by attempting to connect another client
|
||||||
|
_, err = smtp.Dial(ln.Addr().String())
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("Dial did not fail as expected")
|
||||||
|
}
|
||||||
|
if _, typok := err.(*net.OpError); !typok {
|
||||||
|
t.Fatalf("Dial did not return net.OpError as expected: %v (%T)", err, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for shutdown to complete
|
||||||
shutres := make(chan error)
|
shutres := make(chan error)
|
||||||
go func() {
|
go func() {
|
||||||
t.Log("Waiting for server shutdown")
|
t.Log("Waiting for server shutdown to finish")
|
||||||
shutres <- server.Shutdown()
|
shutres <- server.Wait()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Slight delay to ensure Shutdown() blocks
|
// Slight delay to ensure Shutdown() blocks
|
||||||
time.Sleep(250 * time.Millisecond)
|
time.Sleep(250 * time.Millisecond)
|
||||||
|
|
||||||
// Shutdown() should not have returned yet due to open client conn
|
// Wait() should not have returned yet due to open client conn
|
||||||
select {
|
select {
|
||||||
case shuterr := <-shutres:
|
case shuterr := <-shutres:
|
||||||
t.Fatalf("Shutdown() returned early w/ error: %v", shuterr)
|
t.Fatalf("Wait() returned early w/ error: %v", shuterr)
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1488,15 +1504,15 @@ func TestShutdown(t *testing.T) {
|
||||||
}
|
}
|
||||||
c.Close()
|
c.Close()
|
||||||
|
|
||||||
// Wait for Shutdown() to return
|
// Wait for Wait() to return
|
||||||
t.Log("Waiting for Shutdown() to return")
|
t.Log("Waiting for Wait() to return")
|
||||||
select {
|
select {
|
||||||
case shuterr := <-shutres:
|
case shuterr := <-shutres:
|
||||||
if shuterr != nil {
|
if shuterr != nil {
|
||||||
t.Fatalf("Shutdown() returned error: %v", shuterr)
|
t.Fatalf("Wait() returned error: %v", shuterr)
|
||||||
}
|
}
|
||||||
case <-time.After(15 * time.Second):
|
case <-time.After(15 * time.Second):
|
||||||
t.Fatalf("Timed out waiting for Shutdown() to return")
|
t.Fatalf("Timed out waiting for Wait() to return")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait for Serve() to return
|
// Wait for Serve() to return
|
||||||
|
@ -1513,7 +1529,7 @@ func TestShutdown(t *testing.T) {
|
||||||
|
|
||||||
func TestServeFailsIfShutdown(t *testing.T) {
|
func TestServeFailsIfShutdown(t *testing.T) {
|
||||||
server := &smtpd.Server{}
|
server := &smtpd.Server{}
|
||||||
err := server.Shutdown()
|
err := server.Shutdown(true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Shutdown() failed: %v", err)
|
t.Fatalf("Shutdown() failed: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -1522,3 +1538,11 @@ func TestServeFailsIfShutdown(t *testing.T) {
|
||||||
t.Fatalf("Serve() did not return ErrServerClosed: %v", err)
|
t.Fatalf("Serve() did not return ErrServerClosed: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWaitFailsIfNotShutdown(t *testing.T) {
|
||||||
|
server := &smtpd.Server{}
|
||||||
|
err := server.Wait()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("Wait() did not fail as expected")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue