This commit is contained in:
Christian Joergensen 2017-05-22 19:54:10 +02:00
commit 83640db51d
2 changed files with 87 additions and 13 deletions

View file

@ -39,9 +39,7 @@ func parseLine(line string) (cmd command) {
} }
func (session *session) handle(line string) { func (session *session) handle(line string) {
if session.server.ProtocolLogger != nil {
session.server.ProtocolLogger.Printf("%s < %s", session.conn.RemoteAddr(), line)
}
cmd := parseLine(line) cmd := parseLine(line)
// Commands are dispatched to the appropriate handler functions. // Commands are dispatched to the appropriate handler functions.
@ -50,6 +48,10 @@ func (session *session) handle(line string) {
switch cmd.action { switch cmd.action {
case "PROXY":
session.handlePROXY(cmd)
return
case "HELO": case "HELO":
session.handleHELO(cmd) session.handleHELO(cmd)
return return
@ -268,6 +270,7 @@ func (session *session) handleSTARTTLS(cmd command) {
session.reply(220, "Go ahead") session.reply(220, "Go ahead")
if err := tlsConn.Handshake(); err != nil { if err := tlsConn.Handshake(); err != nil {
session.logError(err, "couldn't perform handshake")
session.reply(550, "Handshake error") session.reply(550, "Handshake error")
return return
} }
@ -431,13 +434,19 @@ func (session *session) handleAUTH(cmd command) {
case "LOGIN": case "LOGIN":
session.reply(334, "VXNlcm5hbWU6") encodedUsername := ""
if len(cmd.fields) < 3 {
session.reply(334, "VXNlcm5hbWU6")
if !session.scanner.Scan() { if !session.scanner.Scan() {
return return
} }
encodedUsername = session.scanner.Text()
} else {
encodedUsername = cmd.fields[2]
}
byteUsername, err := base64.StdEncoding.DecodeString(session.scanner.Text()) byteUsername, err := base64.StdEncoding.DecodeString(encodedUsername)
if err != nil { if err != nil {
session.reply(502, "Couldn't decode your credentials") session.reply(502, "Couldn't decode your credentials")
@ -462,6 +471,7 @@ func (session *session) handleAUTH(cmd command) {
default: default:
session.logf("unknown authentication mechanism: %s", mechanism)
session.reply(502, "Unknown authentication mechanism") session.reply(502, "Unknown authentication mechanism")
return return
@ -582,3 +592,47 @@ func (session *session) handleXCLIENT(cmd command) {
session.welcome() session.welcome()
} }
func (session *session) handlePROXY(cmd command) {
if !session.server.EnableProxyProtocol {
session.reply(550, "Proxy Protocol not enabled")
return
}
if len(cmd.fields) < 6 {
session.reply(502, "Couldn't decode the command.")
return
}
var (
newAddr net.IP = nil
newTCPPort uint64 = 0
err error
)
newAddr = net.ParseIP(cmd.fields[2])
newTCPPort, err = strconv.ParseUint(cmd.fields[4], 10, 16)
if err != nil {
session.reply(502, "Couldn't decode the command.")
return
}
tcpAddr, ok := session.peer.Addr.(*net.TCPAddr)
if !ok {
session.reply(502, "Unsupported network connection")
return
}
if newAddr != nil {
tcpAddr.IP = newAddr
}
if newTCPPort != 0 {
tcpAddr.Port = int(newTCPPort)
}
session.welcome()
}

View file

@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"log" "log"
"net" "net"
"strings"
"time" "time"
) )
@ -42,6 +43,7 @@ type Server struct {
Authenticator func(peer Peer, username, password string) error Authenticator func(peer Peer, username, password string) error
EnableXCLIENT bool // Enable XCLIENT support (default: false) EnableXCLIENT bool // Enable XCLIENT support (default: false)
EnableProxyProtocol bool // Enable proxy protocol support (default: false)
TLSConfig *tls.Config // Enable STARTTLS support. TLSConfig *tls.Config // Enable STARTTLS support.
ForceTLS bool // Force STARTTLS usage. ForceTLS bool // Force STARTTLS usage.
@ -214,12 +216,16 @@ func (session *session) serve() {
defer session.close() defer session.close()
if !session.server.EnableProxyProtocol {
session.welcome() session.welcome()
}
for { for {
for session.scanner.Scan() { for session.scanner.Scan() {
session.handle(session.scanner.Text()) line := session.scanner.Text()
session.logf("received: %s", strings.TrimSpace(line))
session.handle(line)
} }
err := session.scanner.Err() err := session.scanner.Err()
@ -270,9 +276,7 @@ func (session *session) welcome() {
} }
func (session *session) reply(code int, message string) { func (session *session) reply(code int, message string) {
if session.server.ProtocolLogger != nil { session.logf("sending: %d %s", code, message)
session.server.ProtocolLogger.Printf("%s > %d %s", session.conn.RemoteAddr(), code, message)
}
fmt.Fprintf(session.writer, "%d %s\r\n", code, message) fmt.Fprintf(session.writer, "%d %s\r\n", code, message)
session.flush() session.flush()
} }
@ -291,6 +295,22 @@ func (session *session) error(err error) {
} }
} }
func (session *session) logf(format string, v ...interface{}) {
if session.server.ProtocolLogger == nil {
return
}
session.server.ProtocolLogger.Output(2, fmt.Sprintf(
"%s [peer:%s]",
fmt.Sprintf(format, v...),
session.peer.Addr,
))
}
func (session *session) logError(err error, desc string) {
session.logf("%s: %v ", desc, err)
}
func (session *session) extensions() []string { func (session *session) extensions() []string {
extensions := []string{ extensions := []string{