From b2f59a653eb020676f755b2bd7ca1e18c790cab9 Mon Sep 17 00:00:00 2001 From: Christian Joergensen Date: Tue, 15 Jul 2014 11:16:34 +0200 Subject: [PATCH] Fixes. --- smtpd.go | 56 +++++++++++++++++++++++++++++++++++++++------------ smtpd_test.go | 38 ++++++++++++++++++++++++++++++---- 2 files changed, 77 insertions(+), 17 deletions(-) diff --git a/smtpd.go b/smtpd.go index 200cc21..f002e48 100644 --- a/smtpd.go +++ b/smtpd.go @@ -13,14 +13,14 @@ import ( // Server defines the parameters for running the SMTP server type Server struct { - Addr string // Address to listen on when using ListenAndServe (default: "127.0.0.1:10025") - WelcomeMessage string // Initial server banner (default: " ESMTP ready.") + Addr string // Address to listen on when using ListenAndServe. (default: "127.0.0.1:10025") + WelcomeMessage string // Initial server banner. (default: " ESMTP ready.") - ReadTimeout time.Duration // Socket timeout for read operations (default: 60s) - WriteTimeout time.Duration // Socket timeout for write operations (default: 60s) + ReadTimeout time.Duration // Socket timeout for read operations. (default: 60s) + WriteTimeout time.Duration // Socket timeout for write operations. (default: 60s) - MaxMessageSize int // Max message size in bytes (default: 10240000) - MaxConnections int // Max concurrent connections, use -1 to disable (default: 100) + MaxMessageSize int // Max message size in bytes. (default: 10240000) + MaxConnections int // Max concurrent connections, use -1 to disable. (default: 100) // New e-mails are handed off to this function. // Can be left empty for a NOOP server. @@ -30,16 +30,18 @@ type Server struct { // Enable various checks during the SMTP session. // Can be left empty for no restrictions. // If an error is returned, it will be reported in the SMTP session. - HeloChecker func(peer Peer) error // Called after HELO/EHLO. - SenderChecker func(peer Peer, addr string) error // Called after MAIL FROM. - RecipientChecker func(peer Peer, addr string) error // Called after each RCPT TO. + // Use the Error struct for access to error codes. + ConnectionChecker func(peer Peer) error // Called upon new connection. + HeloChecker func(peer Peer) error // Called after HELO/EHLO. + SenderChecker func(peer Peer, addr string) error // Called after MAIL FROM. + RecipientChecker func(peer Peer, addr string) error // Called after each RCPT TO. // Enable PLAIN/LOGIN authentication, only available after STARTTLS. // Can be left empty for no authentication support. Authenticator func(peer Peer, username, password string) error - TLSConfig *tls.Config // Enable STARTTLS support - ForceTLS bool // Force STARTTLS usage + TLSConfig *tls.Config // Enable STARTTLS support. + ForceTLS bool // Force STARTTLS usage. } // Peer represents the client connecting to the server @@ -57,6 +59,15 @@ type Envelope struct { Data []byte } +// Error represents an Error reported in the SMTP session. +type Error struct { + Code int // The integer error code + Message string // The error message +} + +// Error returns a string representation of the SMTP error +func (e Error) Error() string { return fmt.Sprintf("%d %s", e.Code, e.Message) } + type session struct { server *Server @@ -191,7 +202,7 @@ func (session *session) serve() { defer session.close() - session.reply(220, session.server.WelcomeMessage) + session.welcome() for session.scanner.Scan() { session.handle(session.scanner.Text()) @@ -204,6 +215,21 @@ func (session *session) reject() { session.close() } +func (session *session) welcome() { + + if session.server.ConnectionChecker != nil { + err := session.server.ConnectionChecker(session.peer) + if err != nil { + session.error(err) + session.close() + return + } + } + + session.reply(220, session.server.WelcomeMessage) + +} + func (session *session) reply(code int, message string) { fmt.Fprintf(session.writer, "%d %s\r\n", code, message) @@ -216,7 +242,11 @@ func (session *session) reply(code int, message string) { } func (session *session) error(err error) { - session.reply(502, fmt.Sprintf("%s", err)) + if smtpdError, ok := err.(Error); ok { + session.reply(smtpdError.Code, smtpdError.Message) + } else { + session.reply(502, fmt.Sprintf("%s", err)) + } } func (session *session) extensions() []string { diff --git a/smtpd_test.go b/smtpd_test.go index ccc80d5..64d35d4 100644 --- a/smtpd_test.go +++ b/smtpd_test.go @@ -2,7 +2,6 @@ package smtpd import ( "crypto/tls" - "errors" "fmt" "net" "net/smtp" @@ -195,6 +194,31 @@ func TestSTARTTLS(t *testing.T) { } } +func TestConnectionCheck(t *testing.T) { + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Listen failed: %v", err) + } + + defer ln.Close() + + server := &Server{ + ConnectionChecker: func(peer Peer) error { + return Error{Code: 552, Message: "Denied"} + }, + } + + go func() { + server.Serve(ln) + }() + + if _, err := smtp.Dial(ln.Addr().String()); err == nil { + t.Fatal("Dial succeeded despite ConnectionCheck") + } + +} + func TestHELOCheck(t *testing.T) { ln, err := net.Listen("tcp", "127.0.0.1:0") @@ -205,7 +229,9 @@ func TestHELOCheck(t *testing.T) { defer ln.Close() server := &Server{ - HeloChecker: func(peer Peer) error { return errors.New("Denied") }, + HeloChecker: func(peer Peer) error { + return Error{Code: 552, Message: "Denied"} + }, } go func() { @@ -233,7 +259,9 @@ func TestSenderCheck(t *testing.T) { defer ln.Close() server := &Server{ - SenderChecker: func(peer Peer, addr string) error { return errors.New("Denied") }, + SenderChecker: func(peer Peer, addr string) error { + return Error{Code: 552, Message: "Denied"} + }, } go func() { @@ -261,7 +289,9 @@ func TestRecipientCheck(t *testing.T) { defer ln.Close() server := &Server{ - RecipientChecker: func(peer Peer, addr string) error { return errors.New("Denied") }, + RecipientChecker: func(peer Peer, addr string) error { + return Error{Code: 552, Message: "Denied"} + }, } go func() {