diff --git a/onceCloseListerner.go b/onceCloseListerner.go new file mode 100644 index 0000000..4f4926c --- /dev/null +++ b/onceCloseListerner.go @@ -0,0 +1,28 @@ +package smtpd + +import ( + "net" + "sync" + "sync/atomic" +) + +// onceCloseListener wraps a net.Listener, protecting it from +// multiple Close calls. +type onceCloseListener struct { + net.Listener + once sync.Once + closeErr error +} + +func (oc *onceCloseListener) Close() error { + oc.once.Do(oc.close) + return oc.closeErr +} + +func (oc *onceCloseListener) close() { oc.closeErr = oc.Listener.Close() } + +type atomicBool int32 + +func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 } +func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) } +func (b *atomicBool) setFalse() { atomic.StoreInt32((*int32)(b), 0) } diff --git a/session.go b/session.go new file mode 100644 index 0000000..44d6550 --- /dev/null +++ b/session.go @@ -0,0 +1,163 @@ +package smtpd + +import ( + "net" + "bufio" + "strings" + "fmt" + "time" +) + +type session struct { + server *Server + + peer Peer + envelope *Envelope + + conn net.Conn + + reader *bufio.Reader + writer *bufio.Writer + scanner *bufio.Scanner + + tls bool +} + + +func (session *session) serve() { + + defer session.close() + + if !session.server.EnableProxyProtocol { + session.welcome() + } + + for { + + for session.scanner.Scan() { + line := session.scanner.Text() + session.logf("received: %s", strings.TrimSpace(line)) + session.handle(line) + } + + err := session.scanner.Err() + + if err == bufio.ErrTooLong { + + session.reply(500, "Line too long") + + // Advance reader to the next newline + + session.reader.ReadString('\n') + session.scanner = bufio.NewScanner(session.reader) + + // Reset and have the client start over. + + session.reset() + + continue + } + + break + } + +} + +func (session *session) reject() { + session.reply(421, "Too busy. Try again later.") + session.close() +} + +func (session *session) reset() { + session.envelope = nil +} + +func (session *session) welcome() { + + if session.server.ConnectionChecker != nil { + err := session.server.ConnectionChecker(session.peer) + if err != nil { + session.error(err) + session.close() + return + } + } + + session.reply(220, session.server.WelcomeMessage) + +} + +func (session *session) reply(code int, message string) { + session.logf("sending: %d %s", code, message) + fmt.Fprintf(session.writer, "%d %s\r\n", code, message) + session.flush() +} + +func (session *session) flush() { + session.conn.SetWriteDeadline(time.Now().Add(session.server.WriteTimeout)) + session.writer.Flush() + session.conn.SetReadDeadline(time.Now().Add(session.server.ReadTimeout)) +} + +func (session *session) error(err error) { + if smtpdError, ok := err.(Error); ok { + session.reply(smtpdError.Code, smtpdError.Message) + } else { + session.reply(502, fmt.Sprintf("%s", err)) + } +} + +func (session *session) logf(format string, v ...interface{}) { + if session.server.ProtocolLogger == nil { + return + } + session.server.ProtocolLogger.Info(fmt.Sprintf( + "%s [peer:%s]", + fmt.Sprintf(format, v...), + session.peer.Addr, + )) + +} + +func (session *session) logError(err error, desc string) { + session.server.ProtocolLogger.Error(desc, "error", err) +} + +func (session *session) extensions() []string { + + extensions := []string{ + fmt.Sprintf("SIZE %d", session.server.MaxMessageSize), + "8BITMIME", + "PIPELINING", + } + + if session.server.EnableXCLIENT { + extensions = append(extensions, "XCLIENT") + } + + if session.server.TLSConfig != nil && !session.tls { + extensions = append(extensions, "STARTTLS") + } + + if session.server.Authenticator != nil && session.tls { + extensions = append(extensions, "AUTH PLAIN LOGIN") + } + + return extensions + +} + +func (session *session) deliver() error { + if session.server.Handler != nil { + return session.server.Handler(session.peer, *session.envelope) + } + return nil +} + +func (session *session) close() { + session.writer.Flush() + time.Sleep(200 * time.Millisecond) + session.conn.Close() +} + + diff --git a/smtpd.go b/smtpd.go index d31651f..a24729f 100644 --- a/smtpd.go +++ b/smtpd.go @@ -9,9 +9,7 @@ import ( "log/slog" "log" "net" - "strings" "sync" - "sync/atomic" "time" ) @@ -98,21 +96,6 @@ func (e Error) Error() string { return fmt.Sprintf("%d %s", e.Code, e.Message) } // methods after a call to Shutdown. var ErrServerClosed = errors.New("smtp: Server closed") -type session struct { - server *Server - - peer Peer - envelope *Envelope - - conn net.Conn - - reader *bufio.Reader - writer *bufio.Writer - scanner *bufio.Scanner - - tls bool -} - func (srv *Server) newSession(c net.Conn) (s *session) { s = &session{ @@ -298,143 +281,6 @@ func (srv *Server) configureDefaults() { } -func (session *session) serve() { - - defer session.close() - - if !session.server.EnableProxyProtocol { - session.welcome() - } - - for { - - for session.scanner.Scan() { - line := session.scanner.Text() - session.logf("received: %s", strings.TrimSpace(line)) - session.handle(line) - } - - err := session.scanner.Err() - - if err == bufio.ErrTooLong { - - session.reply(500, "Line too long") - - // Advance reader to the next newline - - session.reader.ReadString('\n') - session.scanner = bufio.NewScanner(session.reader) - - // Reset and have the client start over. - - session.reset() - - continue - } - - break - } - -} - -func (session *session) reject() { - session.reply(421, "Too busy. Try again later.") - session.close() -} - -func (session *session) reset() { - session.envelope = nil -} - -func (session *session) welcome() { - - if session.server.ConnectionChecker != nil { - err := session.server.ConnectionChecker(session.peer) - if err != nil { - session.error(err) - session.close() - return - } - } - - session.reply(220, session.server.WelcomeMessage) - -} - -func (session *session) reply(code int, message string) { - session.logf("sending: %d %s", code, message) - fmt.Fprintf(session.writer, "%d %s\r\n", code, message) - session.flush() -} - -func (session *session) flush() { - session.conn.SetWriteDeadline(time.Now().Add(session.server.WriteTimeout)) - session.writer.Flush() - session.conn.SetReadDeadline(time.Now().Add(session.server.ReadTimeout)) -} - -func (session *session) error(err error) { - if smtpdError, ok := err.(Error); ok { - session.reply(smtpdError.Code, smtpdError.Message) - } else { - session.reply(502, fmt.Sprintf("%s", err)) - } -} - -func (session *session) logf(format string, v ...interface{}) { - if session.server.ProtocolLogger == nil { - return - } - session.server.ProtocolLogger.Info(fmt.Sprintf( - "%s [peer:%s]", - fmt.Sprintf(format, v...), - session.peer.Addr, - )) - -} - -func (session *session) logError(err error, desc string) { - session.server.ProtocolLogger.Error(desc, "error", err) -} - -func (session *session) extensions() []string { - - extensions := []string{ - fmt.Sprintf("SIZE %d", session.server.MaxMessageSize), - "8BITMIME", - "PIPELINING", - } - - if session.server.EnableXCLIENT { - extensions = append(extensions, "XCLIENT") - } - - if session.server.TLSConfig != nil && !session.tls { - extensions = append(extensions, "STARTTLS") - } - - if session.server.Authenticator != nil && session.tls { - extensions = append(extensions, "AUTH PLAIN LOGIN") - } - - return extensions - -} - -func (session *session) deliver() error { - if session.server.Handler != nil { - return session.server.Handler(session.peer, *session.envelope) - } - return nil -} - -func (session *session) close() { - session.writer.Flush() - time.Sleep(200 * time.Millisecond) - session.conn.Close() -} - - // From net/http/server.go func (s *Server) shuttingDown() bool { @@ -466,23 +312,3 @@ func (s *Server) closeDoneChanLocked() { } } -// onceCloseListener wraps a net.Listener, protecting it from -// multiple Close calls. -type onceCloseListener struct { - net.Listener - once sync.Once - closeErr error -} - -func (oc *onceCloseListener) Close() error { - oc.once.Do(oc.close) - return oc.closeErr -} - -func (oc *onceCloseListener) close() { oc.closeErr = oc.Listener.Close() } - -type atomicBool int32 - -func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 } -func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) } -func (b *atomicBool) setFalse() { atomic.StoreInt32((*int32)(b), 0) }