diff --git a/envelope.go b/envelope.go index 0fa4fe1..5a98242 100644 --- a/envelope.go +++ b/envelope.go @@ -15,7 +15,7 @@ type Envelope struct { } // AddReceivedLine prepends a Received header to the Data -func (env *Envelope) AddReceivedLine(peer Peer) { +func (env *Envelope) AddReceivedLine(peer *Peer) { tlsDetails := "" diff --git a/example_test.go b/example_test.go index a61241e..24ac8f5 100644 --- a/example_test.go +++ b/example_test.go @@ -2,6 +2,8 @@ package smtpd_test import ( "errors" + "fmt" + "log" "net/smtp" "strings" @@ -18,14 +20,34 @@ func ExampleServer() { // Relay server. Accepts only from single IP address and forwards using the Gmail smtp server = &smtpd.Server{ - HeloChecker: func(peer smtpd.Peer, name string) error { + HeloChecker: func(peer *smtpd.Peer, name string) error { if !strings.HasPrefix(peer.Addr.String(), "42.42.42.42:") { - return errors.New("Denied") + return errors.New("denied") } return nil }, - Handler: func(peer smtpd.Peer, env smtpd.Envelope) error { + RecipientChecker: func(peer *smtpd.Peer, addr string) error { + peer.IncrInt("recipientsQueued", 1) + if strings.HasPrefix("scubad1ver", addr) { + // it is bad idea ;-) + peer.SetString("never_send_to_scuba", addr) + } + return nil + }, + + Handler: func(peer *smtpd.Peer, env smtpd.Envelope) error { + numberOfRecepients, ok := peer.Meta["recipientsQueued"] + if ok { + log.Printf("Sending email for %v recipients...", numberOfRecepients) + } + never, ok := peer.Meta["never_send_to_scuba"] + if ok { + return smtpd.Error{ + Code: 521, + Message: fmt.Sprintf("Sending messages to %s is a bad idea", never), + } + } return smtp.SendMail( "smtp.gmail.com:587", diff --git a/smtpd.go b/smtpd.go index 2ee1ca2..51fd057 100644 --- a/smtpd.go +++ b/smtpd.go @@ -30,20 +30,20 @@ type Server struct { // New e-mails are handed off to this function. // Can be left empty for a NOOP server. // If an error is returned, it will be reported in the SMTP session. - Handler func(peer Peer, env Envelope) error + Handler func(peer *Peer, env Envelope) error // Enable various checks during the SMTP session. // Can be left empty for no restrictions. // If an error is returned, it will be reported in the SMTP session. // Use the Error struct for access to error codes. - ConnectionChecker func(peer Peer) error // Called upon new connection. - HeloChecker func(peer Peer, name string) error // Called after HELO/EHLO. - SenderChecker func(peer Peer, addr string) error // Called after MAIL FROM. - RecipientChecker func(peer Peer, addr string) error // Called after each RCPT TO. + ConnectionChecker func(peer *Peer) error // Called upon new connection. + HeloChecker func(peer *Peer, name string) error // Called after HELO/EHLO. + SenderChecker func(peer *Peer, addr string) error // Called after MAIL FROM. + RecipientChecker func(peer *Peer, addr string) error // Called after each RCPT TO. // Enable PLAIN/LOGIN authentication, only available after STARTTLS. // Can be left empty for no authentication support. - Authenticator func(peer Peer, username, password string) error + Authenticator func(peer *Peer, username, password string) error EnableXCLIENT bool // Enable XCLIENT support (default: false) EnableProxyProtocol bool // Enable proxy protocol support (default: false) @@ -83,6 +83,53 @@ type Peer struct { ServerName string // A copy of Server.Hostname Addr net.Addr // Network address TLS *tls.ConnectionState // TLS Connection details, if on TLS + Meta map[string]interface{} +} + +// IncrInt increments integer metadata by positive or negative delta provided +func (p *Peer) IncrInt(name string, delta int64) (err error) { + if len(p.Meta) == 0 { + p.Meta = make(map[string]interface{}, 0) + } + val, found := p.Meta[name] + if !found { + p.Meta[name] = delta + return nil + } + switch val.(type) { + case int64: + p.Meta[name] = val.(int64) + delta + default: + return fmt.Errorf("meta %s has wrong type instead of int64", name) + } + return nil +} + +// IncrFloat64 increments integer metadata by positive or negative delta provided +func (p *Peer) IncrFloat64(name string, delta float64) (err error) { + if len(p.Meta) == 0 { + p.Meta = make(map[string]interface{}, 0) + } + val, found := p.Meta[name] + if !found { + p.Meta[name] = delta + return nil + } + switch val.(type) { + case float64: + p.Meta[name] = val.(float64) + delta + default: + return fmt.Errorf("meta %s has wrong type instead of float64", name) + } + return nil +} + +// SetString sets string parameter in meta +func (p *Peer) SetString(name, value string) { + if len(p.Meta) == 0 { + p.Meta = make(map[string]interface{}, 0) + } + p.Meta[name] = value } // Error represents an Error reported in the SMTP session. @@ -95,13 +142,13 @@ type Error struct { func (e Error) Error() string { return fmt.Sprintf("%d %s", e.Code, e.Message) } // ErrServerClosed is returned by the Server's Serve and ListenAndServe, -// methods after a call to Shutdown. +// methods after a call to shut down. var ErrServerClosed = errors.New("smtp: Server closed") type session struct { server *Server - peer Peer + peer *Peer envelope *Envelope conn net.Conn @@ -120,7 +167,7 @@ func (srv *Server) newSession(c net.Conn) (s *session) { conn: c, reader: bufio.NewReader(c), writer: bufio.NewWriter(c), - peer: Peer{ + peer: &Peer{ Addr: c.RemoteAddr(), ServerName: srv.Hostname, }, @@ -219,7 +266,7 @@ func (srv *Server) Serve(l net.Listener) error { } -// Shutdown instructs the server to shutdown, starting by closing the +// Shutdown instructs the server to shut down, starting by closing the // associated listener. If wait is true, it will wait for the shutdown // to complete. If wait is false, Wait must be called afterwards. func (srv *Server) Shutdown(wait bool) error { diff --git a/smtpd_test.go b/smtpd_test.go index 0913a24..62903a3 100644 --- a/smtpd_test.go +++ b/smtpd_test.go @@ -256,7 +256,7 @@ func TestListenAndServe(t *testing.T) { func TestSTARTTLS(t *testing.T) { addr, closer := runsslserver(t, &smtpd.Server{ - Authenticator: func(peer smtpd.Peer, username, password string) error { return nil }, + Authenticator: func(peer *smtpd.Peer, username, password string) error { return nil }, ForceTLS: true, ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile), }) @@ -343,7 +343,7 @@ func TestSTARTTLS(t *testing.T) { func TestAuthRejection(t *testing.T) { addr, closer := runsslserver(t, &smtpd.Server{ - Authenticator: func(peer smtpd.Peer, username, password string) error { + Authenticator: func(peer *smtpd.Peer, username, password string) error { return smtpd.Error{Code: 550, Message: "Denied"} }, ForceTLS: true, @@ -394,7 +394,7 @@ func TestAuthNotSupported(t *testing.T) { func TestAuthBypass(t *testing.T) { addr, closer := runsslserver(t, &smtpd.Server{ - Authenticator: func(peer smtpd.Peer, username, password string) error { + Authenticator: func(peer *smtpd.Peer, username, password string) error { return smtpd.Error{Code: 550, Message: "Denied"} }, ForceTLS: true, @@ -421,7 +421,7 @@ func TestAuthBypass(t *testing.T) { func TestConnectionCheck(t *testing.T) { addr, closer := runserver(t, &smtpd.Server{ - ConnectionChecker: func(peer smtpd.Peer) error { + ConnectionChecker: func(peer *smtpd.Peer) error { return smtpd.Error{Code: 552, Message: "Denied"} }, ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile), @@ -438,7 +438,7 @@ func TestConnectionCheck(t *testing.T) { func TestConnectionCheckSimpleError(t *testing.T) { addr, closer := runserver(t, &smtpd.Server{ - ConnectionChecker: func(peer smtpd.Peer) error { + ConnectionChecker: func(peer *smtpd.Peer) error { return errors.New("Denied") }, ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile), @@ -455,7 +455,7 @@ func TestConnectionCheckSimpleError(t *testing.T) { func TestHELOCheck(t *testing.T) { addr, closer := runserver(t, &smtpd.Server{ - HeloChecker: func(peer smtpd.Peer, name string) error { + HeloChecker: func(peer *smtpd.Peer, name string) error { if name != "foobar.local" { t.Fatal("Wrong HELO name") } @@ -480,7 +480,7 @@ func TestHELOCheck(t *testing.T) { func TestSenderCheck(t *testing.T) { addr, closer := runserver(t, &smtpd.Server{ - SenderChecker: func(peer smtpd.Peer, addr string) error { + SenderChecker: func(peer *smtpd.Peer, addr string) error { return smtpd.Error{Code: 552, Message: "Denied"} }, ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile), @@ -502,7 +502,7 @@ func TestSenderCheck(t *testing.T) { func TestRecipientCheck(t *testing.T) { addr, closer := runserver(t, &smtpd.Server{ - RecipientChecker: func(peer smtpd.Peer, addr string) error { + RecipientChecker: func(peer *smtpd.Peer, addr string) error { return smtpd.Error{Code: 552, Message: "Denied"} }, ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile), @@ -571,7 +571,7 @@ func TestMaxMessageSize(t *testing.T) { func TestHandler(t *testing.T) { addr, closer := runserver(t, &smtpd.Server{ - Handler: func(peer smtpd.Peer, env smtpd.Envelope) error { + Handler: func(peer *smtpd.Peer, env smtpd.Envelope) error { if env.Sender != "sender@example.org" { t.Fatalf("Unknown sender: %v", env.Sender) } @@ -628,7 +628,7 @@ func TestHandler(t *testing.T) { func TestRejectHandler(t *testing.T) { addr, closer := runserver(t, &smtpd.Server{ - Handler: func(peer smtpd.Peer, env smtpd.Envelope) error { + Handler: func(peer *smtpd.Peer, env smtpd.Envelope) error { return smtpd.Error{Code: 550, Message: "Rejected"} }, ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile), @@ -851,7 +851,7 @@ func TestDATAbeforeRCPT(t *testing.T) { func TestInterruptedDATA(t *testing.T) { addr, closer := runserver(t, &smtpd.Server{ - Handler: func(peer smtpd.Peer, env smtpd.Envelope) error { + Handler: func(peer *smtpd.Peer, env smtpd.Envelope) error { t.Fatal("Accepted DATA despite disconnection") return nil }, @@ -887,6 +887,74 @@ func TestInterruptedDATA(t *testing.T) { } +func TestMeta(t *testing.T) { + addr, closer := runserver(t, &smtpd.Server{ + MaxConnections: 1, + HeloChecker: func(peer *smtpd.Peer, name string) (err error) { + peer.SetString("something", name) + err = peer.IncrInt("int64", 1) + if err != nil { + return + } + err = peer.IncrFloat64("float64", 1.1) + return err + }, + SenderChecker: func(peer *smtpd.Peer, addr string) (err error) { + if peer.Meta["something"] != "localhost" { + t.Errorf("wrong meta something %s instead of `localhost`", + peer.Meta["something"]) + } + if peer.Meta["int64"] != int64(1) { + t.Errorf("wrong meta int64 %v instead of `1`", peer.Meta["int64"]) + } + if peer.Meta["float64"] != float64(1.1) { + t.Errorf("wrong meta int64 %v instead of `1`", peer.Meta["int64"]) + } + err = peer.IncrInt("int64", 1) + if err != nil { + return + } + err = peer.IncrFloat64("float64", 1.1) + return err + }, + RecipientChecker: func(peer *smtpd.Peer, addr string) error { + return smtpd.Error{ + Code: 451, + Message: fmt.Sprintf("%v %v %s", + peer.Meta["int64"], + peer.Meta["float64"], + peer.Meta["something"], + ), + } + }, + ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile), + }) + + defer closer() + cm, err := smtp.Dial(addr) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + err = cm.Hello("localhost") + if err != nil { + t.Error(err) + } + err = cm.Mail("somebody@localhost") + if err != nil { + t.Error(err) + } + err = cm.Rcpt("scuba@example.org") + if err != nil { + if err.Error() != "451 2 2.2 localhost" { + t.Errorf("wrong error `%s` instead `451 2 2.2 localhost`", err) + } + } + err = cm.Close() + if err != nil { + t.Error(err) + } +} + func TestTimeoutClose(t *testing.T) { addr, closer := runserver(t, &smtpd.Server{ @@ -997,7 +1065,7 @@ func TestXCLIENT(t *testing.T) { addr, closer := runserver(t, &smtpd.Server{ EnableXCLIENT: true, - SenderChecker: func(peer smtpd.Peer, addr string) error { + SenderChecker: func(peer *smtpd.Peer, addr string) error { if peer.HeloName != "new.example.net" { t.Fatalf("Didn't override HELO name: %v", peer.HeloName) } @@ -1068,7 +1136,7 @@ func TestEnvelopeReceived(t *testing.T) { addr, closer := runsslserver(t, &smtpd.Server{ Hostname: "foobar.example.net", - Handler: func(peer smtpd.Peer, env smtpd.Envelope) error { + Handler: func(peer *smtpd.Peer, env smtpd.Envelope) error { env.AddReceivedLine(peer) if !bytes.HasPrefix(env.Data, []byte("Received: from localhost ([127.0.0.1]) by foobar.example.net with ESMTP;")) { t.Fatal("Wrong received line.") @@ -1157,7 +1225,7 @@ func TestHELO(t *testing.T) { func TestLOGINAuth(t *testing.T) { addr, closer := runsslserver(t, &smtpd.Server{ - Authenticator: func(peer smtpd.Peer, username, password string) error { return nil }, + Authenticator: func(peer *smtpd.Peer, username, password string) error { return nil }, ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile), }) @@ -1272,7 +1340,7 @@ func TestErrors(t *testing.T) { } server := &smtpd.Server{ - Authenticator: func(peer smtpd.Peer, username, password string) error { return nil }, + Authenticator: func(peer *smtpd.Peer, username, password string) error { return nil }, ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile), } @@ -1350,7 +1418,7 @@ func TestErrors(t *testing.T) { func TestMailformedMAILFROM(t *testing.T) { addr, closer := runserver(t, &smtpd.Server{ - SenderChecker: func(peer smtpd.Peer, addr string) error { + SenderChecker: func(peer *smtpd.Peer, addr string) error { if addr != "test@example.org" { return smtpd.Error{Code: 502, Message: "Denied"} } @@ -1396,7 +1464,7 @@ func TestTLSListener(t *testing.T) { addr := ln.Addr().String() server := &smtpd.Server{ - Authenticator: func(peer smtpd.Peer, username, password string) error { + Authenticator: func(peer *smtpd.Peer, username, password string) error { if peer.TLS == nil { t.Error("didn't correctly set connection state on TLS connection") }