Different fixes.
Handle empty e-mail addresses. Implemented MaxRecipients check. Announce PIPELINING support. Corrected error code on Too Busy error. Implemented DataTimeout for timeouts on DATA. Use textproto.DotReader in handleDATA to implement dot-stuffing. Reset deadlines on old socket in STARTTLS, add new deadlines to new TLS socket.
This commit is contained in:
parent
3b49f5e70d
commit
a0c514f12d
5 changed files with 226 additions and 39 deletions
|
@ -6,8 +6,14 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func parseAddress(src string) (string, error) {
|
func parseAddress(src string) (string, error) {
|
||||||
if src[0] != '<' || src[len(src)-1] != '>' || strings.Count(src, "@") != 1 {
|
|
||||||
|
if src[0] != '<' || src[len(src)-1] != '>' {
|
||||||
return "", fmt.Errorf("Ill-formatted e-mail address: %s", src)
|
return "", fmt.Errorf("Ill-formatted e-mail address: %s", src)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if strings.Count(src, "@") > 1 {
|
||||||
|
return "", fmt.Errorf("Ill-formatted e-mail address: %s", src)
|
||||||
|
}
|
||||||
|
|
||||||
return src[1 : len(src)-1], nil
|
return src[1 : len(src)-1], nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,6 +27,7 @@ func ExampleServer() {
|
||||||
},
|
},
|
||||||
|
|
||||||
Handler: func(peer smtpd.Peer, env smtpd.Envelope) error {
|
Handler: func(peer smtpd.Peer, env smtpd.Envelope) error {
|
||||||
|
|
||||||
return smtp.SendMail(
|
return smtp.SendMail(
|
||||||
"smtp.gmail.com:587",
|
"smtp.gmail.com:587",
|
||||||
smtp.PlainAuth(
|
smtp.PlainAuth(
|
||||||
|
@ -39,6 +40,7 @@ func ExampleServer() {
|
||||||
env.Recipients,
|
env.Recipients,
|
||||||
env.Data,
|
env.Data,
|
||||||
)
|
)
|
||||||
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
77
protocol.go
77
protocol.go
|
@ -6,7 +6,11 @@ import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/textproto"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type command struct {
|
type command struct {
|
||||||
|
@ -178,6 +182,11 @@ func (session *session) handleRCPT(cmd command) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(session.envelope.Recipients) >= session.server.MaxRecipients {
|
||||||
|
session.reply(550, "Too many recipients")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
addr, err := parseAddress(cmd.params[1])
|
addr, err := parseAddress(cmd.params[1])
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -219,12 +228,23 @@ func (session *session) handleSTARTTLS(cmd command) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Reset HeloName as a new EHLO/HELO is required after STARTTLS
|
||||||
|
session.peer.HeloName = ""
|
||||||
|
|
||||||
|
// Reset deadlines on the underlying connection before I replace it
|
||||||
|
// with a TLS connection
|
||||||
|
session.conn.SetDeadline(time.Time{})
|
||||||
|
|
||||||
|
// Replace connection with a TLS connection
|
||||||
session.conn = tlsConn
|
session.conn = tlsConn
|
||||||
session.reader = bufio.NewReader(tlsConn)
|
session.reader = bufio.NewReader(tlsConn)
|
||||||
session.writer = bufio.NewWriter(tlsConn)
|
session.writer = bufio.NewWriter(tlsConn)
|
||||||
session.scanner = bufio.NewScanner(session.reader)
|
session.scanner = bufio.NewScanner(session.reader)
|
||||||
session.tls = true
|
session.tls = true
|
||||||
|
|
||||||
|
// Flush the connection to set new timeout deadlines
|
||||||
|
session.flush()
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -237,41 +257,20 @@ func (session *session) handleDATA(cmd command) {
|
||||||
}
|
}
|
||||||
|
|
||||||
session.reply(354, "Go ahead. End your data with <CR><LF>.<CR><LF>")
|
session.reply(354, "Go ahead. End your data with <CR><LF>.<CR><LF>")
|
||||||
|
session.conn.SetDeadline(time.Now().Add(session.server.DataTimeout))
|
||||||
|
|
||||||
data := &bytes.Buffer{}
|
data := &bytes.Buffer{}
|
||||||
done := false
|
reader := textproto.NewReader(session.reader).DotReader()
|
||||||
|
|
||||||
for session.scanner.Scan() {
|
_, err := io.CopyN(data, reader, int64(session.server.MaxMessageSize))
|
||||||
|
|
||||||
line := session.scanner.Text()
|
if err == io.EOF {
|
||||||
|
|
||||||
if line == "." {
|
// EOF was reached before MaxMessageSize
|
||||||
done = true
|
// Accept and deliver message
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
data.Write([]byte(line))
|
|
||||||
data.Write([]byte("\r\n"))
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
if !done {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if data.Len() > session.server.MaxMessageSize {
|
|
||||||
session.reply(550, fmt.Sprintf(
|
|
||||||
"Message exceeded max message size of %d bytes",
|
|
||||||
session.server.MaxMessageSize,
|
|
||||||
))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
session.envelope.Data = data.Bytes()
|
session.envelope.Data = data.Bytes()
|
||||||
|
if err := session.deliver(); err != nil {
|
||||||
err := session.deliver()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
session.error(err)
|
session.error(err)
|
||||||
} else {
|
} else {
|
||||||
session.reply(250, "Thank you.")
|
session.reply(250, "Thank you.")
|
||||||
|
@ -279,6 +278,28 @@ func (session *session) handleDATA(cmd command) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
// Network error, ignore
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Discard the rest and report an error.
|
||||||
|
_, err = io.Copy(ioutil.Discard, reader)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
// Network error, ignore
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
session.reply(552, fmt.Sprintf(
|
||||||
|
"Message exceeded max message size of %d bytes",
|
||||||
|
session.server.MaxMessageSize,
|
||||||
|
))
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func (session *session) handleRSET(cmd command) {
|
func (session *session) handleRSET(cmd command) {
|
||||||
session.envelope = nil
|
session.envelope = nil
|
||||||
session.reply(250, "Go ahead")
|
session.reply(250, "Go ahead")
|
||||||
|
|
21
smtpd.go
21
smtpd.go
|
@ -18,9 +18,11 @@ type Server struct {
|
||||||
|
|
||||||
ReadTimeout time.Duration // Socket timeout for read operations. (default: 60s)
|
ReadTimeout time.Duration // Socket timeout for read operations. (default: 60s)
|
||||||
WriteTimeout time.Duration // Socket timeout for write operations. (default: 60s)
|
WriteTimeout time.Duration // Socket timeout for write operations. (default: 60s)
|
||||||
|
DataTimeout time.Duration // Socket timeout for DATA command (default: 5m)
|
||||||
|
|
||||||
MaxMessageSize int // Max message size in bytes. (default: 10240000)
|
|
||||||
MaxConnections int // Max concurrent connections, use -1 to disable. (default: 100)
|
MaxConnections int // Max concurrent connections, use -1 to disable. (default: 100)
|
||||||
|
MaxMessageSize int // Max message size in bytes. (default: 10240000)
|
||||||
|
MaxRecipients int // Max RCPT TO calls for each envelope. (default: 100)
|
||||||
|
|
||||||
// New e-mails are handed off to this function.
|
// New e-mails are handed off to this function.
|
||||||
// Can be left empty for a NOOP server.
|
// Can be left empty for a NOOP server.
|
||||||
|
@ -168,6 +170,10 @@ func (srv *Server) configureDefaults() {
|
||||||
srv.MaxConnections = 100
|
srv.MaxConnections = 100
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if srv.MaxRecipients == 0 {
|
||||||
|
srv.MaxRecipients = 100
|
||||||
|
}
|
||||||
|
|
||||||
if srv.ReadTimeout == 0 {
|
if srv.ReadTimeout == 0 {
|
||||||
srv.ReadTimeout = time.Second * 60
|
srv.ReadTimeout = time.Second * 60
|
||||||
}
|
}
|
||||||
|
@ -176,6 +182,10 @@ func (srv *Server) configureDefaults() {
|
||||||
srv.WriteTimeout = time.Second * 60
|
srv.WriteTimeout = time.Second * 60
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if srv.DataTimeout == 0 {
|
||||||
|
srv.DataTimeout = time.Minute * 5
|
||||||
|
}
|
||||||
|
|
||||||
if srv.ForceTLS && srv.TLSConfig == nil {
|
if srv.ForceTLS && srv.TLSConfig == nil {
|
||||||
log.Fatal("Cannot use ForceTLS with no TLSConfig")
|
log.Fatal("Cannot use ForceTLS with no TLSConfig")
|
||||||
}
|
}
|
||||||
|
@ -211,7 +221,7 @@ func (session *session) serve() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (session *session) reject() {
|
func (session *session) reject() {
|
||||||
session.reply(450, "Too busy. Try again later.")
|
session.reply(421, "Too busy. Try again later.")
|
||||||
session.close()
|
session.close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -231,14 +241,14 @@ func (session *session) welcome() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (session *session) reply(code int, message string) {
|
func (session *session) reply(code int, message string) {
|
||||||
|
|
||||||
fmt.Fprintf(session.writer, "%d %s\r\n", code, message)
|
fmt.Fprintf(session.writer, "%d %s\r\n", code, message)
|
||||||
|
session.flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (session *session) flush() {
|
||||||
session.conn.SetWriteDeadline(time.Now().Add(session.server.WriteTimeout))
|
session.conn.SetWriteDeadline(time.Now().Add(session.server.WriteTimeout))
|
||||||
session.writer.Flush()
|
session.writer.Flush()
|
||||||
|
|
||||||
session.conn.SetReadDeadline(time.Now().Add(session.server.ReadTimeout))
|
session.conn.SetReadDeadline(time.Now().Add(session.server.ReadTimeout))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (session *session) error(err error) {
|
func (session *session) error(err error) {
|
||||||
|
@ -254,6 +264,7 @@ func (session *session) extensions() []string {
|
||||||
extensions := []string{
|
extensions := []string{
|
||||||
fmt.Sprintf("SIZE %d", session.server.MaxMessageSize),
|
fmt.Sprintf("SIZE %d", session.server.MaxMessageSize),
|
||||||
"8BITMIME",
|
"8BITMIME",
|
||||||
|
"PIPELINING",
|
||||||
}
|
}
|
||||||
|
|
||||||
if session.server.TLSConfig != nil && !session.tls {
|
if session.server.TLSConfig != nil && !session.tls {
|
||||||
|
|
155
smtpd_test.go
155
smtpd_test.go
|
@ -533,8 +533,8 @@ func TestHandler(t *testing.T) {
|
||||||
if env.Recipients[0] != "recipient@example.net" {
|
if env.Recipients[0] != "recipient@example.net" {
|
||||||
t.Fatalf("Unknown recipient: %v", env.Recipients[0])
|
t.Fatalf("Unknown recipient: %v", env.Recipients[0])
|
||||||
}
|
}
|
||||||
if string(env.Data) != "This is the email body\r\n" {
|
if string(env.Data) != "This is the email body\n" {
|
||||||
t.Fatalf("Wrong message body: %v", env.Data)
|
t.Fatalf("Wrong message body: %v", string(env.Data))
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
|
@ -686,6 +686,46 @@ func TestNoMaxConnections(t *testing.T) {
|
||||||
c1.Close()
|
c1.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMaxRecipients(t *testing.T) {
|
||||||
|
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Listen failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer ln.Close()
|
||||||
|
|
||||||
|
server := &smtpd.Server{
|
||||||
|
MaxRecipients: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
server.Serve(ln)
|
||||||
|
}()
|
||||||
|
|
||||||
|
c, err := smtp.Dial(ln.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Dial failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.Mail("sender@example.org"); err != nil {
|
||||||
|
t.Fatalf("MAIL failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.Rcpt("recipient@example.net"); err != nil {
|
||||||
|
t.Fatalf("RCPT failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.Rcpt("recipient@example.net"); err == nil {
|
||||||
|
t.Fatal("RCPT succeeded despite MaxRecipients = 1")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.Quit(); err != nil {
|
||||||
|
t.Fatalf("QUIT failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func TestInvalidHelo(t *testing.T) {
|
func TestInvalidHelo(t *testing.T) {
|
||||||
|
|
||||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
@ -732,7 +772,7 @@ func TestInvalidSender(t *testing.T) {
|
||||||
t.Fatalf("Dial failed: %v", err)
|
t.Fatalf("Dial failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.Mail("invalid"); err == nil {
|
if err := c.Mail("invalid@@example.org"); err == nil {
|
||||||
t.Fatal("Unexpected MAIL success")
|
t.Fatal("Unexpected MAIL success")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -762,7 +802,7 @@ func TestInvalidRecipient(t *testing.T) {
|
||||||
t.Fatalf("Mail failed: %v", err)
|
t.Fatalf("Mail failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.Rcpt("invalid"); err == nil {
|
if err := c.Rcpt("invalid@@example.org"); err == nil {
|
||||||
t.Fatal("Unexpected RCPT success")
|
t.Fatal("Unexpected RCPT success")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -878,3 +918,110 @@ func TestInterruptedDATA(t *testing.T) {
|
||||||
c.Close()
|
c.Close()
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTimeoutClose(t *testing.T) {
|
||||||
|
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Listen failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer ln.Close()
|
||||||
|
|
||||||
|
server := &smtpd.Server{
|
||||||
|
MaxConnections: 1,
|
||||||
|
ReadTimeout: time.Second,
|
||||||
|
WriteTimeout: time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
server.Serve(ln)
|
||||||
|
}()
|
||||||
|
|
||||||
|
c1, err := smtp.Dial(ln.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Dial failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(time.Second * 2)
|
||||||
|
|
||||||
|
c2, err := smtp.Dial(ln.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Dial failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c1.Mail("sender@example.org"); err == nil {
|
||||||
|
t.Fatal("MAIL succeeded despite being timed out.")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c2.Mail("sender@example.org"); err != nil {
|
||||||
|
t.Fatalf("MAIL failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c2.Quit(); err != nil {
|
||||||
|
t.Fatalf("Quit failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c2.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTLSTimeout(t *testing.T) {
|
||||||
|
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Listen failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer ln.Close()
|
||||||
|
|
||||||
|
cert, err := tls.X509KeyPair(localhostCert, localhostKey)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Cert load failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
server := &smtpd.Server{
|
||||||
|
TLSConfig: &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
},
|
||||||
|
ReadTimeout: time.Second * 2,
|
||||||
|
WriteTimeout: time.Second * 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
server.Serve(ln)
|
||||||
|
}()
|
||||||
|
|
||||||
|
c, err := smtp.Dial(ln.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Dial failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.StartTLS(&tls.Config{InsecureSkipVerify: true}); err != nil {
|
||||||
|
t.Fatalf("STARTTLS failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
|
if err := c.Mail("sender@example.org"); err != nil {
|
||||||
|
t.Fatalf("MAIL failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
|
if err := c.Rcpt("recipient@example.net"); err != nil {
|
||||||
|
t.Fatalf("RCPT failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
|
if err := c.Rcpt("recipient@example.net"); err != nil {
|
||||||
|
t.Fatalf("RCPT failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
|
if err := c.Quit(); err != nil {
|
||||||
|
t.Fatalf("Quit failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue