diff --git a/protocol.go b/protocol.go index c36dfa6..d0d68cd 100644 --- a/protocol.go +++ b/protocol.go @@ -39,9 +39,7 @@ func parseLine(line string) (cmd command) { } func (session *session) handle(line string) { - if session.server.ProtocolLogger != nil { - session.server.ProtocolLogger.Printf("%s < %s", session.conn.RemoteAddr(), line) - } + cmd := parseLine(line) // Commands are dispatched to the appropriate handler functions. @@ -50,6 +48,10 @@ func (session *session) handle(line string) { switch cmd.action { + case "PROXY": + session.handlePROXY(cmd) + return + case "HELO": session.handleHELO(cmd) return @@ -268,6 +270,7 @@ func (session *session) handleSTARTTLS(cmd command) { session.reply(220, "Go ahead") if err := tlsConn.Handshake(); err != nil { + session.logError(err, "couldn't perform handshake") session.reply(550, "Handshake error") return } @@ -431,13 +434,19 @@ func (session *session) handleAUTH(cmd command) { case "LOGIN": - session.reply(334, "VXNlcm5hbWU6") + encodedUsername := "" - if !session.scanner.Scan() { - return + if len(cmd.fields) < 3 { + session.reply(334, "VXNlcm5hbWU6") + if !session.scanner.Scan() { + return + } + encodedUsername = session.scanner.Text() + } else { + encodedUsername = cmd.fields[2] } - byteUsername, err := base64.StdEncoding.DecodeString(session.scanner.Text()) + byteUsername, err := base64.StdEncoding.DecodeString(encodedUsername) if err != nil { session.reply(502, "Couldn't decode your credentials") @@ -462,6 +471,7 @@ func (session *session) handleAUTH(cmd command) { default: + session.logf("unknown authentication mechanism: %s", mechanism) session.reply(502, "Unknown authentication mechanism") return @@ -582,3 +592,47 @@ func (session *session) handleXCLIENT(cmd command) { session.welcome() } + +func (session *session) handlePROXY(cmd command) { + + if !session.server.EnableProxyProtocol { + session.reply(550, "Proxy Protocol not enabled") + return + } + + if len(cmd.fields) < 6 { + session.reply(502, "Couldn't decode the command.") + return + } + + var ( + newAddr net.IP = nil + newTCPPort uint64 = 0 + err error + ) + + newAddr = net.ParseIP(cmd.fields[2]) + + newTCPPort, err = strconv.ParseUint(cmd.fields[4], 10, 16) + if err != nil { + session.reply(502, "Couldn't decode the command.") + return + } + + tcpAddr, ok := session.peer.Addr.(*net.TCPAddr) + if !ok { + session.reply(502, "Unsupported network connection") + return + } + + if newAddr != nil { + tcpAddr.IP = newAddr + } + + if newTCPPort != 0 { + tcpAddr.Port = int(newTCPPort) + } + + session.welcome() + +} diff --git a/smtpd.go b/smtpd.go index 27f47ba..ec2e7a3 100644 --- a/smtpd.go +++ b/smtpd.go @@ -7,6 +7,7 @@ import ( "fmt" "log" "net" + "strings" "time" ) @@ -41,7 +42,8 @@ type Server struct { // Can be left empty for no authentication support. Authenticator func(peer Peer, username, password string) error - EnableXCLIENT bool // Enable XCLIENT support (default: false) + EnableXCLIENT bool // Enable XCLIENT support (default: false) + EnableProxyProtocol bool // Enable proxy protocol support (default: false) TLSConfig *tls.Config // Enable STARTTLS support. ForceTLS bool // Force STARTTLS usage. @@ -214,12 +216,16 @@ func (session *session) serve() { defer session.close() - session.welcome() + if !session.server.EnableProxyProtocol { + session.welcome() + } for { for session.scanner.Scan() { - session.handle(session.scanner.Text()) + line := session.scanner.Text() + session.logf("received: %s", strings.TrimSpace(line)) + session.handle(line) } err := session.scanner.Err() @@ -270,9 +276,7 @@ func (session *session) welcome() { } func (session *session) reply(code int, message string) { - if session.server.ProtocolLogger != nil { - session.server.ProtocolLogger.Printf("%s > %d %s", session.conn.RemoteAddr(), code, message) - } + session.logf("sending: %d %s", code, message) fmt.Fprintf(session.writer, "%d %s\r\n", code, message) session.flush() } @@ -291,6 +295,22 @@ func (session *session) error(err error) { } } +func (session *session) logf(format string, v ...interface{}) { + if session.server.ProtocolLogger == nil { + return + } + session.server.ProtocolLogger.Output(2, fmt.Sprintf( + "%s [peer:%s]", + fmt.Sprintf(format, v...), + session.peer.Addr, + )) + +} + +func (session *session) logError(err error, desc string) { + session.logf("%s: %v ", desc, err) +} + func (session *session) extensions() []string { extensions := []string{