diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9f11b75 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +.idea/ diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..1201253 --- /dev/null +++ b/Makefile @@ -0,0 +1,12 @@ +test: check + +check: lint + go test -v + +lint: + gofmt -w=true -s=true -l=true ./ + golint ./... + go vet ./... + + + diff --git a/envelope.go b/envelope.go index 0fa4fe1..55d72c4 100644 --- a/envelope.go +++ b/envelope.go @@ -15,7 +15,7 @@ type Envelope struct { } // AddReceivedLine prepends a Received header to the Data -func (env *Envelope) AddReceivedLine(peer Peer) { +func (env *Envelope) AddReceivedLine(peer *Peer) { tlsDetails := "" @@ -66,3 +66,16 @@ func (env *Envelope) AddReceivedLine(peer Peer) { copy(env.Data, line) } + +// AddHeader adds header, it should be called before AddReceivedLine, since it adds +// header to the top +func (env *Envelope) AddHeader(name, value string) { + line := wrap([]byte(fmt.Sprintf("%s: %s\r\n", name, value))) + + env.Data = append(env.Data, line...) + + // Move the new newly added header line up front + + copy(env.Data[len(line):], env.Data[0:len(env.Data)-len(line)]) + copy(env.Data, line) +} diff --git a/example_test.go b/example_test.go index a61241e..24ac8f5 100644 --- a/example_test.go +++ b/example_test.go @@ -2,6 +2,8 @@ package smtpd_test import ( "errors" + "fmt" + "log" "net/smtp" "strings" @@ -18,14 +20,34 @@ func ExampleServer() { // Relay server. Accepts only from single IP address and forwards using the Gmail smtp server = &smtpd.Server{ - HeloChecker: func(peer smtpd.Peer, name string) error { + HeloChecker: func(peer *smtpd.Peer, name string) error { if !strings.HasPrefix(peer.Addr.String(), "42.42.42.42:") { - return errors.New("Denied") + return errors.New("denied") } return nil }, - Handler: func(peer smtpd.Peer, env smtpd.Envelope) error { + RecipientChecker: func(peer *smtpd.Peer, addr string) error { + peer.IncrInt("recipientsQueued", 1) + if strings.HasPrefix("scubad1ver", addr) { + // it is bad idea ;-) + peer.SetString("never_send_to_scuba", addr) + } + return nil + }, + + Handler: func(peer *smtpd.Peer, env smtpd.Envelope) error { + numberOfRecepients, ok := peer.Meta["recipientsQueued"] + if ok { + log.Printf("Sending email for %v recipients...", numberOfRecepients) + } + never, ok := peer.Meta["never_send_to_scuba"] + if ok { + return smtpd.Error{ + Code: 521, + Message: fmt.Sprintf("Sending messages to %s is a bad idea", never), + } + } return smtp.SendMail( "smtp.gmail.com:587", diff --git a/go.mod b/go.mod index a883ba1..972bce1 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/chrj/smtpd -go 1.14 +go 1.19 diff --git a/protocol.go b/protocol.go index a047f02..16e1f3b 100644 --- a/protocol.go +++ b/protocol.go @@ -61,7 +61,8 @@ func parseLine(line string) (cmd command) { } func (session *session) handle(line string) { - + session.mutex.Lock() + defer session.mutex.Unlock() cmd := parseLine(line) // Commands are dispatched to the appropriate handler functions. @@ -470,7 +471,7 @@ func (session *session) handleAUTH(cmd command) { encodedUsername := "" if len(cmd.fields) < 3 { - session.reply(334, "VXNlcm5hbWU6") + session.reply(334, "VXNlcm5hbWU6") // `Username:` if !session.scanner.Scan() { return } @@ -486,7 +487,7 @@ func (session *session) handleAUTH(cmd command) { return } - session.reply(334, "UGFzc3dvcmQ6") + session.reply(334, "UGFzc3dvcmQ6") // `Password:` if !session.scanner.Scan() { return @@ -626,12 +627,13 @@ func (session *session) handleXCLIENT(cmd command) { } func (session *session) handlePROXY(cmd command) { - + session.logf("Proxy command: %s", cmd.line) + // http://www.haproxy.org/download/1.8/doc/proxy-protocol.txt + // Example: `PROXY TCP4 8.8.8.8 127.0.0.1 443 25` if !session.server.EnableProxyProtocol { session.reply(550, "Proxy Protocol not enabled") return } - if len(cmd.fields) < 6 { session.reply(502, "Couldn't decode the command.") return @@ -642,7 +644,6 @@ func (session *session) handlePROXY(cmd command) { newTCPPort uint64 = 0 err error ) - newAddr = net.ParseIP(cmd.fields[2]) newTCPPort, err = strconv.ParseUint(cmd.fields[4], 10, 16) @@ -664,6 +665,9 @@ func (session *session) handlePROXY(cmd command) { if newTCPPort != 0 { tcpAddr.Port = int(newTCPPort) } + session.logf("Proxy processed: new address - %s:%v", + tcpAddr.IP, tcpAddr.Port, + ) session.welcome() diff --git a/smtpd.go b/smtpd.go index 852ff30..a232a25 100644 --- a/smtpd.go +++ b/smtpd.go @@ -30,20 +30,20 @@ type Server struct { // New e-mails are handed off to this function. // Can be left empty for a NOOP server. // If an error is returned, it will be reported in the SMTP session. - Handler func(peer Peer, env Envelope) error + Handler func(peer *Peer, env Envelope) error // Enable various checks during the SMTP session. // Can be left empty for no restrictions. // If an error is returned, it will be reported in the SMTP session. // Use the Error struct for access to error codes. - ConnectionChecker func(peer Peer) error // Called upon new connection. - HeloChecker func(peer Peer, name string) error // Called after HELO/EHLO. - SenderChecker func(peer Peer, addr string) error // Called after MAIL FROM. - RecipientChecker func(peer Peer, addr string) error // Called after each RCPT TO. + ConnectionChecker func(peer *Peer) error // Called upon new connection. + HeloChecker func(peer *Peer, name string) error // Called after HELO/EHLO. + SenderChecker func(peer *Peer, addr string) error // Called after MAIL FROM. + RecipientChecker func(peer *Peer, addr string) error // Called after each RCPT TO. // Enable PLAIN/LOGIN authentication, only available after STARTTLS. // Can be left empty for no authentication support. - Authenticator func(peer Peer, username, password string) error + Authenticator func(peer *Peer, username, password string) error EnableXCLIENT bool // Enable XCLIENT support (default: false) EnableProxyProtocol bool // Enable proxy protocol support (default: false) @@ -55,10 +55,10 @@ type Server struct { // mu guards doneChan and makes closing it and listener atomic from // perspective of Serve() - mu sync.Mutex - doneChan chan struct{} - listener *net.Listener - waitgrp sync.WaitGroup + mu sync.Mutex + doneChan chan struct{} + listener *net.Listener + waitgrp sync.WaitGroup inShutdown atomicBool // true when server is in shutdown } @@ -66,10 +66,11 @@ type Server struct { type Protocol string const ( - // SMTP + // SMTP means Simple Mail Transfer Protocol SMTP Protocol = "SMTP" - // Extended SMTP + // ESMTP means Extended Simple Mail Transfer Protocol, because it has some extra features + // Simple Mail Transfer Protocol doesn't have ESMTP = "ESMTP" ) @@ -82,6 +83,53 @@ type Peer struct { ServerName string // A copy of Server.Hostname Addr net.Addr // Network address TLS *tls.ConnectionState // TLS Connection details, if on TLS + Meta map[string]interface{} +} + +// IncrInt increments integer metadata by positive or negative delta provided +func (p *Peer) IncrInt(name string, delta int64) (err error) { + if len(p.Meta) == 0 { + p.Meta = make(map[string]interface{}, 0) + } + val, found := p.Meta[name] + if !found { + p.Meta[name] = delta + return nil + } + switch val.(type) { + case int64: + p.Meta[name] = val.(int64) + delta + default: + return fmt.Errorf("meta %s has wrong type instead of int64", name) + } + return nil +} + +// IncrFloat64 increments integer metadata by positive or negative delta provided +func (p *Peer) IncrFloat64(name string, delta float64) (err error) { + if len(p.Meta) == 0 { + p.Meta = make(map[string]interface{}, 0) + } + val, found := p.Meta[name] + if !found { + p.Meta[name] = delta + return nil + } + switch val.(type) { + case float64: + p.Meta[name] = val.(float64) + delta + default: + return fmt.Errorf("meta %s has wrong type instead of float64", name) + } + return nil +} + +// SetString sets string parameter in meta +func (p *Peer) SetString(name, value string) { + if len(p.Meta) == 0 { + p.Meta = make(map[string]interface{}, 0) + } + p.Meta[name] = value } // Error represents an Error reported in the SMTP session. @@ -94,13 +142,13 @@ type Error struct { func (e Error) Error() string { return fmt.Sprintf("%d %s", e.Code, e.Message) } // ErrServerClosed is returned by the Server's Serve and ListenAndServe, -// methods after a call to Shutdown. +// methods after a call to shut down. var ErrServerClosed = errors.New("smtp: Server closed") type session struct { server *Server - peer Peer + peer *Peer envelope *Envelope conn net.Conn @@ -110,6 +158,8 @@ type session struct { scanner *bufio.Scanner tls bool + // mutex is used to guard editing peer's Metadata + mutex sync.Mutex } func (srv *Server) newSession(c net.Conn) (s *session) { @@ -119,14 +169,15 @@ func (srv *Server) newSession(c net.Conn) (s *session) { conn: c, reader: bufio.NewReader(c), writer: bufio.NewWriter(c), - peer: Peer{ + peer: &Peer{ Addr: c.RemoteAddr(), ServerName: srv.Hostname, + Meta: make(map[string]interface{}, 0), }, } // Check if the underlying connection is already TLS. - // This will happen if the Listerner provided Serve() + // This will happen if the Listener provided Serve() // is from tls.Listen() var tlsConn *tls.Conn @@ -218,7 +269,7 @@ func (srv *Server) Serve(l net.Listener) error { } -// Shutdown instructs the server to shutdown, starting by closing the +// Shutdown instructs the server to shut down, starting by closing the // associated listener. If wait is true, it will wait for the shutdown // to complete. If wait is false, Wait must be called afterwards. func (srv *Server) Shutdown(wait bool) error { @@ -228,7 +279,7 @@ func (srv *Server) Shutdown(wait bool) error { // First close the listener srv.mu.Lock() if srv.listener != nil { - lnerr = (*srv.listener).Close(); + lnerr = (*srv.listener).Close() } srv.closeDoneChanLocked() srv.mu.Unlock() @@ -254,7 +305,7 @@ func (srv *Server) Wait() error { // Address returns the listening address of the server func (srv *Server) Address() net.Addr { - return (*srv.listener).Addr(); + return (*srv.listener).Addr() } func (srv *Server) configureDefaults() { @@ -433,28 +484,27 @@ func (session *session) close() { session.conn.Close() } - // From net/http/server.go -func (s *Server) shuttingDown() bool { - return s.inShutdown.isSet() +func (srv *Server) shuttingDown() bool { + return srv.inShutdown.isSet() } -func (s *Server) getDoneChan() <-chan struct{} { - s.mu.Lock() - defer s.mu.Unlock() - return s.getDoneChanLocked() +func (srv *Server) getDoneChan() <-chan struct{} { + srv.mu.Lock() + defer srv.mu.Unlock() + return srv.getDoneChanLocked() } -func (s *Server) getDoneChanLocked() chan struct{} { - if s.doneChan == nil { - s.doneChan = make(chan struct{}) +func (srv *Server) getDoneChanLocked() chan struct{} { + if srv.doneChan == nil { + srv.doneChan = make(chan struct{}) } - return s.doneChan + return srv.doneChan } -func (s *Server) closeDoneChanLocked() { - ch := s.getDoneChanLocked() +func (srv *Server) closeDoneChanLocked() { + ch := srv.getDoneChanLocked() select { case <-ch: // Already closed. Don't close again. diff --git a/smtpd_test.go b/smtpd_test.go index 3f8488a..6b946b0 100644 --- a/smtpd_test.go +++ b/smtpd_test.go @@ -7,6 +7,7 @@ import ( "fmt" "log" "net" + "net/mail" "net/smtp" "net/textproto" "os" @@ -242,6 +243,12 @@ func TestListenAndServe(t *testing.T) { time.Sleep(100 * time.Millisecond) + if server.Address().String() != addr { + t.Errorf("server is listening on `%s` instead of `%s", + server.Address(), addr, + ) + } + c, err := smtp.Dial(addr) if err != nil { t.Fatalf("Dial failed: %v", err) @@ -256,7 +263,7 @@ func TestListenAndServe(t *testing.T) { func TestSTARTTLS(t *testing.T) { addr, closer := runsslserver(t, &smtpd.Server{ - Authenticator: func(peer smtpd.Peer, username, password string) error { return nil }, + Authenticator: func(peer *smtpd.Peer, username, password string) error { return nil }, ForceTLS: true, ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile), }) @@ -343,7 +350,7 @@ func TestSTARTTLS(t *testing.T) { func TestAuthRejection(t *testing.T) { addr, closer := runsslserver(t, &smtpd.Server{ - Authenticator: func(peer smtpd.Peer, username, password string) error { + Authenticator: func(peer *smtpd.Peer, username, password string) error { return smtpd.Error{Code: 550, Message: "Denied"} }, ForceTLS: true, @@ -394,11 +401,11 @@ func TestAuthNotSupported(t *testing.T) { func TestAuthBypass(t *testing.T) { addr, closer := runsslserver(t, &smtpd.Server{ - Authenticator: func(peer smtpd.Peer, username, password string) error { + Authenticator: func(peer *smtpd.Peer, username, password string) error { return smtpd.Error{Code: 550, Message: "Denied"} }, - ForceTLS: true, - ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile), + ForceTLS: true, + ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile), }) defer closer() @@ -421,7 +428,7 @@ func TestAuthBypass(t *testing.T) { func TestConnectionCheck(t *testing.T) { addr, closer := runserver(t, &smtpd.Server{ - ConnectionChecker: func(peer smtpd.Peer) error { + ConnectionChecker: func(peer *smtpd.Peer) error { return smtpd.Error{Code: 552, Message: "Denied"} }, ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile), @@ -438,7 +445,7 @@ func TestConnectionCheck(t *testing.T) { func TestConnectionCheckSimpleError(t *testing.T) { addr, closer := runserver(t, &smtpd.Server{ - ConnectionChecker: func(peer smtpd.Peer) error { + ConnectionChecker: func(peer *smtpd.Peer) error { return errors.New("Denied") }, ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile), @@ -455,7 +462,7 @@ func TestConnectionCheckSimpleError(t *testing.T) { func TestHELOCheck(t *testing.T) { addr, closer := runserver(t, &smtpd.Server{ - HeloChecker: func(peer smtpd.Peer, name string) error { + HeloChecker: func(peer *smtpd.Peer, name string) error { if name != "foobar.local" { t.Fatal("Wrong HELO name") } @@ -480,7 +487,7 @@ func TestHELOCheck(t *testing.T) { func TestSenderCheck(t *testing.T) { addr, closer := runserver(t, &smtpd.Server{ - SenderChecker: func(peer smtpd.Peer, addr string) error { + SenderChecker: func(peer *smtpd.Peer, addr string) error { return smtpd.Error{Code: 552, Message: "Denied"} }, ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile), @@ -502,7 +509,7 @@ func TestSenderCheck(t *testing.T) { func TestRecipientCheck(t *testing.T) { addr, closer := runserver(t, &smtpd.Server{ - RecipientChecker: func(peer smtpd.Peer, addr string) error { + RecipientChecker: func(peer *smtpd.Peer, addr string) error { return smtpd.Error{Code: 552, Message: "Denied"} }, ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile), @@ -571,7 +578,7 @@ func TestMaxMessageSize(t *testing.T) { func TestHandler(t *testing.T) { addr, closer := runserver(t, &smtpd.Server{ - Handler: func(peer smtpd.Peer, env smtpd.Envelope) error { + Handler: func(peer *smtpd.Peer, env smtpd.Envelope) error { if env.Sender != "sender@example.org" { t.Fatalf("Unknown sender: %v", env.Sender) } @@ -628,7 +635,7 @@ func TestHandler(t *testing.T) { func TestRejectHandler(t *testing.T) { addr, closer := runserver(t, &smtpd.Server{ - Handler: func(peer smtpd.Peer, env smtpd.Envelope) error { + Handler: func(peer *smtpd.Peer, env smtpd.Envelope) error { return smtpd.Error{Code: 550, Message: "Rejected"} }, ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile), @@ -851,7 +858,7 @@ func TestDATAbeforeRCPT(t *testing.T) { func TestInterruptedDATA(t *testing.T) { addr, closer := runserver(t, &smtpd.Server{ - Handler: func(peer smtpd.Peer, env smtpd.Envelope) error { + Handler: func(peer *smtpd.Peer, env smtpd.Envelope) error { t.Fatal("Accepted DATA despite disconnection") return nil }, @@ -887,6 +894,74 @@ func TestInterruptedDATA(t *testing.T) { } +func TestMeta(t *testing.T) { + addr, closer := runserver(t, &smtpd.Server{ + MaxConnections: 1, + HeloChecker: func(peer *smtpd.Peer, name string) (err error) { + peer.SetString("something", name) + err = peer.IncrInt("int64", 1) + if err != nil { + return + } + err = peer.IncrFloat64("float64", 1.1) + return err + }, + SenderChecker: func(peer *smtpd.Peer, addr string) (err error) { + if peer.Meta["something"] != "localhost" { + t.Errorf("wrong meta something %s instead of `localhost`", + peer.Meta["something"]) + } + if peer.Meta["int64"] != int64(1) { + t.Errorf("wrong meta int64 %v instead of `1`", peer.Meta["int64"]) + } + if peer.Meta["float64"] != float64(1.1) { + t.Errorf("wrong meta int64 %v instead of `1`", peer.Meta["int64"]) + } + err = peer.IncrInt("int64", 1) + if err != nil { + return + } + err = peer.IncrFloat64("float64", 1.1) + return err + }, + RecipientChecker: func(peer *smtpd.Peer, addr string) error { + return smtpd.Error{ + Code: 451, + Message: fmt.Sprintf("%v %v %s", + peer.Meta["int64"], + peer.Meta["float64"], + peer.Meta["something"], + ), + } + }, + ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile), + }) + + defer closer() + cm, err := smtp.Dial(addr) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + err = cm.Hello("localhost") + if err != nil { + t.Error(err) + } + err = cm.Mail("somebody@localhost") + if err != nil { + t.Error(err) + } + err = cm.Rcpt("scuba@example.org") + if err != nil { + if err.Error() != "451 2 2.2 localhost" { + t.Errorf("wrong error `%s` instead `451 2 2.2 localhost`", err) + } + } + err = cm.Close() + if err != nil { + t.Error(err) + } +} + func TestTimeoutClose(t *testing.T) { addr, closer := runserver(t, &smtpd.Server{ @@ -997,7 +1072,7 @@ func TestXCLIENT(t *testing.T) { addr, closer := runserver(t, &smtpd.Server{ EnableXCLIENT: true, - SenderChecker: func(peer smtpd.Peer, addr string) error { + SenderChecker: func(peer *smtpd.Peer, addr string) error { if peer.HeloName != "new.example.net" { t.Fatalf("Didn't override HELO name: %v", peer.HeloName) } @@ -1068,7 +1143,7 @@ func TestEnvelopeReceived(t *testing.T) { addr, closer := runsslserver(t, &smtpd.Server{ Hostname: "foobar.example.net", - Handler: func(peer smtpd.Peer, env smtpd.Envelope) error { + Handler: func(peer *smtpd.Peer, env smtpd.Envelope) error { env.AddReceivedLine(peer) if !bytes.HasPrefix(env.Data, []byte("Received: from localhost ([127.0.0.1]) by foobar.example.net with ESMTP;")) { t.Fatal("Wrong received line.") @@ -1119,6 +1194,142 @@ func TestEnvelopeReceived(t *testing.T) { } +func TestExtraHeader(t *testing.T) { + + addr, closer := runsslserver(t, &smtpd.Server{ + Hostname: "foobar.example.net", + Handler: func(peer *smtpd.Peer, env smtpd.Envelope) error { + env.AddHeader("Something", "interesting") + if !bytes.HasPrefix(env.Data, []byte("Something: interesting")) { + t.Fatal("Wrong extra header line.") + } + return nil + }, + ForceTLS: true, + ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile), + }) + + defer closer() + + c, err := smtp.Dial(addr) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + + if err := c.StartTLS(&tls.Config{InsecureSkipVerify: true}); err != nil { + t.Fatalf("STARTTLS failed: %v", err) + } + + if err := c.Mail("sender@example.org"); err != nil { + t.Fatalf("MAIL failed: %v", err) + } + + if err := c.Rcpt("recipient@example.net"); err != nil { + t.Fatalf("RCPT failed: %v", err) + } + + wc, err := c.Data() + if err != nil { + t.Fatalf("Data failed: %v", err) + } + + _, err = fmt.Fprintf(wc, "This is the email body") + if err != nil { + t.Fatalf("Data body failed: %v", err) + } + + err = wc.Close() + if err != nil { + t.Fatalf("Data close failed: %v", err) + } + + if err := c.Quit(); err != nil { + t.Fatalf("QUIT failed: %v", err) + } + +} + +func TestTwoExtraHeadersMakeMessageParsable(t *testing.T) { + + addr, closer := runsslserver(t, &smtpd.Server{ + Hostname: "foobar.example.net", + Handler: func(peer *smtpd.Peer, env smtpd.Envelope) (err error) { + env.AddHeader("Something1", "interesting 1") + env.AddHeader("Something2", "interesting 2") + env.AddReceivedLine(peer) + if !bytes.HasPrefix(env.Data, []byte("Received: from localhost ([127.0.0.1]) by foobar.example.net with ESMTP;")) { + t.Fatal("Wrong received line.") + } + msg, err := mail.ReadMessage(bytes.NewReader(env.Data)) + if err != nil { + t.Errorf("%s : while parsing email message", err) + return err + } + if msg.Header.Get("Something1") != "interesting 1" { + t.Errorf("Header Something is wrong: `%s` instead of `interesting 1`", + msg.Header.Get("Something1")) + } + if msg.Header.Get("Something2") != "interesting 2" { + t.Errorf("Header Something is wrong: `%s` instead of `interesting 1`", + msg.Header.Get("Something1")) + } + return + }, + ForceTLS: true, + ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile), + }) + + defer closer() + + c, err := smtp.Dial(addr) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + + if err := c.StartTLS(&tls.Config{InsecureSkipVerify: true}); err != nil { + t.Fatalf("STARTTLS failed: %v", err) + } + + if err := c.Mail("sender@example.org"); err != nil { + t.Fatalf("MAIL failed: %v", err) + } + + if err := c.Rcpt("recipient@example.net"); err != nil { + t.Fatalf("RCPT failed: %v", err) + } + + wc, err := c.Data() + if err != nil { + t.Fatalf("Data failed: %v", err) + } + + body := ` +Date: Sun, 11 Jun 2023 19:49:29 +0300 +To: scuba@vodolaz095.ru +From: scuba@vodolaz095.ru +Subject: test Sun, 11 Jun 2023 19:49:29 +0300 +Message-Id: <20230611194929.017435@localhost> +X-Mailer: swaks v20190914.0 jetmore.org/john/code/swaks/ + +This is a test mailing +` + + _, err = fmt.Fprintf(wc, body) + if err != nil { + t.Fatalf("Data body failed: %v", err) + } + + err = wc.Close() + if err != nil { + t.Fatalf("Data close failed: %v", err) + } + + if err := c.Quit(); err != nil { + t.Fatalf("QUIT failed: %v", err) + } + +} + func TestHELO(t *testing.T) { addr, closer := runserver(t, &smtpd.Server{ @@ -1157,7 +1368,7 @@ func TestHELO(t *testing.T) { func TestLOGINAuth(t *testing.T) { addr, closer := runsslserver(t, &smtpd.Server{ - Authenticator: func(peer smtpd.Peer, username, password string) error { return nil }, + Authenticator: func(peer *smtpd.Peer, username, password string) error { return nil }, ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile), }) @@ -1272,7 +1483,7 @@ func TestErrors(t *testing.T) { } server := &smtpd.Server{ - Authenticator: func(peer smtpd.Peer, username, password string) error { return nil }, + Authenticator: func(peer *smtpd.Peer, username, password string) error { return nil }, ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile), } @@ -1350,7 +1561,7 @@ func TestErrors(t *testing.T) { func TestMailformedMAILFROM(t *testing.T) { addr, closer := runserver(t, &smtpd.Server{ - SenderChecker: func(peer smtpd.Peer, addr string) error { + SenderChecker: func(peer *smtpd.Peer, addr string) error { if addr != "test@example.org" { return smtpd.Error{Code: 502, Message: "Denied"} } @@ -1379,6 +1590,35 @@ func TestMailformedMAILFROM(t *testing.T) { } } +func TestProxyNotEnabled(t *testing.T) { + addr, closer := runserver(t, &smtpd.Server{ + EnableProxyProtocol: false, // important + ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile), + }) + defer closer() + + c, err := smtp.Dial(addr) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + + where := strings.Split(addr, ":") + err = cmd(c.Text, 550, "PROXY TCP4 8.8.8.8 %s 443 %s", where[0], where[1]) + if err != nil { + t.Fatalf("sending proxy command enabled from the box - %s", err) + } + + err = c.Hello("nobody.example.org") + if err != nil { + t.Fatalf("sending helo command failed with %s", err) + } + + err = c.Quit() + if err != nil { + t.Fatalf("sending quit command failed with %s", err) + } +} + func TestTLSListener(t *testing.T) { cert, err := tls.X509KeyPair(localhostCert, localhostKey) @@ -1396,7 +1636,7 @@ func TestTLSListener(t *testing.T) { addr := ln.Addr().String() server := &smtpd.Server{ - Authenticator: func(peer smtpd.Peer, username, password string) error { + Authenticator: func(peer *smtpd.Peer, username, password string) error { if peer.TLS == nil { t.Error("didn't correctly set connection state on TLS connection") }