Implement graceful shutdown
This borrows heavily from the implementation of net/http Server.Shutdown(). Closes #10
This commit is contained in:
parent
32be721d71
commit
aee434025e
2 changed files with 194 additions and 10 deletions
112
smtpd.go
112
smtpd.go
|
@ -4,10 +4,13 @@ package smtpd
|
|||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
@ -49,6 +52,14 @@ type Server struct {
|
|||
ForceTLS bool // Force STARTTLS usage.
|
||||
|
||||
ProtocolLogger *log.Logger
|
||||
|
||||
// mu guards doneChan and makes closing it and listener atomic from
|
||||
// perspective of Serve()
|
||||
mu sync.Mutex
|
||||
doneChan chan struct{}
|
||||
listener *net.Listener
|
||||
waitgrp sync.WaitGroup
|
||||
inShutdown atomicBool // true when server is in shutdown
|
||||
}
|
||||
|
||||
// Protocol represents the protocol used in the SMTP session
|
||||
|
@ -82,6 +93,10 @@ type Error struct {
|
|||
// Error returns a string representation of the SMTP error
|
||||
func (e Error) Error() string { return fmt.Sprintf("%d %s", e.Code, e.Message) }
|
||||
|
||||
// ErrServerClosed is returned by the Server's Serve and ListenAndServe,
|
||||
// methods after a call to Shutdown.
|
||||
var ErrServerClosed = errors.New("smtp: Server closed")
|
||||
|
||||
type session struct {
|
||||
server *Server
|
||||
|
||||
|
@ -134,6 +149,9 @@ func (srv *Server) newSession(c net.Conn) (s *session) {
|
|||
|
||||
// ListenAndServe starts the SMTP server and listens on the address provided
|
||||
func (srv *Server) ListenAndServe(addr string) error {
|
||||
if srv.shuttingDown() {
|
||||
return ErrServerClosed
|
||||
}
|
||||
|
||||
srv.configureDefaults()
|
||||
|
||||
|
@ -147,23 +165,31 @@ func (srv *Server) ListenAndServe(addr string) error {
|
|||
|
||||
// Serve starts the SMTP server and listens on the Listener provided
|
||||
func (srv *Server) Serve(l net.Listener) error {
|
||||
if srv.shuttingDown() {
|
||||
return ErrServerClosed
|
||||
}
|
||||
|
||||
srv.configureDefaults()
|
||||
|
||||
l = &onceCloseListener{Listener: l}
|
||||
defer l.Close()
|
||||
srv.listener = &l
|
||||
|
||||
var limiter chan struct{}
|
||||
|
||||
if srv.MaxConnections > 0 {
|
||||
limiter = make(chan struct{}, srv.MaxConnections)
|
||||
} else {
|
||||
limiter = nil
|
||||
}
|
||||
|
||||
for {
|
||||
|
||||
conn, e := l.Accept()
|
||||
if e != nil {
|
||||
select {
|
||||
case <-srv.getDoneChan():
|
||||
return ErrServerClosed
|
||||
default:
|
||||
}
|
||||
|
||||
if ne, ok := e.(net.Error); ok && ne.Temporary() {
|
||||
time.Sleep(time.Second)
|
||||
continue
|
||||
|
@ -173,8 +199,10 @@ func (srv *Server) Serve(l net.Listener) error {
|
|||
|
||||
session := srv.newSession(conn)
|
||||
|
||||
if limiter != nil {
|
||||
srv.waitgrp.Add(1)
|
||||
go func() {
|
||||
defer srv.waitgrp.Done()
|
||||
if limiter != nil {
|
||||
select {
|
||||
case limiter <- struct{}{}:
|
||||
session.serve()
|
||||
|
@ -182,13 +210,30 @@ func (srv *Server) Serve(l net.Listener) error {
|
|||
default:
|
||||
session.reject()
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
go session.serve()
|
||||
session.serve()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *Server) Shutdown() error {
|
||||
var lnerr error
|
||||
srv.inShutdown.setTrue()
|
||||
|
||||
// First close the listener
|
||||
srv.mu.Lock()
|
||||
if srv.listener != nil {
|
||||
lnerr = (*srv.listener).Close();
|
||||
}
|
||||
srv.closeDoneChanLocked()
|
||||
srv.mu.Unlock()
|
||||
|
||||
// Now wait for all client connections to close
|
||||
srv.waitgrp.Wait()
|
||||
|
||||
return lnerr
|
||||
}
|
||||
|
||||
func (srv *Server) configureDefaults() {
|
||||
|
@ -366,3 +411,56 @@ func (session *session) close() {
|
|||
time.Sleep(200 * time.Millisecond)
|
||||
session.conn.Close()
|
||||
}
|
||||
|
||||
|
||||
// From net/http/server.go
|
||||
|
||||
func (s *Server) shuttingDown() bool {
|
||||
return s.inShutdown.isSet()
|
||||
}
|
||||
|
||||
func (s *Server) getDoneChan() <-chan struct{} {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.getDoneChanLocked()
|
||||
}
|
||||
|
||||
func (s *Server) getDoneChanLocked() chan struct{} {
|
||||
if s.doneChan == nil {
|
||||
s.doneChan = make(chan struct{})
|
||||
}
|
||||
return s.doneChan
|
||||
}
|
||||
|
||||
func (s *Server) closeDoneChanLocked() {
|
||||
ch := s.getDoneChanLocked()
|
||||
select {
|
||||
case <-ch:
|
||||
// Already closed. Don't close again.
|
||||
default:
|
||||
// Safe to close here. We're the only closer, guarded
|
||||
// by s.mu.
|
||||
close(ch)
|
||||
}
|
||||
}
|
||||
|
||||
// onceCloseListener wraps a net.Listener, protecting it from
|
||||
// multiple Close calls.
|
||||
type onceCloseListener struct {
|
||||
net.Listener
|
||||
once sync.Once
|
||||
closeErr error
|
||||
}
|
||||
|
||||
func (oc *onceCloseListener) Close() error {
|
||||
oc.once.Do(oc.close)
|
||||
return oc.closeErr
|
||||
}
|
||||
|
||||
func (oc *onceCloseListener) close() { oc.closeErr = oc.Listener.Close() }
|
||||
|
||||
type atomicBool int32
|
||||
|
||||
func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 }
|
||||
func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) }
|
||||
func (b *atomicBool) setFalse() { atomic.StoreInt32((*int32)(b), 0) }
|
||||
|
|
|
@ -1436,3 +1436,89 @@ func TestTLSListener(t *testing.T) {
|
|||
}
|
||||
|
||||
}
|
||||
|
||||
func TestShutdown(t *testing.T) {
|
||||
fmt.Println("Starting test")
|
||||
server := &smtpd.Server{
|
||||
ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile),
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Listen failed: %v", err)
|
||||
}
|
||||
|
||||
srvres := make(chan error)
|
||||
go func() {
|
||||
t.Log("Starting server")
|
||||
srvres <- server.Serve(ln)
|
||||
}()
|
||||
|
||||
// Connect a client
|
||||
c, err := smtp.Dial(ln.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("Dial failed: %v", err)
|
||||
}
|
||||
|
||||
if err := c.Hello("localhost"); err != nil {
|
||||
t.Fatalf("HELO failed: %v", err)
|
||||
}
|
||||
|
||||
// While the client connection is open, shut down the server
|
||||
shutres := make(chan error)
|
||||
go func() {
|
||||
t.Log("Waiting for server shutdown")
|
||||
shutres <- server.Shutdown()
|
||||
}()
|
||||
|
||||
// Slight delay to ensure Shutdown() blocks
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
|
||||
// Shutdown() should not have returned yet due to open client conn
|
||||
select {
|
||||
case shuterr := <-shutres:
|
||||
t.Fatalf("Shutdown() returned early w/ error: %v", shuterr)
|
||||
default:
|
||||
}
|
||||
|
||||
// Now close the client
|
||||
t.Log("Closing client connection")
|
||||
if err := c.Quit(); err != nil {
|
||||
t.Fatalf("QUIT failed: %v", err)
|
||||
}
|
||||
c.Close()
|
||||
|
||||
// Wait for Shutdown() to return
|
||||
t.Log("Waiting for Shutdown() to return")
|
||||
select {
|
||||
case shuterr := <-shutres:
|
||||
if shuterr != nil {
|
||||
t.Fatalf("Shutdown() returned error: %v", shuterr)
|
||||
}
|
||||
case <-time.After(15 * time.Second):
|
||||
t.Fatalf("Timed out waiting for Shutdown() to return")
|
||||
}
|
||||
|
||||
// Wait for Serve() to return
|
||||
t.Log("Waiting for Serve() to return")
|
||||
select {
|
||||
case srverr := <-srvres:
|
||||
if srverr != smtpd.ErrServerClosed {
|
||||
t.Fatalf("Serve() returned error: %v", srverr)
|
||||
}
|
||||
case <-time.After(15 * time.Second):
|
||||
t.Fatalf("Timed out waiting for Serve() to return")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeFailsIfShutdown(t *testing.T) {
|
||||
server := &smtpd.Server{}
|
||||
err := server.Shutdown()
|
||||
if err != nil {
|
||||
t.Fatalf("Shutdown() failed: %v", err)
|
||||
}
|
||||
err = server.Serve(nil)
|
||||
if err != smtpd.ErrServerClosed {
|
||||
t.Fatalf("Serve() did not return ErrServerClosed: %v", err)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue