From afa9ece3d0bb3184bb15f105dac89d9206c1156b Mon Sep 17 00:00:00 2001 From: Christian Joergensen Date: Mon, 22 May 2017 19:45:42 +0200 Subject: [PATCH] Proxy protocol support, go vet. --- protocol.go | 64 +++++++++++++++++++++++++++++++++++++++++++++++++---- smtpd.go | 31 +++++++++++++++++++++++--- 2 files changed, 88 insertions(+), 7 deletions(-) diff --git a/protocol.go b/protocol.go index 85ade92..6d70294 100644 --- a/protocol.go +++ b/protocol.go @@ -48,6 +48,10 @@ func (session *session) handle(line string) { switch cmd.action { + case "PROXY": + session.handlePROXY(cmd) + return + case "HELO": session.handleHELO(cmd) return @@ -258,6 +262,7 @@ func (session *session) handleSTARTTLS(cmd command) { session.reply(220, "Go ahead") if err := tlsConn.Handshake(); err != nil { + session.logError(err, "couldn't perform handshake") session.reply(550, "Handshake error") return } @@ -417,13 +422,19 @@ func (session *session) handleAUTH(cmd command) { case "LOGIN": - session.reply(334, "VXNlcm5hbWU6") + encodedUsername := "" - if !session.scanner.Scan() { - return + if len(cmd.fields) < 3 { + session.reply(334, "VXNlcm5hbWU6") + if !session.scanner.Scan() { + return + } + encodedUsername = session.scanner.Text() + } else { + encodedUsername = cmd.fields[2] } - byteUsername, err := base64.StdEncoding.DecodeString(session.scanner.Text()) + byteUsername, err := base64.StdEncoding.DecodeString(encodedUsername) if err != nil { session.reply(502, "Couldn't decode your credentials") @@ -448,6 +459,7 @@ func (session *session) handleAUTH(cmd command) { default: + session.logf("unknown authentication mechanism: %s", mechanism) session.reply(502, "Unknown authentication mechanism") return @@ -564,3 +576,47 @@ func (session *session) handleXCLIENT(cmd command) { session.welcome() } + +func (session *session) handlePROXY(cmd command) { + + if !session.server.EnableProxyProtocol { + session.reply(550, "Proxy Protocol not enabled") + return + } + + if len(cmd.fields) < 6 { + session.reply(502, "Couldn't decode the command.") + return + } + + var ( + newAddr net.IP = nil + newTCPPort uint64 = 0 + err error + ) + + newAddr = net.ParseIP(cmd.fields[2]) + + newTCPPort, err = strconv.ParseUint(cmd.fields[4], 10, 16) + if err != nil { + session.reply(502, "Couldn't decode the command.") + return + } + + tcpAddr, ok := session.peer.Addr.(*net.TCPAddr) + if !ok { + session.reply(502, "Unsupported network connection") + return + } + + if newAddr != nil { + tcpAddr.IP = newAddr + } + + if newTCPPort != 0 { + tcpAddr.Port = int(newTCPPort) + } + + session.welcome() + +} diff --git a/smtpd.go b/smtpd.go index 08f84e0..0c24d4c 100644 --- a/smtpd.go +++ b/smtpd.go @@ -7,6 +7,7 @@ import ( "fmt" "log" "net" + "strings" "time" ) @@ -41,10 +42,13 @@ type Server struct { // Can be left empty for no authentication support. Authenticator func(peer Peer, username, password string) error - EnableXCLIENT bool // Enable XCLIENT support (default: false) + EnableXCLIENT bool // Enable XCLIENT support (default: false) + EnableProxyProtocol bool // Enable proxy protocol support (default: false) TLSConfig *tls.Config // Enable STARTTLS support. ForceTLS bool // Force STARTTLS usage. + + Logger *log.Logger } // Protocol represents the protocol used in the SMTP session @@ -212,12 +216,16 @@ func (session *session) serve() { defer session.close() - session.welcome() + if !session.server.EnableProxyProtocol { + session.welcome() + } for { for session.scanner.Scan() { - session.handle(session.scanner.Text()) + line := session.scanner.Text() + session.logf("received line: %s", strings.TrimSpace(line)) + session.handle(line) } err := session.scanner.Err() @@ -268,6 +276,7 @@ func (session *session) welcome() { } func (session *session) reply(code int, message string) { + session.logf("sending line: %d %s", code, message) fmt.Fprintf(session.writer, "%d %s\r\n", code, message) session.flush() } @@ -286,6 +295,22 @@ func (session *session) error(err error) { } } +func (session *session) logf(format string, v ...interface{}) { + if session.server.Logger == nil { + return + } + session.server.Logger.Output(2, fmt.Sprintf( + "%s [peer:%s]", + fmt.Sprintf(format, v...), + session.peer.Addr, + )) + +} + +func (session *session) logError(err error, desc string) { + session.logf("%s: %v ", desc, err) +} + func (session *session) extensions() []string { extensions := []string{