diff --git a/protocol.go b/protocol.go index ab5c217..bf783a3 100644 --- a/protocol.go +++ b/protocol.go @@ -38,6 +38,10 @@ func (session *session) handle(line string) { cmd := parseLine(line) + // Commands are dispatched to the appropriate handler functions. + // If a network error occurs during handling, the handler should + // just return and let the error be handled on the next read. + switch cmd.action { case "HELO": @@ -95,7 +99,7 @@ func (session *session) handleHELO(cmd command) { if session.peer.HeloName != "" { // Reset envelope in case of duplicate HELO - session.envelope = nil + session.reset() } if session.server.HeloChecker != nil { @@ -122,7 +126,7 @@ func (session *session) handleEHLO(cmd command) { if session.peer.HeloName != "" { // Reset envelope in case of duplicate EHLO - session.envelope = nil + session.reset() } if session.server.HeloChecker != nil { @@ -197,7 +201,7 @@ func (session *session) handleRCPT(cmd command) { } if len(session.envelope.Recipients) >= session.server.MaxRecipients { - session.reply(550, "Too many recipients") + session.reply(452, "Too many recipients") return } @@ -243,7 +247,7 @@ func (session *session) handleSTARTTLS(cmd command) { } // Reset envelope as a new EHLO/HELO is required after STARTTLS - session.envelope = nil + session.reset() // Reset deadlines on the underlying connection before I replace it // with a TLS connection @@ -291,7 +295,7 @@ func (session *session) handleDATA(cmd command) { session.reply(250, "Thank you.") } - session.envelope = nil + session.reset() } @@ -313,14 +317,14 @@ func (session *session) handleDATA(cmd command) { session.server.MaxMessageSize, )) - session.envelope = nil + session.reset() return } func (session *session) handleRSET(cmd command) { - session.envelope = nil + session.reset() session.reply(250, "Go ahead") return } diff --git a/smtpd.go b/smtpd.go index ac6585d..490a546 100644 --- a/smtpd.go +++ b/smtpd.go @@ -214,8 +214,31 @@ func (session *session) serve() { session.welcome() - for session.scanner.Scan() { - session.handle(session.scanner.Text()) + for { + + for session.scanner.Scan() { + session.handle(session.scanner.Text()) + } + + err := session.scanner.Err() + + if err == bufio.ErrTooLong { + + session.reply(500, "Line too long") + + // Advance reader to the next newline + + session.reader.ReadString('\n') + session.scanner = bufio.NewScanner(session.reader) + + // Reset and have the client start over. + + session.reset() + + continue + } + + break } } @@ -225,6 +248,10 @@ func (session *session) reject() { session.close() } +func (session *session) reset() { + session.envelope = nil +} + func (session *session) welcome() { if session.server.ConnectionChecker != nil { diff --git a/smtpd_test.go b/smtpd_test.go index f13b200..a8b47a6 100644 --- a/smtpd_test.go +++ b/smtpd_test.go @@ -1028,3 +1028,33 @@ func TestTLSTimeout(t *testing.T) { } } + +func TestLongLine(t *testing.T) { + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Listen failed: %v", err) + } + + defer ln.Close() + + server := &smtpd.Server{} + + go func() { + server.Serve(ln) + }() + + c, err := smtp.Dial(ln.Addr().String()) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + + if err := c.Mail(fmt.Sprintf("%s@example.org", strings.Repeat("x", 65*1024))); err == nil { + t.Fatalf("MAIL failed: %v", err) + } + + if err := c.Quit(); err != nil { + t.Fatalf("Quit failed: %v", err) + } + +}