Implement graceful shutdown (#11)
* Implement graceful shutdown This borrows heavily from the implementation of net/http Server.Shutdown(). Closes #10 * Add Server.Address() function to return listener address of a Server This is useful for logging, etc. * Refactor out Server.Wait() and add 'wait' param to Shutdown() This allows a consumer which runs multiple Servers (decke/smtprelay) to first shut down all servers (passing wait=false) so no new connections are accepted, and then wait for outstanding client connections to finish.
This commit is contained in:
		
							parent
							
								
									32be721d71
								
							
						
					
					
						commit
						b5f17a69f6
					
				
					 2 changed files with 239 additions and 10 deletions
				
			
		
							
								
								
									
										131
									
								
								smtpd.go
									
										
									
									
									
								
							
							
						
						
									
										131
									
								
								smtpd.go
									
										
									
									
									
								
							|  | @ -4,10 +4,13 @@ package smtpd | ||||||
| import ( | import ( | ||||||
| 	"bufio" | 	"bufio" | ||||||
| 	"crypto/tls" | 	"crypto/tls" | ||||||
|  | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"log" | 	"log" | ||||||
| 	"net" | 	"net" | ||||||
| 	"strings" | 	"strings" | ||||||
|  | 	"sync" | ||||||
|  | 	"sync/atomic" | ||||||
| 	"time" | 	"time" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -49,6 +52,14 @@ type Server struct { | ||||||
| 	ForceTLS  bool        // Force STARTTLS usage. | 	ForceTLS  bool        // Force STARTTLS usage. | ||||||
| 
 | 
 | ||||||
| 	ProtocolLogger *log.Logger | 	ProtocolLogger *log.Logger | ||||||
|  | 
 | ||||||
|  | 	// mu guards doneChan and makes closing it and listener atomic from | ||||||
|  | 	// perspective of Serve() | ||||||
|  | 	mu sync.Mutex | ||||||
|  | 	doneChan chan struct{} | ||||||
|  | 	listener *net.Listener | ||||||
|  | 	waitgrp sync.WaitGroup | ||||||
|  | 	inShutdown atomicBool // true when server is in shutdown | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Protocol represents the protocol used in the SMTP session | // Protocol represents the protocol used in the SMTP session | ||||||
|  | @ -82,6 +93,10 @@ type Error struct { | ||||||
| // Error returns a string representation of the SMTP error | // Error returns a string representation of the SMTP error | ||||||
| func (e Error) Error() string { return fmt.Sprintf("%d %s", e.Code, e.Message) } | func (e Error) Error() string { return fmt.Sprintf("%d %s", e.Code, e.Message) } | ||||||
| 
 | 
 | ||||||
|  | // ErrServerClosed is returned by the Server's Serve and ListenAndServe, | ||||||
|  | // methods after a call to Shutdown. | ||||||
|  | var ErrServerClosed = errors.New("smtp: Server closed") | ||||||
|  | 
 | ||||||
| type session struct { | type session struct { | ||||||
| 	server *Server | 	server *Server | ||||||
| 
 | 
 | ||||||
|  | @ -134,6 +149,9 @@ func (srv *Server) newSession(c net.Conn) (s *session) { | ||||||
| 
 | 
 | ||||||
| // ListenAndServe starts the SMTP server and listens on the address provided | // ListenAndServe starts the SMTP server and listens on the address provided | ||||||
| func (srv *Server) ListenAndServe(addr string) error { | func (srv *Server) ListenAndServe(addr string) error { | ||||||
|  | 	if srv.shuttingDown() { | ||||||
|  | 		return ErrServerClosed | ||||||
|  | 	} | ||||||
| 
 | 
 | ||||||
| 	srv.configureDefaults() | 	srv.configureDefaults() | ||||||
| 
 | 
 | ||||||
|  | @ -147,23 +165,31 @@ func (srv *Server) ListenAndServe(addr string) error { | ||||||
| 
 | 
 | ||||||
| // Serve starts the SMTP server and listens on the Listener provided | // Serve starts the SMTP server and listens on the Listener provided | ||||||
| func (srv *Server) Serve(l net.Listener) error { | func (srv *Server) Serve(l net.Listener) error { | ||||||
|  | 	if srv.shuttingDown() { | ||||||
|  | 		return ErrServerClosed | ||||||
|  | 	} | ||||||
| 
 | 
 | ||||||
| 	srv.configureDefaults() | 	srv.configureDefaults() | ||||||
| 
 | 
 | ||||||
|  | 	l = &onceCloseListener{Listener: l} | ||||||
| 	defer l.Close() | 	defer l.Close() | ||||||
|  | 	srv.listener = &l | ||||||
| 
 | 
 | ||||||
| 	var limiter chan struct{} | 	var limiter chan struct{} | ||||||
| 
 | 
 | ||||||
| 	if srv.MaxConnections > 0 { | 	if srv.MaxConnections > 0 { | ||||||
| 		limiter = make(chan struct{}, srv.MaxConnections) | 		limiter = make(chan struct{}, srv.MaxConnections) | ||||||
| 	} else { |  | ||||||
| 		limiter = nil |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	for { | 	for { | ||||||
| 
 |  | ||||||
| 		conn, e := l.Accept() | 		conn, e := l.Accept() | ||||||
| 		if e != nil { | 		if e != nil { | ||||||
|  | 			select { | ||||||
|  | 			case <-srv.getDoneChan(): | ||||||
|  | 				return ErrServerClosed | ||||||
|  | 			default: | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
| 			if ne, ok := e.(net.Error); ok && ne.Temporary() { | 			if ne, ok := e.(net.Error); ok && ne.Temporary() { | ||||||
| 				time.Sleep(time.Second) | 				time.Sleep(time.Second) | ||||||
| 				continue | 				continue | ||||||
|  | @ -173,8 +199,10 @@ func (srv *Server) Serve(l net.Listener) error { | ||||||
| 
 | 
 | ||||||
| 		session := srv.newSession(conn) | 		session := srv.newSession(conn) | ||||||
| 
 | 
 | ||||||
| 		if limiter != nil { | 		srv.waitgrp.Add(1) | ||||||
| 		go func() { | 		go func() { | ||||||
|  | 			defer srv.waitgrp.Done() | ||||||
|  | 			if limiter != nil { | ||||||
| 				select { | 				select { | ||||||
| 				case limiter <- struct{}{}: | 				case limiter <- struct{}{}: | ||||||
| 					session.serve() | 					session.serve() | ||||||
|  | @ -182,13 +210,51 @@ func (srv *Server) Serve(l net.Listener) error { | ||||||
| 				default: | 				default: | ||||||
| 					session.reject() | 					session.reject() | ||||||
| 				} | 				} | ||||||
| 			}() |  | ||||||
| 			} else { | 			} else { | ||||||
| 			go session.serve() | 				session.serve() | ||||||
|  | 			} | ||||||
|  | 		}() | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // Shutdown instructs the server to shutdown, starting by closing the | ||||||
|  | // associated listener. If wait is true, it will wait for the shutdown | ||||||
|  | // to complete. If wait is false, Wait must be called afterwards. | ||||||
|  | func (srv *Server) Shutdown(wait bool) error { | ||||||
|  | 	var lnerr error | ||||||
|  | 	srv.inShutdown.setTrue() | ||||||
|  | 
 | ||||||
|  | 	// First close the listener | ||||||
|  | 	srv.mu.Lock() | ||||||
|  | 	if srv.listener != nil { | ||||||
|  | 		lnerr = (*srv.listener).Close(); | ||||||
|  | 	} | ||||||
|  | 	srv.closeDoneChanLocked() | ||||||
|  | 	srv.mu.Unlock() | ||||||
|  | 
 | ||||||
|  | 	// Now wait for all client connections to close | ||||||
|  | 	if wait { | ||||||
|  | 		srv.Wait() | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return lnerr | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Wait waits for all client connections to close and the server to finish | ||||||
|  | // shutting down. | ||||||
|  | func (srv *Server) Wait() error { | ||||||
|  | 	if !srv.shuttingDown() { | ||||||
|  | 		return errors.New("Server has not been Shutdown") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	srv.waitgrp.Wait() | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Address returns the listening address of the server | ||||||
|  | func (srv *Server) Address() net.Addr { | ||||||
|  | 	return (*srv.listener).Addr(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (srv *Server) configureDefaults() { | func (srv *Server) configureDefaults() { | ||||||
|  | @ -366,3 +432,56 @@ func (session *session) close() { | ||||||
| 	time.Sleep(200 * time.Millisecond) | 	time.Sleep(200 * time.Millisecond) | ||||||
| 	session.conn.Close() | 	session.conn.Close() | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | // From net/http/server.go | ||||||
|  | 
 | ||||||
|  | func (s *Server) shuttingDown() bool { | ||||||
|  | 	return s.inShutdown.isSet() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Server) getDoneChan() <-chan struct{} { | ||||||
|  | 	s.mu.Lock() | ||||||
|  | 	defer s.mu.Unlock() | ||||||
|  | 	return s.getDoneChanLocked() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Server) getDoneChanLocked() chan struct{} { | ||||||
|  | 	if s.doneChan == nil { | ||||||
|  | 		s.doneChan = make(chan struct{}) | ||||||
|  | 	} | ||||||
|  | 	return s.doneChan | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Server) closeDoneChanLocked() { | ||||||
|  | 	ch := s.getDoneChanLocked() | ||||||
|  | 	select { | ||||||
|  | 	case <-ch: | ||||||
|  | 		// Already closed. Don't close again. | ||||||
|  | 	default: | ||||||
|  | 		// Safe to close here. We're the only closer, guarded | ||||||
|  | 		// by s.mu. | ||||||
|  | 		close(ch) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // onceCloseListener wraps a net.Listener, protecting it from | ||||||
|  | // multiple Close calls. | ||||||
|  | type onceCloseListener struct { | ||||||
|  | 	net.Listener | ||||||
|  | 	once     sync.Once | ||||||
|  | 	closeErr error | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (oc *onceCloseListener) Close() error { | ||||||
|  | 	oc.once.Do(oc.close) | ||||||
|  | 	return oc.closeErr | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (oc *onceCloseListener) close() { oc.closeErr = oc.Listener.Close() } | ||||||
|  | 
 | ||||||
|  | type atomicBool int32 | ||||||
|  | 
 | ||||||
|  | func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 } | ||||||
|  | func (b *atomicBool) setTrue()    { atomic.StoreInt32((*int32)(b), 1) } | ||||||
|  | func (b *atomicBool) setFalse()   { atomic.StoreInt32((*int32)(b), 0) } | ||||||
|  |  | ||||||
							
								
								
									
										110
									
								
								smtpd_test.go
									
										
									
									
									
								
							
							
						
						
									
										110
									
								
								smtpd_test.go
									
										
									
									
									
								
							|  | @ -1436,3 +1436,113 @@ func TestTLSListener(t *testing.T) { | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func TestShutdown(t *testing.T) { | ||||||
|  | 	fmt.Println("Starting test") | ||||||
|  | 	server := &smtpd.Server{ | ||||||
|  | 		ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile), | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	ln, err := net.Listen("tcp", "127.0.0.1:0") | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Listen failed: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	srvres := make(chan error) | ||||||
|  | 	go func() { | ||||||
|  | 		t.Log("Starting server") | ||||||
|  | 		srvres <- server.Serve(ln) | ||||||
|  | 	}() | ||||||
|  | 
 | ||||||
|  | 	// Connect a client | ||||||
|  | 	c, err := smtp.Dial(ln.Addr().String()) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Dial failed: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := c.Hello("localhost"); err != nil { | ||||||
|  | 		t.Fatalf("HELO failed: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// While the client connection is open, shut down the server (without | ||||||
|  | 	// waiting for it to finish) | ||||||
|  | 	err = server.Shutdown(false) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Shutdown returned error: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Verify that Shutdown() worked by attempting to connect another client | ||||||
|  | 	_, err = smtp.Dial(ln.Addr().String()) | ||||||
|  | 	if err == nil { | ||||||
|  | 		t.Fatalf("Dial did not fail as expected") | ||||||
|  | 	} | ||||||
|  | 	if _, typok := err.(*net.OpError); !typok { | ||||||
|  | 		t.Fatalf("Dial did not return net.OpError as expected: %v (%T)", err, err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Wait for shutdown to complete | ||||||
|  | 	shutres := make(chan error) | ||||||
|  | 	go func() { | ||||||
|  | 		t.Log("Waiting for server shutdown to finish") | ||||||
|  | 		shutres <- server.Wait() | ||||||
|  | 	}() | ||||||
|  | 
 | ||||||
|  | 	// Slight delay to ensure Shutdown() blocks | ||||||
|  | 	time.Sleep(250 * time.Millisecond) | ||||||
|  | 
 | ||||||
|  | 	// Wait() should not have returned yet due to open client conn | ||||||
|  | 	select { | ||||||
|  | 	case shuterr := <-shutres: | ||||||
|  | 		t.Fatalf("Wait() returned early w/ error: %v", shuterr) | ||||||
|  | 	default: | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Now close the client | ||||||
|  | 	t.Log("Closing client connection") | ||||||
|  | 	if err := c.Quit(); err != nil { | ||||||
|  | 		t.Fatalf("QUIT failed: %v", err) | ||||||
|  | 	} | ||||||
|  | 	c.Close() | ||||||
|  | 
 | ||||||
|  | 	// Wait for Wait() to return | ||||||
|  | 	t.Log("Waiting for Wait() to return") | ||||||
|  | 	select { | ||||||
|  | 	case shuterr := <-shutres: | ||||||
|  | 		if shuterr != nil { | ||||||
|  | 			t.Fatalf("Wait() returned error: %v", shuterr) | ||||||
|  | 		} | ||||||
|  | 	case <-time.After(15 * time.Second): | ||||||
|  | 		t.Fatalf("Timed out waiting for Wait() to return") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Wait for Serve() to return | ||||||
|  | 	t.Log("Waiting for Serve() to return") | ||||||
|  | 	select { | ||||||
|  | 	case srverr := <-srvres: | ||||||
|  | 		if srverr != smtpd.ErrServerClosed { | ||||||
|  | 			t.Fatalf("Serve() returned error: %v", srverr) | ||||||
|  | 		} | ||||||
|  | 	case <-time.After(15 * time.Second): | ||||||
|  | 		t.Fatalf("Timed out waiting for Serve() to return") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestServeFailsIfShutdown(t *testing.T) { | ||||||
|  | 	server := &smtpd.Server{} | ||||||
|  | 	err := server.Shutdown(true) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Shutdown() failed: %v", err) | ||||||
|  | 	} | ||||||
|  | 	err = server.Serve(nil) | ||||||
|  | 	if err != smtpd.ErrServerClosed { | ||||||
|  | 		t.Fatalf("Serve() did not return ErrServerClosed: %v", err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestWaitFailsIfNotShutdown(t *testing.T) { | ||||||
|  | 	server := &smtpd.Server{} | ||||||
|  | 	err := server.Wait() | ||||||
|  | 	if err == nil { | ||||||
|  | 		t.Fatalf("Wait() did not fail as expected") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		
		Reference in a new issue