feat: peer has metadata added + tests
This commit is contained in:
		
							parent
							
								
									7df5779a24
								
							
						
					
					
						commit
						92a3299f16
					
				
					 4 changed files with 168 additions and 31 deletions
				
			
		|  | @ -15,7 +15,7 @@ type Envelope struct { | |||
| } | ||||
| 
 | ||||
| // AddReceivedLine prepends a Received header to the Data | ||||
| func (env *Envelope) AddReceivedLine(peer Peer) { | ||||
| func (env *Envelope) AddReceivedLine(peer *Peer) { | ||||
| 
 | ||||
| 	tlsDetails := "" | ||||
| 
 | ||||
|  |  | |||
|  | @ -2,6 +2,8 @@ package smtpd_test | |||
| 
 | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"log" | ||||
| 	"net/smtp" | ||||
| 	"strings" | ||||
| 
 | ||||
|  | @ -18,14 +20,34 @@ func ExampleServer() { | |||
| 	// Relay server. Accepts only from single IP address and forwards using the Gmail smtp | ||||
| 	server = &smtpd.Server{ | ||||
| 
 | ||||
| 		HeloChecker: func(peer smtpd.Peer, name string) error { | ||||
| 		HeloChecker: func(peer *smtpd.Peer, name string) error { | ||||
| 			if !strings.HasPrefix(peer.Addr.String(), "42.42.42.42:") { | ||||
| 				return errors.New("Denied") | ||||
| 				return errors.New("denied") | ||||
| 			} | ||||
| 			return nil | ||||
| 		}, | ||||
| 
 | ||||
| 		Handler: func(peer smtpd.Peer, env smtpd.Envelope) error { | ||||
| 		RecipientChecker: func(peer *smtpd.Peer, addr string) error { | ||||
| 			peer.IncrInt("recipientsQueued", 1) | ||||
| 			if strings.HasPrefix("scubad1ver", addr) { | ||||
| 				// it is bad idea ;-) | ||||
| 				peer.SetString("never_send_to_scuba", addr) | ||||
| 			} | ||||
| 			return nil | ||||
| 		}, | ||||
| 
 | ||||
| 		Handler: func(peer *smtpd.Peer, env smtpd.Envelope) error { | ||||
| 			numberOfRecepients, ok := peer.Meta["recipientsQueued"] | ||||
| 			if ok { | ||||
| 				log.Printf("Sending email for %v recipients...", numberOfRecepients) | ||||
| 			} | ||||
| 			never, ok := peer.Meta["never_send_to_scuba"] | ||||
| 			if ok { | ||||
| 				return smtpd.Error{ | ||||
| 					Code:    521, | ||||
| 					Message: fmt.Sprintf("Sending messages to %s is a bad idea", never), | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			return smtp.SendMail( | ||||
| 				"smtp.gmail.com:587", | ||||
|  |  | |||
							
								
								
									
										67
									
								
								smtpd.go
									
										
									
									
									
								
							
							
						
						
									
										67
									
								
								smtpd.go
									
										
									
									
									
								
							|  | @ -30,20 +30,20 @@ type Server struct { | |||
| 	// New e-mails are handed off to this function. | ||||
| 	// Can be left empty for a NOOP server. | ||||
| 	// If an error is returned, it will be reported in the SMTP session. | ||||
| 	Handler func(peer Peer, env Envelope) error | ||||
| 	Handler func(peer *Peer, env Envelope) error | ||||
| 
 | ||||
| 	// Enable various checks during the SMTP session. | ||||
| 	// Can be left empty for no restrictions. | ||||
| 	// If an error is returned, it will be reported in the SMTP session. | ||||
| 	// Use the Error struct for access to error codes. | ||||
| 	ConnectionChecker func(peer Peer) error              // Called upon new connection. | ||||
| 	HeloChecker       func(peer Peer, name string) error // Called after HELO/EHLO. | ||||
| 	SenderChecker     func(peer Peer, addr string) error // Called after MAIL FROM. | ||||
| 	RecipientChecker  func(peer Peer, addr string) error // Called after each RCPT TO. | ||||
| 	ConnectionChecker func(peer *Peer) error              // Called upon new connection. | ||||
| 	HeloChecker       func(peer *Peer, name string) error // Called after HELO/EHLO. | ||||
| 	SenderChecker     func(peer *Peer, addr string) error // Called after MAIL FROM. | ||||
| 	RecipientChecker  func(peer *Peer, addr string) error // Called after each RCPT TO. | ||||
| 
 | ||||
| 	// Enable PLAIN/LOGIN authentication, only available after STARTTLS. | ||||
| 	// Can be left empty for no authentication support. | ||||
| 	Authenticator func(peer Peer, username, password string) error | ||||
| 	Authenticator func(peer *Peer, username, password string) error | ||||
| 
 | ||||
| 	EnableXCLIENT       bool // Enable XCLIENT support (default: false) | ||||
| 	EnableProxyProtocol bool // Enable proxy protocol support (default: false) | ||||
|  | @ -83,6 +83,53 @@ type Peer struct { | |||
| 	ServerName string               // A copy of Server.Hostname | ||||
| 	Addr       net.Addr             // Network address | ||||
| 	TLS        *tls.ConnectionState // TLS Connection details, if on TLS | ||||
| 	Meta       map[string]interface{} | ||||
| } | ||||
| 
 | ||||
| // IncrInt increments integer metadata by positive or negative delta provided | ||||
| func (p *Peer) IncrInt(name string, delta int64) (err error) { | ||||
| 	if len(p.Meta) == 0 { | ||||
| 		p.Meta = make(map[string]interface{}, 0) | ||||
| 	} | ||||
| 	val, found := p.Meta[name] | ||||
| 	if !found { | ||||
| 		p.Meta[name] = delta | ||||
| 		return nil | ||||
| 	} | ||||
| 	switch val.(type) { | ||||
| 	case int64: | ||||
| 		p.Meta[name] = val.(int64) + delta | ||||
| 	default: | ||||
| 		return fmt.Errorf("meta %s has wrong type instead of int64", name) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // IncrFloat64 increments integer metadata by positive or negative delta provided | ||||
| func (p *Peer) IncrFloat64(name string, delta float64) (err error) { | ||||
| 	if len(p.Meta) == 0 { | ||||
| 		p.Meta = make(map[string]interface{}, 0) | ||||
| 	} | ||||
| 	val, found := p.Meta[name] | ||||
| 	if !found { | ||||
| 		p.Meta[name] = delta | ||||
| 		return nil | ||||
| 	} | ||||
| 	switch val.(type) { | ||||
| 	case float64: | ||||
| 		p.Meta[name] = val.(float64) + delta | ||||
| 	default: | ||||
| 		return fmt.Errorf("meta %s has wrong type instead of float64", name) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // SetString sets string parameter in meta | ||||
| func (p *Peer) SetString(name, value string) { | ||||
| 	if len(p.Meta) == 0 { | ||||
| 		p.Meta = make(map[string]interface{}, 0) | ||||
| 	} | ||||
| 	p.Meta[name] = value | ||||
| } | ||||
| 
 | ||||
| // Error represents an Error reported in the SMTP session. | ||||
|  | @ -95,13 +142,13 @@ type Error struct { | |||
| func (e Error) Error() string { return fmt.Sprintf("%d %s", e.Code, e.Message) } | ||||
| 
 | ||||
| // ErrServerClosed is returned by the Server's Serve and ListenAndServe, | ||||
| // methods after a call to Shutdown. | ||||
| // methods after a call to shut down. | ||||
| var ErrServerClosed = errors.New("smtp: Server closed") | ||||
| 
 | ||||
| type session struct { | ||||
| 	server *Server | ||||
| 
 | ||||
| 	peer     Peer | ||||
| 	peer     *Peer | ||||
| 	envelope *Envelope | ||||
| 
 | ||||
| 	conn net.Conn | ||||
|  | @ -120,7 +167,7 @@ func (srv *Server) newSession(c net.Conn) (s *session) { | |||
| 		conn:   c, | ||||
| 		reader: bufio.NewReader(c), | ||||
| 		writer: bufio.NewWriter(c), | ||||
| 		peer: Peer{ | ||||
| 		peer: &Peer{ | ||||
| 			Addr:       c.RemoteAddr(), | ||||
| 			ServerName: srv.Hostname, | ||||
| 		}, | ||||
|  | @ -219,7 +266,7 @@ func (srv *Server) Serve(l net.Listener) error { | |||
| 
 | ||||
| } | ||||
| 
 | ||||
| // Shutdown instructs the server to shutdown, starting by closing the | ||||
| // Shutdown instructs the server to shut down, starting by closing the | ||||
| // associated listener. If wait is true, it will wait for the shutdown | ||||
| // to complete. If wait is false, Wait must be called afterwards. | ||||
| func (srv *Server) Shutdown(wait bool) error { | ||||
|  |  | |||
							
								
								
									
										102
									
								
								smtpd_test.go
									
										
									
									
									
								
							
							
						
						
									
										102
									
								
								smtpd_test.go
									
										
									
									
									
								
							|  | @ -256,7 +256,7 @@ func TestListenAndServe(t *testing.T) { | |||
| func TestSTARTTLS(t *testing.T) { | ||||
| 
 | ||||
| 	addr, closer := runsslserver(t, &smtpd.Server{ | ||||
| 		Authenticator:  func(peer smtpd.Peer, username, password string) error { return nil }, | ||||
| 		Authenticator:  func(peer *smtpd.Peer, username, password string) error { return nil }, | ||||
| 		ForceTLS:       true, | ||||
| 		ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile), | ||||
| 	}) | ||||
|  | @ -343,7 +343,7 @@ func TestSTARTTLS(t *testing.T) { | |||
| func TestAuthRejection(t *testing.T) { | ||||
| 
 | ||||
| 	addr, closer := runsslserver(t, &smtpd.Server{ | ||||
| 		Authenticator: func(peer smtpd.Peer, username, password string) error { | ||||
| 		Authenticator: func(peer *smtpd.Peer, username, password string) error { | ||||
| 			return smtpd.Error{Code: 550, Message: "Denied"} | ||||
| 		}, | ||||
| 		ForceTLS:       true, | ||||
|  | @ -394,7 +394,7 @@ func TestAuthNotSupported(t *testing.T) { | |||
| func TestAuthBypass(t *testing.T) { | ||||
| 
 | ||||
| 	addr, closer := runsslserver(t, &smtpd.Server{ | ||||
| 		Authenticator: func(peer smtpd.Peer, username, password string) error { | ||||
| 		Authenticator: func(peer *smtpd.Peer, username, password string) error { | ||||
| 			return smtpd.Error{Code: 550, Message: "Denied"} | ||||
| 		}, | ||||
| 		ForceTLS:       true, | ||||
|  | @ -421,7 +421,7 @@ func TestAuthBypass(t *testing.T) { | |||
| func TestConnectionCheck(t *testing.T) { | ||||
| 
 | ||||
| 	addr, closer := runserver(t, &smtpd.Server{ | ||||
| 		ConnectionChecker: func(peer smtpd.Peer) error { | ||||
| 		ConnectionChecker: func(peer *smtpd.Peer) error { | ||||
| 			return smtpd.Error{Code: 552, Message: "Denied"} | ||||
| 		}, | ||||
| 		ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile), | ||||
|  | @ -438,7 +438,7 @@ func TestConnectionCheck(t *testing.T) { | |||
| func TestConnectionCheckSimpleError(t *testing.T) { | ||||
| 
 | ||||
| 	addr, closer := runserver(t, &smtpd.Server{ | ||||
| 		ConnectionChecker: func(peer smtpd.Peer) error { | ||||
| 		ConnectionChecker: func(peer *smtpd.Peer) error { | ||||
| 			return errors.New("Denied") | ||||
| 		}, | ||||
| 		ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile), | ||||
|  | @ -455,7 +455,7 @@ func TestConnectionCheckSimpleError(t *testing.T) { | |||
| func TestHELOCheck(t *testing.T) { | ||||
| 
 | ||||
| 	addr, closer := runserver(t, &smtpd.Server{ | ||||
| 		HeloChecker: func(peer smtpd.Peer, name string) error { | ||||
| 		HeloChecker: func(peer *smtpd.Peer, name string) error { | ||||
| 			if name != "foobar.local" { | ||||
| 				t.Fatal("Wrong HELO name") | ||||
| 			} | ||||
|  | @ -480,7 +480,7 @@ func TestHELOCheck(t *testing.T) { | |||
| func TestSenderCheck(t *testing.T) { | ||||
| 
 | ||||
| 	addr, closer := runserver(t, &smtpd.Server{ | ||||
| 		SenderChecker: func(peer smtpd.Peer, addr string) error { | ||||
| 		SenderChecker: func(peer *smtpd.Peer, addr string) error { | ||||
| 			return smtpd.Error{Code: 552, Message: "Denied"} | ||||
| 		}, | ||||
| 		ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile), | ||||
|  | @ -502,7 +502,7 @@ func TestSenderCheck(t *testing.T) { | |||
| func TestRecipientCheck(t *testing.T) { | ||||
| 
 | ||||
| 	addr, closer := runserver(t, &smtpd.Server{ | ||||
| 		RecipientChecker: func(peer smtpd.Peer, addr string) error { | ||||
| 		RecipientChecker: func(peer *smtpd.Peer, addr string) error { | ||||
| 			return smtpd.Error{Code: 552, Message: "Denied"} | ||||
| 		}, | ||||
| 		ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile), | ||||
|  | @ -571,7 +571,7 @@ func TestMaxMessageSize(t *testing.T) { | |||
| func TestHandler(t *testing.T) { | ||||
| 
 | ||||
| 	addr, closer := runserver(t, &smtpd.Server{ | ||||
| 		Handler: func(peer smtpd.Peer, env smtpd.Envelope) error { | ||||
| 		Handler: func(peer *smtpd.Peer, env smtpd.Envelope) error { | ||||
| 			if env.Sender != "sender@example.org" { | ||||
| 				t.Fatalf("Unknown sender: %v", env.Sender) | ||||
| 			} | ||||
|  | @ -628,7 +628,7 @@ func TestHandler(t *testing.T) { | |||
| func TestRejectHandler(t *testing.T) { | ||||
| 
 | ||||
| 	addr, closer := runserver(t, &smtpd.Server{ | ||||
| 		Handler: func(peer smtpd.Peer, env smtpd.Envelope) error { | ||||
| 		Handler: func(peer *smtpd.Peer, env smtpd.Envelope) error { | ||||
| 			return smtpd.Error{Code: 550, Message: "Rejected"} | ||||
| 		}, | ||||
| 		ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile), | ||||
|  | @ -851,7 +851,7 @@ func TestDATAbeforeRCPT(t *testing.T) { | |||
| func TestInterruptedDATA(t *testing.T) { | ||||
| 
 | ||||
| 	addr, closer := runserver(t, &smtpd.Server{ | ||||
| 		Handler: func(peer smtpd.Peer, env smtpd.Envelope) error { | ||||
| 		Handler: func(peer *smtpd.Peer, env smtpd.Envelope) error { | ||||
| 			t.Fatal("Accepted DATA despite disconnection") | ||||
| 			return nil | ||||
| 		}, | ||||
|  | @ -887,6 +887,74 @@ func TestInterruptedDATA(t *testing.T) { | |||
| 
 | ||||
| } | ||||
| 
 | ||||
| func TestMeta(t *testing.T) { | ||||
| 	addr, closer := runserver(t, &smtpd.Server{ | ||||
| 		MaxConnections: 1, | ||||
| 		HeloChecker: func(peer *smtpd.Peer, name string) (err error) { | ||||
| 			peer.SetString("something", name) | ||||
| 			err = peer.IncrInt("int64", 1) | ||||
| 			if err != nil { | ||||
| 				return | ||||
| 			} | ||||
| 			err = peer.IncrFloat64("float64", 1.1) | ||||
| 			return err | ||||
| 		}, | ||||
| 		SenderChecker: func(peer *smtpd.Peer, addr string) (err error) { | ||||
| 			if peer.Meta["something"] != "localhost" { | ||||
| 				t.Errorf("wrong meta something %s instead of `localhost`", | ||||
| 					peer.Meta["something"]) | ||||
| 			} | ||||
| 			if peer.Meta["int64"] != int64(1) { | ||||
| 				t.Errorf("wrong meta int64 %v instead of `1`", peer.Meta["int64"]) | ||||
| 			} | ||||
| 			if peer.Meta["float64"] != float64(1.1) { | ||||
| 				t.Errorf("wrong meta int64 %v instead of `1`", peer.Meta["int64"]) | ||||
| 			} | ||||
| 			err = peer.IncrInt("int64", 1) | ||||
| 			if err != nil { | ||||
| 				return | ||||
| 			} | ||||
| 			err = peer.IncrFloat64("float64", 1.1) | ||||
| 			return err | ||||
| 		}, | ||||
| 		RecipientChecker: func(peer *smtpd.Peer, addr string) error { | ||||
| 			return smtpd.Error{ | ||||
| 				Code: 451, | ||||
| 				Message: fmt.Sprintf("%v %v %s", | ||||
| 					peer.Meta["int64"], | ||||
| 					peer.Meta["float64"], | ||||
| 					peer.Meta["something"], | ||||
| 				), | ||||
| 			} | ||||
| 		}, | ||||
| 		ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile), | ||||
| 	}) | ||||
| 
 | ||||
| 	defer closer() | ||||
| 	cm, err := smtp.Dial(addr) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Dial failed: %v", err) | ||||
| 	} | ||||
| 	err = cm.Hello("localhost") | ||||
| 	if err != nil { | ||||
| 		t.Error(err) | ||||
| 	} | ||||
| 	err = cm.Mail("somebody@localhost") | ||||
| 	if err != nil { | ||||
| 		t.Error(err) | ||||
| 	} | ||||
| 	err = cm.Rcpt("scuba@example.org") | ||||
| 	if err != nil { | ||||
| 		if err.Error() != "451 2 2.2 localhost" { | ||||
| 			t.Errorf("wrong error `%s` instead `451 2 2.2 localhost`", err) | ||||
| 		} | ||||
| 	} | ||||
| 	err = cm.Close() | ||||
| 	if err != nil { | ||||
| 		t.Error(err) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestTimeoutClose(t *testing.T) { | ||||
| 
 | ||||
| 	addr, closer := runserver(t, &smtpd.Server{ | ||||
|  | @ -997,7 +1065,7 @@ func TestXCLIENT(t *testing.T) { | |||
| 
 | ||||
| 	addr, closer := runserver(t, &smtpd.Server{ | ||||
| 		EnableXCLIENT: true, | ||||
| 		SenderChecker: func(peer smtpd.Peer, addr string) error { | ||||
| 		SenderChecker: func(peer *smtpd.Peer, addr string) error { | ||||
| 			if peer.HeloName != "new.example.net" { | ||||
| 				t.Fatalf("Didn't override HELO name: %v", peer.HeloName) | ||||
| 			} | ||||
|  | @ -1068,7 +1136,7 @@ func TestEnvelopeReceived(t *testing.T) { | |||
| 
 | ||||
| 	addr, closer := runsslserver(t, &smtpd.Server{ | ||||
| 		Hostname: "foobar.example.net", | ||||
| 		Handler: func(peer smtpd.Peer, env smtpd.Envelope) error { | ||||
| 		Handler: func(peer *smtpd.Peer, env smtpd.Envelope) error { | ||||
| 			env.AddReceivedLine(peer) | ||||
| 			if !bytes.HasPrefix(env.Data, []byte("Received: from localhost ([127.0.0.1]) by foobar.example.net with ESMTP;")) { | ||||
| 				t.Fatal("Wrong received line.") | ||||
|  | @ -1157,7 +1225,7 @@ func TestHELO(t *testing.T) { | |||
| func TestLOGINAuth(t *testing.T) { | ||||
| 
 | ||||
| 	addr, closer := runsslserver(t, &smtpd.Server{ | ||||
| 		Authenticator:  func(peer smtpd.Peer, username, password string) error { return nil }, | ||||
| 		Authenticator:  func(peer *smtpd.Peer, username, password string) error { return nil }, | ||||
| 		ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile), | ||||
| 	}) | ||||
| 
 | ||||
|  | @ -1272,7 +1340,7 @@ func TestErrors(t *testing.T) { | |||
| 	} | ||||
| 
 | ||||
| 	server := &smtpd.Server{ | ||||
| 		Authenticator:  func(peer smtpd.Peer, username, password string) error { return nil }, | ||||
| 		Authenticator:  func(peer *smtpd.Peer, username, password string) error { return nil }, | ||||
| 		ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile), | ||||
| 	} | ||||
| 
 | ||||
|  | @ -1350,7 +1418,7 @@ func TestErrors(t *testing.T) { | |||
| func TestMailformedMAILFROM(t *testing.T) { | ||||
| 
 | ||||
| 	addr, closer := runserver(t, &smtpd.Server{ | ||||
| 		SenderChecker: func(peer smtpd.Peer, addr string) error { | ||||
| 		SenderChecker: func(peer *smtpd.Peer, addr string) error { | ||||
| 			if addr != "test@example.org" { | ||||
| 				return smtpd.Error{Code: 502, Message: "Denied"} | ||||
| 			} | ||||
|  | @ -1396,7 +1464,7 @@ func TestTLSListener(t *testing.T) { | |||
| 	addr := ln.Addr().String() | ||||
| 
 | ||||
| 	server := &smtpd.Server{ | ||||
| 		Authenticator: func(peer smtpd.Peer, username, password string) error { | ||||
| 		Authenticator: func(peer *smtpd.Peer, username, password string) error { | ||||
| 			if peer.TLS == nil { | ||||
| 				t.Error("didn't correctly set connection state on TLS connection") | ||||
| 			} | ||||
|  |  | |||
		Loading…
	
	Add table
		
		Reference in a new issue