diff --git a/address.go b/address.go index 1ffce35..76b7507 100644 --- a/address.go +++ b/address.go @@ -5,12 +5,9 @@ import ( "strings" ) -// MailAddress holds an e-mail address -type MailAddress string - -func parseMailAddress(src string) (MailAddress, error) { +func parseAddress(src string) (string, error) { if src[0] != '<' || src[len(src)-1] != '>' || strings.Count(src, "@") != 1 { - return MailAddress(""), fmt.Errorf("Ill-formatted e-mail address: %s", src) + return "", fmt.Errorf("Ill-formatted e-mail address: %s", src) } - return MailAddress(src[1 : len(src)-1]), nil + return src[1 : len(src)-1], nil } diff --git a/example_test.go b/example_test.go index 4e784d3..ea2c23e 100644 --- a/example_test.go +++ b/example_test.go @@ -1,32 +1,32 @@ package smtpd import ( - "bitbucket.org/chrj/smtpd" "errors" - "net" "net/smtp" "strings" ) func ExampleServer() { + var server *Server + // No-op server. Accepts and discards - server := &smtpd.Server{} - server.serve() + server = &Server{} + server.ListenAndServe() // Relay server. Accepts only from single IP address and forwards using the Gmail smtp - server := &smtpd.Server{ + server = &Server{ Addr: "0.0.0.0:10025", - HeloChecker: func(peer smtpd.Peer) error { + HeloChecker: func(peer Peer) error { if !strings.HasPrefix(peer.Addr.String(), "42.42.42.42:") { return errors.New("Denied") } return nil }, - Handler: func(peer smtpd.Peer, env smtpd.Envelope) error { + Handler: func(peer Peer, env Envelope) error { return smtp.SendMail( "smtp.gmail.com:587", smtp.PlainAuth( @@ -35,14 +35,14 @@ func ExampleServer() { "password", "smtp.gmail.com", ), - env.Sender, - env.Recipients, + string(env.Sender), + []string(env.Recipients), env.Data, ) }, } - server.serve() + server.ListenAndServe() } diff --git a/protocol.go b/protocol.go index 359aa98..1b6304c 100644 --- a/protocol.go +++ b/protocol.go @@ -148,7 +148,7 @@ func (session *session) handleMAIL(cmd command) { return } - addr, err := parseMailAddress(cmd.params[1]) + addr, err := parseAddress(cmd.params[1]) if err != nil { session.reply(502, "Ill-formatted e-mail address") @@ -178,7 +178,7 @@ func (session *session) handleRCPT(cmd command) { return } - addr, err := parseMailAddress(cmd.params[1]) + addr, err := parseAddress(cmd.params[1]) if err != nil { session.reply(502, "Ill-formatted e-mail address") diff --git a/smtpd.go b/smtpd.go index ef463df..85fcdc2 100644 --- a/smtpd.go +++ b/smtpd.go @@ -19,6 +19,9 @@ type Server struct { ReadTimeout time.Duration // Socket timeout for read operations (default: 60s) WriteTimeout time.Duration // Socket timeout for write operations (default: 60s) + MaxMessageSize int // Max message size in bytes (default: 10240000) + MaxConnections int // Max concurrent connections, use -1 to disable (default: 100) + // New e-mails are handed off to this function. // Can be left empty for a NOOP server. // If an error is returned, it will be reported in the SMTP session. @@ -28,8 +31,8 @@ type Server struct { // Can be left empty for no restrictions. // If an error is returned, it will be reported in the SMTP session. HeloChecker func(peer Peer) error // Called after HELO/EHLO. - SenderChecker func(peer Peer, addr MailAddress) error // Called after MAIL FROM. - RecipientChecker func(peer Peer, addr MailAddress) error // Called after each RCPT TO. + SenderChecker func(peer Peer, addr string) error // Called after MAIL FROM. + RecipientChecker func(peer Peer, addr string) error // Called after each RCPT TO. // Enable PLAIN/LOGIN authentication, only available after STARTTLS. // Can be left empty for no authentication support. @@ -37,8 +40,6 @@ type Server struct { TLSConfig *tls.Config // Enable STARTTLS support ForceTLS bool // Force STARTTLS usage - - MaxMessageSize int // Max message size in bytes (default: 10240000) } // Peer represents the client connecting to the server @@ -51,8 +52,8 @@ type Peer struct { // Envelope holds a message type Envelope struct { - Sender MailAddress - Recipients []MailAddress + Sender string + Recipients []string Data []byte } @@ -107,6 +108,14 @@ func (srv *Server) Serve(l net.Listener) error { defer l.Close() + var limiter chan struct{} + + if srv.MaxConnections > 0 { + limiter = make(chan struct{}, srv.MaxConnections) + } else { + limiter = nil + } + for { conn, e := l.Accept() @@ -123,7 +132,19 @@ func (srv *Server) Serve(l net.Listener) error { continue } - go session.serve() + if limiter != nil { + go func() { + select { + case limiter <- struct{}{}: + session.serve() + <-limiter + default: + session.reject() + } + }() + } else { + go session.serve() + } } @@ -135,6 +156,10 @@ func (srv *Server) configureDefaults() { srv.MaxMessageSize = 10240000 } + if srv.MaxConnections == 0 { + srv.MaxConnections = 100 + } + if srv.ReadTimeout == 0 { srv.ReadTimeout = time.Second * 60 } @@ -177,6 +202,11 @@ func (session *session) serve() { } +func (session *session) reject() { + session.reply(450, "Too busy. Try again later.") + session.close() +} + func (session *session) reply(code int, message string) { fmt.Fprintf(session.writer, "%d %s\r\n", code, message) diff --git a/smtpd_test.go b/smtpd_test.go index 51b664f..ccc80d5 100644 --- a/smtpd_test.go +++ b/smtpd_test.go @@ -233,7 +233,7 @@ func TestSenderCheck(t *testing.T) { defer ln.Close() server := &Server{ - SenderChecker: func(peer Peer, addr MailAddress) error { return errors.New("Denied") }, + SenderChecker: func(peer Peer, addr string) error { return errors.New("Denied") }, } go func() { @@ -261,7 +261,7 @@ func TestRecipientCheck(t *testing.T) { defer ln.Close() server := &Server{ - RecipientChecker: func(peer Peer, addr MailAddress) error { return errors.New("Denied") }, + RecipientChecker: func(peer Peer, addr string) error { return errors.New("Denied") }, } go func() { @@ -398,3 +398,33 @@ func TestHandler(t *testing.T) { } } + +func TestMaxConnections(t *testing.T) { + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Listen failed: %v", err) + } + + defer ln.Close() + + server := &Server{ + MaxConnections: 1, + } + + go func() { + server.Serve(ln) + }() + + c1, err := smtp.Dial(ln.Addr().String()) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + + _, err = smtp.Dial(ln.Addr().String()) + if err == nil { + t.Fatal("Dial succeeded despite MaxConnections = 1") + } + + c1.Close() +}