slight refractoring
This commit is contained in:
parent
fb16a3fc81
commit
71409bf7cd
3 changed files with 191 additions and 174 deletions
28
onceCloseListerner.go
Normal file
28
onceCloseListerner.go
Normal file
|
@ -0,0 +1,28 @@
|
|||
package smtpd
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// onceCloseListener wraps a net.Listener, protecting it from
|
||||
// multiple Close calls.
|
||||
type onceCloseListener struct {
|
||||
net.Listener
|
||||
once sync.Once
|
||||
closeErr error
|
||||
}
|
||||
|
||||
func (oc *onceCloseListener) Close() error {
|
||||
oc.once.Do(oc.close)
|
||||
return oc.closeErr
|
||||
}
|
||||
|
||||
func (oc *onceCloseListener) close() { oc.closeErr = oc.Listener.Close() }
|
||||
|
||||
type atomicBool int32
|
||||
|
||||
func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 }
|
||||
func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) }
|
||||
func (b *atomicBool) setFalse() { atomic.StoreInt32((*int32)(b), 0) }
|
163
session.go
Normal file
163
session.go
Normal file
|
@ -0,0 +1,163 @@
|
|||
package smtpd
|
||||
|
||||
import (
|
||||
"net"
|
||||
"bufio"
|
||||
"strings"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
type session struct {
|
||||
server *Server
|
||||
|
||||
peer Peer
|
||||
envelope *Envelope
|
||||
|
||||
conn net.Conn
|
||||
|
||||
reader *bufio.Reader
|
||||
writer *bufio.Writer
|
||||
scanner *bufio.Scanner
|
||||
|
||||
tls bool
|
||||
}
|
||||
|
||||
|
||||
func (session *session) serve() {
|
||||
|
||||
defer session.close()
|
||||
|
||||
if !session.server.EnableProxyProtocol {
|
||||
session.welcome()
|
||||
}
|
||||
|
||||
for {
|
||||
|
||||
for session.scanner.Scan() {
|
||||
line := session.scanner.Text()
|
||||
session.logf("received: %s", strings.TrimSpace(line))
|
||||
session.handle(line)
|
||||
}
|
||||
|
||||
err := session.scanner.Err()
|
||||
|
||||
if err == bufio.ErrTooLong {
|
||||
|
||||
session.reply(500, "Line too long")
|
||||
|
||||
// Advance reader to the next newline
|
||||
|
||||
session.reader.ReadString('\n')
|
||||
session.scanner = bufio.NewScanner(session.reader)
|
||||
|
||||
// Reset and have the client start over.
|
||||
|
||||
session.reset()
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (session *session) reject() {
|
||||
session.reply(421, "Too busy. Try again later.")
|
||||
session.close()
|
||||
}
|
||||
|
||||
func (session *session) reset() {
|
||||
session.envelope = nil
|
||||
}
|
||||
|
||||
func (session *session) welcome() {
|
||||
|
||||
if session.server.ConnectionChecker != nil {
|
||||
err := session.server.ConnectionChecker(session.peer)
|
||||
if err != nil {
|
||||
session.error(err)
|
||||
session.close()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
session.reply(220, session.server.WelcomeMessage)
|
||||
|
||||
}
|
||||
|
||||
func (session *session) reply(code int, message string) {
|
||||
session.logf("sending: %d %s", code, message)
|
||||
fmt.Fprintf(session.writer, "%d %s\r\n", code, message)
|
||||
session.flush()
|
||||
}
|
||||
|
||||
func (session *session) flush() {
|
||||
session.conn.SetWriteDeadline(time.Now().Add(session.server.WriteTimeout))
|
||||
session.writer.Flush()
|
||||
session.conn.SetReadDeadline(time.Now().Add(session.server.ReadTimeout))
|
||||
}
|
||||
|
||||
func (session *session) error(err error) {
|
||||
if smtpdError, ok := err.(Error); ok {
|
||||
session.reply(smtpdError.Code, smtpdError.Message)
|
||||
} else {
|
||||
session.reply(502, fmt.Sprintf("%s", err))
|
||||
}
|
||||
}
|
||||
|
||||
func (session *session) logf(format string, v ...interface{}) {
|
||||
if session.server.ProtocolLogger == nil {
|
||||
return
|
||||
}
|
||||
session.server.ProtocolLogger.Info(fmt.Sprintf(
|
||||
"%s [peer:%s]",
|
||||
fmt.Sprintf(format, v...),
|
||||
session.peer.Addr,
|
||||
))
|
||||
|
||||
}
|
||||
|
||||
func (session *session) logError(err error, desc string) {
|
||||
session.server.ProtocolLogger.Error(desc, "error", err)
|
||||
}
|
||||
|
||||
func (session *session) extensions() []string {
|
||||
|
||||
extensions := []string{
|
||||
fmt.Sprintf("SIZE %d", session.server.MaxMessageSize),
|
||||
"8BITMIME",
|
||||
"PIPELINING",
|
||||
}
|
||||
|
||||
if session.server.EnableXCLIENT {
|
||||
extensions = append(extensions, "XCLIENT")
|
||||
}
|
||||
|
||||
if session.server.TLSConfig != nil && !session.tls {
|
||||
extensions = append(extensions, "STARTTLS")
|
||||
}
|
||||
|
||||
if session.server.Authenticator != nil && session.tls {
|
||||
extensions = append(extensions, "AUTH PLAIN LOGIN")
|
||||
}
|
||||
|
||||
return extensions
|
||||
|
||||
}
|
||||
|
||||
func (session *session) deliver() error {
|
||||
if session.server.Handler != nil {
|
||||
return session.server.Handler(session.peer, *session.envelope)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (session *session) close() {
|
||||
session.writer.Flush()
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
session.conn.Close()
|
||||
}
|
||||
|
||||
|
174
smtpd.go
174
smtpd.go
|
@ -9,9 +9,7 @@ import (
|
|||
"log/slog"
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
@ -98,21 +96,6 @@ func (e Error) Error() string { return fmt.Sprintf("%d %s", e.Code, e.Message) }
|
|||
// methods after a call to Shutdown.
|
||||
var ErrServerClosed = errors.New("smtp: Server closed")
|
||||
|
||||
type session struct {
|
||||
server *Server
|
||||
|
||||
peer Peer
|
||||
envelope *Envelope
|
||||
|
||||
conn net.Conn
|
||||
|
||||
reader *bufio.Reader
|
||||
writer *bufio.Writer
|
||||
scanner *bufio.Scanner
|
||||
|
||||
tls bool
|
||||
}
|
||||
|
||||
func (srv *Server) newSession(c net.Conn) (s *session) {
|
||||
|
||||
s = &session{
|
||||
|
@ -298,143 +281,6 @@ func (srv *Server) configureDefaults() {
|
|||
|
||||
}
|
||||
|
||||
func (session *session) serve() {
|
||||
|
||||
defer session.close()
|
||||
|
||||
if !session.server.EnableProxyProtocol {
|
||||
session.welcome()
|
||||
}
|
||||
|
||||
for {
|
||||
|
||||
for session.scanner.Scan() {
|
||||
line := session.scanner.Text()
|
||||
session.logf("received: %s", strings.TrimSpace(line))
|
||||
session.handle(line)
|
||||
}
|
||||
|
||||
err := session.scanner.Err()
|
||||
|
||||
if err == bufio.ErrTooLong {
|
||||
|
||||
session.reply(500, "Line too long")
|
||||
|
||||
// Advance reader to the next newline
|
||||
|
||||
session.reader.ReadString('\n')
|
||||
session.scanner = bufio.NewScanner(session.reader)
|
||||
|
||||
// Reset and have the client start over.
|
||||
|
||||
session.reset()
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (session *session) reject() {
|
||||
session.reply(421, "Too busy. Try again later.")
|
||||
session.close()
|
||||
}
|
||||
|
||||
func (session *session) reset() {
|
||||
session.envelope = nil
|
||||
}
|
||||
|
||||
func (session *session) welcome() {
|
||||
|
||||
if session.server.ConnectionChecker != nil {
|
||||
err := session.server.ConnectionChecker(session.peer)
|
||||
if err != nil {
|
||||
session.error(err)
|
||||
session.close()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
session.reply(220, session.server.WelcomeMessage)
|
||||
|
||||
}
|
||||
|
||||
func (session *session) reply(code int, message string) {
|
||||
session.logf("sending: %d %s", code, message)
|
||||
fmt.Fprintf(session.writer, "%d %s\r\n", code, message)
|
||||
session.flush()
|
||||
}
|
||||
|
||||
func (session *session) flush() {
|
||||
session.conn.SetWriteDeadline(time.Now().Add(session.server.WriteTimeout))
|
||||
session.writer.Flush()
|
||||
session.conn.SetReadDeadline(time.Now().Add(session.server.ReadTimeout))
|
||||
}
|
||||
|
||||
func (session *session) error(err error) {
|
||||
if smtpdError, ok := err.(Error); ok {
|
||||
session.reply(smtpdError.Code, smtpdError.Message)
|
||||
} else {
|
||||
session.reply(502, fmt.Sprintf("%s", err))
|
||||
}
|
||||
}
|
||||
|
||||
func (session *session) logf(format string, v ...interface{}) {
|
||||
if session.server.ProtocolLogger == nil {
|
||||
return
|
||||
}
|
||||
session.server.ProtocolLogger.Info(fmt.Sprintf(
|
||||
"%s [peer:%s]",
|
||||
fmt.Sprintf(format, v...),
|
||||
session.peer.Addr,
|
||||
))
|
||||
|
||||
}
|
||||
|
||||
func (session *session) logError(err error, desc string) {
|
||||
session.server.ProtocolLogger.Error(desc, "error", err)
|
||||
}
|
||||
|
||||
func (session *session) extensions() []string {
|
||||
|
||||
extensions := []string{
|
||||
fmt.Sprintf("SIZE %d", session.server.MaxMessageSize),
|
||||
"8BITMIME",
|
||||
"PIPELINING",
|
||||
}
|
||||
|
||||
if session.server.EnableXCLIENT {
|
||||
extensions = append(extensions, "XCLIENT")
|
||||
}
|
||||
|
||||
if session.server.TLSConfig != nil && !session.tls {
|
||||
extensions = append(extensions, "STARTTLS")
|
||||
}
|
||||
|
||||
if session.server.Authenticator != nil && session.tls {
|
||||
extensions = append(extensions, "AUTH PLAIN LOGIN")
|
||||
}
|
||||
|
||||
return extensions
|
||||
|
||||
}
|
||||
|
||||
func (session *session) deliver() error {
|
||||
if session.server.Handler != nil {
|
||||
return session.server.Handler(session.peer, *session.envelope)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (session *session) close() {
|
||||
session.writer.Flush()
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
session.conn.Close()
|
||||
}
|
||||
|
||||
|
||||
// From net/http/server.go
|
||||
|
||||
func (s *Server) shuttingDown() bool {
|
||||
|
@ -466,23 +312,3 @@ func (s *Server) closeDoneChanLocked() {
|
|||
}
|
||||
}
|
||||
|
||||
// onceCloseListener wraps a net.Listener, protecting it from
|
||||
// multiple Close calls.
|
||||
type onceCloseListener struct {
|
||||
net.Listener
|
||||
once sync.Once
|
||||
closeErr error
|
||||
}
|
||||
|
||||
func (oc *onceCloseListener) Close() error {
|
||||
oc.once.Do(oc.close)
|
||||
return oc.closeErr
|
||||
}
|
||||
|
||||
func (oc *onceCloseListener) close() { oc.closeErr = oc.Listener.Close() }
|
||||
|
||||
type atomicBool int32
|
||||
|
||||
func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 }
|
||||
func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) }
|
||||
func (b *atomicBool) setFalse() { atomic.StoreInt32((*int32)(b), 0) }
|
||||
|
|
Loading…
Add table
Reference in a new issue