From 50ed5fd57af436ae70181621ad6725ed46a4d7c1 Mon Sep 17 00:00:00 2001 From: Christian Joergensen Date: Fri, 18 Jul 2014 08:46:39 +0200 Subject: [PATCH] Send the HELO hostname to the HeloChecker. --- .hgignore | 3 +++ example_test.go | 2 +- protocol.go | 11 +++++------ smtpd.go | 2 +- smtpd_test.go | 7 +++++-- 5 files changed, 15 insertions(+), 10 deletions(-) create mode 100644 .hgignore diff --git a/.hgignore b/.hgignore new file mode 100644 index 0000000..758d190 --- /dev/null +++ b/.hgignore @@ -0,0 +1,3 @@ +syntax: glob + +*.orig diff --git a/example_test.go b/example_test.go index a143890..611a79b 100644 --- a/example_test.go +++ b/example_test.go @@ -19,7 +19,7 @@ func ExampleServer() { Addr: "0.0.0.0:10025", - HeloChecker: func(peer smtpd.Peer) error { + HeloChecker: func(peer smtpd.Peer, name string) error { if !strings.HasPrefix(peer.Addr.String(), "42.42.42.42:") { return errors.New("Denied") } diff --git a/protocol.go b/protocol.go index f77512c..ab5c217 100644 --- a/protocol.go +++ b/protocol.go @@ -98,16 +98,15 @@ func (session *session) handleHELO(cmd command) { session.envelope = nil } - session.peer.HeloName = cmd.fields[1] - if session.server.HeloChecker != nil { - err := session.server.HeloChecker(session.peer) + err := session.server.HeloChecker(session.peer, cmd.fields[1]) if err != nil { session.error(err) return } } + session.peer.HeloName = cmd.fields[1] session.reply(250, "Go ahead") return @@ -126,16 +125,16 @@ func (session *session) handleEHLO(cmd command) { session.envelope = nil } - session.peer.HeloName = cmd.fields[1] - if session.server.HeloChecker != nil { - err := session.server.HeloChecker(session.peer) + err := session.server.HeloChecker(session.peer, cmd.fields[1]) if err != nil { session.error(err) return } } + session.peer.HeloName = cmd.fields[1] + extensions := session.extensions() if len(extensions) > 1 { diff --git a/smtpd.go b/smtpd.go index aae2ddb..ac6585d 100644 --- a/smtpd.go +++ b/smtpd.go @@ -34,7 +34,7 @@ type Server struct { // If an error is returned, it will be reported in the SMTP session. // Use the Error struct for access to error codes. ConnectionChecker func(peer Peer) error // Called upon new connection. - HeloChecker func(peer Peer) error // Called after HELO/EHLO. + HeloChecker func(peer Peer, name string) error // Called after HELO/EHLO. SenderChecker func(peer Peer, addr string) error // Called after MAIL FROM. RecipientChecker func(peer Peer, addr string) error // Called after each RCPT TO. diff --git a/smtpd_test.go b/smtpd_test.go index 17e0539..f13b200 100644 --- a/smtpd_test.go +++ b/smtpd_test.go @@ -378,7 +378,10 @@ func TestHELOCheck(t *testing.T) { defer ln.Close() server := &smtpd.Server{ - HeloChecker: func(peer smtpd.Peer) error { + HeloChecker: func(peer smtpd.Peer, name string) error { + if name != "foobar.local" { + t.Fatal("Wrong HELO name") + } return smtpd.Error{Code: 552, Message: "Denied"} }, } @@ -392,7 +395,7 @@ func TestHELOCheck(t *testing.T) { t.Fatalf("Dial failed: %v", err) } - if err := c.Hello("localhost"); err == nil { + if err := c.Hello("foobar.local"); err == nil { t.Fatal("Unexpected HELO success") }