Compare commits
3 commits
master
...
chrj/ci-li
Author | SHA1 | Date | |
---|---|---|---|
|
15c75ccffd | ||
|
6fde3ceec6 | ||
|
46e33c52ab |
16 changed files with 266 additions and 492 deletions
15
.github/workflows/go.yml
vendored
15
.github/workflows/go.yml
vendored
|
@ -1,17 +1,15 @@
|
||||||
name: smtpd
|
name: Go
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [ master ]
|
branches: [ master ]
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ master ]
|
branches: [ master ]
|
||||||
schedule:
|
|
||||||
- cron: '12 0 * * *'
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
|
||||||
build:
|
build:
|
||||||
name: Build and Test
|
name: Build
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
|
|
||||||
|
@ -21,11 +19,16 @@ jobs:
|
||||||
go-version: ^1.13
|
go-version: ^1.13
|
||||||
id: go
|
id: go
|
||||||
|
|
||||||
- name: Check out the code
|
- name: Check out code into the Go module directory
|
||||||
uses: actions/checkout@v2
|
uses: actions/checkout@v2
|
||||||
|
|
||||||
- name: Get dependencies
|
- name: Get dependencies
|
||||||
run: go get -v -t -d ./...
|
run: |
|
||||||
|
go get -v -t -d ./...
|
||||||
|
if [ -f Gopkg.toml ]; then
|
||||||
|
curl https://raw.githubusercontent.com/golang/dep/master/install.sh | sh
|
||||||
|
dep ensure
|
||||||
|
fi
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
run: go build -v .
|
run: go build -v .
|
||||||
|
|
45
.github/workflows/linter.yml
vendored
Normal file
45
.github/workflows/linter.yml
vendored
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
name: Lint Code Base
|
||||||
|
|
||||||
|
#
|
||||||
|
# Documentation:
|
||||||
|
# https://help.github.com/en/articles/workflow-syntax-for-github-actions
|
||||||
|
#
|
||||||
|
|
||||||
|
#############################
|
||||||
|
# Start the job on all push #
|
||||||
|
#############################
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches-ignore: [master]
|
||||||
|
# Remove the line above to run when pushing to master
|
||||||
|
pull_request:
|
||||||
|
branches: [master]
|
||||||
|
|
||||||
|
###############
|
||||||
|
# Set the Job #
|
||||||
|
###############
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
# Name the Job
|
||||||
|
name: Lint Code Base
|
||||||
|
# Set the agent to run on
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
##################
|
||||||
|
# Load all steps #
|
||||||
|
##################
|
||||||
|
steps:
|
||||||
|
##########################
|
||||||
|
# Checkout the code base #
|
||||||
|
##########################
|
||||||
|
- name: Checkout Code
|
||||||
|
uses: actions/checkout@v2
|
||||||
|
|
||||||
|
################################
|
||||||
|
# Run Linter against code base #
|
||||||
|
################################
|
||||||
|
- name: Lint Code Base
|
||||||
|
uses: docker://github/super-linter:v3
|
||||||
|
env:
|
||||||
|
VALIDATE_ALL_CODEBASE: true
|
||||||
|
DEFAULT_BRANCH: master
|
8
Makefile
8
Makefile
|
@ -1,8 +0,0 @@
|
||||||
test:
|
|
||||||
go test .
|
|
||||||
|
|
||||||
dkim-proxy:
|
|
||||||
cd _examples && go get . && go build .
|
|
||||||
|
|
||||||
testsum:
|
|
||||||
gotestsum --format testname
|
|
|
@ -1,5 +0,0 @@
|
||||||
# smtpd dkim-proxy
|
|
||||||
|
|
||||||
## Important Note
|
|
||||||
|
|
||||||
The dependency `github.com/eaigner/dkim` is no longer available thus the example can not be built.
|
|
|
@ -1,5 +0,0 @@
|
||||||
module github.com/chrj/smtpd/_examples/dkim-proxy
|
|
||||||
|
|
||||||
go 1.14
|
|
||||||
|
|
||||||
require github.com/eaigner/dkim v0.0.0-20150301120808-6fe4a7ee9cfb
|
|
|
@ -20,7 +20,7 @@ func (env *Envelope) AddReceivedLine(peer Peer) {
|
||||||
tlsDetails := ""
|
tlsDetails := ""
|
||||||
|
|
||||||
tlsVersions := map[uint16]string{
|
tlsVersions := map[uint16]string{
|
||||||
0x300: "SSL3.0",
|
tls.VersionSSL30: "SSL3.0",
|
||||||
tls.VersionTLS10: "TLS1.0",
|
tls.VersionTLS10: "TLS1.0",
|
||||||
tls.VersionTLS11: "TLS1.1",
|
tls.VersionTLS11: "TLS1.1",
|
||||||
tls.VersionTLS12: "TLS1.2",
|
tls.VersionTLS12: "TLS1.2",
|
||||||
|
|
22
error.go
22
error.go
|
@ -1,22 +0,0 @@
|
||||||
package smtpd
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
// Error represents an Error reported in the SMTP session.
|
|
||||||
type Error struct {
|
|
||||||
Code int // The integer error code
|
|
||||||
Message string // The error message
|
|
||||||
}
|
|
||||||
|
|
||||||
// Error returns a string representation of the SMTP error
|
|
||||||
func (e Error) Error() string {
|
|
||||||
return fmt.Sprintf("%d %s", e.Code, e.Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ErrServerClosed is returned by the Server's Serve and ListenAndServe,
|
|
||||||
// methods after a call to Shutdown.
|
|
||||||
var ErrServerClosed = errors.New("smtp: Server closed")
|
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ import (
|
||||||
"net/smtp"
|
"net/smtp"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.jmbit.de/jmb/smtpd"
|
"github.com/chrj/smtpd"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ExampleServer() {
|
func ExampleServer() {
|
||||||
|
|
|
@ -4,6 +4,7 @@ package main
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"flag"
|
"flag"
|
||||||
|
"io/ioutil"
|
||||||
"log"
|
"log"
|
||||||
"net/smtp"
|
"net/smtp"
|
||||||
|
|
||||||
|
@ -59,7 +60,7 @@ func main() {
|
||||||
log.Fatalf("DKIM configuration error: %v", err)
|
log.Fatalf("DKIM configuration error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
privKey, err = io.ReadFile(*privKeyFile)
|
privKey, err = ioutil.ReadFile(*privKeyFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Couldn't read private key: %v", err)
|
log.Fatalf("Couldn't read private key: %v", err)
|
||||||
}
|
}
|
6
go.mod
6
go.mod
|
@ -1,3 +1,5 @@
|
||||||
module git.jmbit.de/jmb/smtpd
|
module github.com/chrj/smtpd
|
||||||
|
|
||||||
go 1.24
|
go 1.14
|
||||||
|
|
||||||
|
require github.com/eaigner/dkim v0.0.0-20150301120808-6fe4a7ee9cfb
|
||||||
|
|
|
@ -1,28 +0,0 @@
|
||||||
package smtpd
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
)
|
|
||||||
|
|
||||||
// onceCloseListener wraps a net.Listener, protecting it from
|
|
||||||
// multiple Close calls.
|
|
||||||
type onceCloseListener struct {
|
|
||||||
net.Listener
|
|
||||||
once sync.Once
|
|
||||||
closeErr error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (oc *onceCloseListener) Close() error {
|
|
||||||
oc.once.Do(oc.close)
|
|
||||||
return oc.closeErr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (oc *onceCloseListener) close() { oc.closeErr = oc.Listener.Close() }
|
|
||||||
|
|
||||||
type atomicBool int32
|
|
||||||
|
|
||||||
func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 }
|
|
||||||
func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) }
|
|
||||||
func (b *atomicBool) setFalse() { atomic.StoreInt32((*int32)(b), 0) }
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
"net/textproto"
|
"net/textproto"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
@ -287,7 +288,6 @@ func (session *session) handleRCPT(cmd command) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (session *session) handleSTARTTLS(cmd command) {
|
func (session *session) handleSTARTTLS(cmd command) {
|
||||||
_ = cmd
|
|
||||||
|
|
||||||
if session.tls {
|
if session.tls {
|
||||||
session.reply(502, "Already running in TLS")
|
session.reply(502, "Already running in TLS")
|
||||||
|
@ -334,7 +334,6 @@ func (session *session) handleSTARTTLS(cmd command) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (session *session) handleDATA(cmd command) {
|
func (session *session) handleDATA(cmd command) {
|
||||||
_ = cmd
|
|
||||||
|
|
||||||
if session.envelope == nil || len(session.envelope.Recipients) == 0 {
|
if session.envelope == nil || len(session.envelope.Recipients) == 0 {
|
||||||
session.reply(502, "Missing RCPT TO command.")
|
session.reply(502, "Missing RCPT TO command.")
|
||||||
|
@ -372,7 +371,7 @@ func (session *session) handleDATA(cmd command) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Discard the rest and report an error.
|
// Discard the rest and report an error.
|
||||||
_, err = io.Copy(io.Discard, reader)
|
_, err = io.Copy(ioutil.Discard, reader)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Network error, ignore
|
// Network error, ignore
|
||||||
|
@ -391,20 +390,17 @@ func (session *session) handleDATA(cmd command) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (session *session) handleRSET(cmd command) {
|
func (session *session) handleRSET(cmd command) {
|
||||||
_ = cmd
|
|
||||||
session.reset()
|
session.reset()
|
||||||
session.reply(250, "Go ahead")
|
session.reply(250, "Go ahead")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (session *session) handleNOOP(cmd command) {
|
func (session *session) handleNOOP(cmd command) {
|
||||||
_ = cmd
|
|
||||||
session.reply(250, "Go ahead")
|
session.reply(250, "Go ahead")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (session *session) handleQUIT(cmd command) {
|
func (session *session) handleQUIT(cmd command) {
|
||||||
_ = cmd
|
|
||||||
session.reply(221, "OK, bye")
|
session.reply(221, "OK, bye")
|
||||||
session.close()
|
session.close()
|
||||||
return
|
return
|
||||||
|
|
162
session.go
162
session.go
|
@ -1,162 +0,0 @@
|
||||||
package smtpd
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"bufio"
|
|
||||||
"strings"
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type session struct {
|
|
||||||
server *Server
|
|
||||||
|
|
||||||
peer Peer
|
|
||||||
envelope *Envelope
|
|
||||||
|
|
||||||
conn net.Conn
|
|
||||||
|
|
||||||
reader *bufio.Reader
|
|
||||||
writer *bufio.Writer
|
|
||||||
scanner *bufio.Scanner
|
|
||||||
|
|
||||||
tls bool
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
func (session *session) serve() {
|
|
||||||
|
|
||||||
defer session.close()
|
|
||||||
|
|
||||||
if !session.server.EnableProxyProtocol {
|
|
||||||
session.welcome()
|
|
||||||
}
|
|
||||||
|
|
||||||
for {
|
|
||||||
|
|
||||||
for session.scanner.Scan() {
|
|
||||||
line := session.scanner.Text()
|
|
||||||
session.logf("received: %s", strings.TrimSpace(line))
|
|
||||||
session.handle(line)
|
|
||||||
}
|
|
||||||
|
|
||||||
err := session.scanner.Err()
|
|
||||||
|
|
||||||
if err == bufio.ErrTooLong {
|
|
||||||
|
|
||||||
session.reply(500, "Line too long")
|
|
||||||
|
|
||||||
// Advance reader to the next newline
|
|
||||||
|
|
||||||
session.reader.ReadString('\n')
|
|
||||||
session.scanner = bufio.NewScanner(session.reader)
|
|
||||||
|
|
||||||
// Reset and have the client start over.
|
|
||||||
|
|
||||||
session.reset()
|
|
||||||
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (session *session) reject() {
|
|
||||||
session.reply(421, "Too busy. Try again later.")
|
|
||||||
session.close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (session *session) reset() {
|
|
||||||
session.envelope = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (session *session) welcome() {
|
|
||||||
|
|
||||||
if session.server.ConnectionChecker != nil {
|
|
||||||
err := session.server.ConnectionChecker(session.peer)
|
|
||||||
if err != nil {
|
|
||||||
session.error(err)
|
|
||||||
session.close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
session.reply(220, session.server.WelcomeMessage)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (session *session) reply(code int, message string) {
|
|
||||||
session.logf("sending: %d %s", code, message)
|
|
||||||
fmt.Fprintf(session.writer, "%d %s\r\n", code, message)
|
|
||||||
session.flush()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (session *session) flush() {
|
|
||||||
session.conn.SetWriteDeadline(time.Now().Add(session.server.WriteTimeout))
|
|
||||||
session.writer.Flush()
|
|
||||||
session.conn.SetReadDeadline(time.Now().Add(session.server.ReadTimeout))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (session *session) error(err error) {
|
|
||||||
if smtpdError, ok := err.(Error); ok {
|
|
||||||
session.reply(smtpdError.Code, smtpdError.Message)
|
|
||||||
} else {
|
|
||||||
session.reply(502, fmt.Sprintf("%s", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (session *session) logf(format string, v ...interface{}) {
|
|
||||||
if session.server.ProtocolLogger == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
session.server.ProtocolLogger.Info(fmt.Sprintf(
|
|
||||||
"%s [peer:%s]",
|
|
||||||
fmt.Sprintf(format, v...),
|
|
||||||
session.peer.Addr,
|
|
||||||
))
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (session *session) logError(err error, desc string) {
|
|
||||||
session.server.ProtocolLogger.Error(desc, "error", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (session *session) extensions() []string {
|
|
||||||
|
|
||||||
extensions := []string{
|
|
||||||
fmt.Sprintf("SIZE %d", session.server.MaxMessageSize),
|
|
||||||
"8BITMIME",
|
|
||||||
"PIPELINING",
|
|
||||||
}
|
|
||||||
|
|
||||||
if session.server.EnableXCLIENT {
|
|
||||||
extensions = append(extensions, "XCLIENT")
|
|
||||||
}
|
|
||||||
|
|
||||||
if session.server.TLSConfig != nil && !session.tls {
|
|
||||||
extensions = append(extensions, "STARTTLS")
|
|
||||||
}
|
|
||||||
|
|
||||||
if session.server.Authenticator != nil && session.tls {
|
|
||||||
extensions = append(extensions, "AUTH PLAIN LOGIN")
|
|
||||||
}
|
|
||||||
|
|
||||||
return extensions
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (session *session) deliver() error {
|
|
||||||
if session.server.Handler != nil {
|
|
||||||
return session.server.Handler(session.peer, *session.envelope)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (session *session) close() {
|
|
||||||
session.writer.Flush()
|
|
||||||
time.Sleep(200 * time.Millisecond)
|
|
||||||
session.conn.Close()
|
|
||||||
}
|
|
||||||
|
|
263
smtpd.go
263
smtpd.go
|
@ -4,12 +4,10 @@ package smtpd
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log"
|
||||||
"log"
|
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -50,15 +48,7 @@ type Server struct {
|
||||||
TLSConfig *tls.Config // Enable STARTTLS support.
|
TLSConfig *tls.Config // Enable STARTTLS support.
|
||||||
ForceTLS bool // Force STARTTLS usage.
|
ForceTLS bool // Force STARTTLS usage.
|
||||||
|
|
||||||
ProtocolLogger *slog.Logger
|
ProtocolLogger *log.Logger
|
||||||
|
|
||||||
// mu guards doneChan and makes closing it and listener atomic from
|
|
||||||
// perspective of Serve()
|
|
||||||
mu sync.Mutex
|
|
||||||
doneChan chan struct{}
|
|
||||||
listener *net.Listener
|
|
||||||
waitgrp sync.WaitGroup
|
|
||||||
inShutdown atomicBool // true when server is in shutdown
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Protocol represents the protocol used in the SMTP session
|
// Protocol represents the protocol used in the SMTP session
|
||||||
|
@ -83,6 +73,30 @@ type Peer struct {
|
||||||
TLS *tls.ConnectionState // TLS Connection details, if on TLS
|
TLS *tls.ConnectionState // TLS Connection details, if on TLS
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Error represents an Error reported in the SMTP session.
|
||||||
|
type Error struct {
|
||||||
|
Code int // The integer error code
|
||||||
|
Message string // The error message
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error returns a string representation of the SMTP error
|
||||||
|
func (e Error) Error() string { return fmt.Sprintf("%d %s", e.Code, e.Message) }
|
||||||
|
|
||||||
|
type session struct {
|
||||||
|
server *Server
|
||||||
|
|
||||||
|
peer Peer
|
||||||
|
envelope *Envelope
|
||||||
|
|
||||||
|
conn net.Conn
|
||||||
|
|
||||||
|
reader *bufio.Reader
|
||||||
|
writer *bufio.Writer
|
||||||
|
scanner *bufio.Scanner
|
||||||
|
|
||||||
|
tls bool
|
||||||
|
}
|
||||||
|
|
||||||
func (srv *Server) newSession(c net.Conn) (s *session) {
|
func (srv *Server) newSession(c net.Conn) (s *session) {
|
||||||
|
|
||||||
s = &session{
|
s = &session{
|
||||||
|
@ -120,9 +134,6 @@ func (srv *Server) newSession(c net.Conn) (s *session) {
|
||||||
|
|
||||||
// ListenAndServe starts the SMTP server and listens on the address provided
|
// ListenAndServe starts the SMTP server and listens on the address provided
|
||||||
func (srv *Server) ListenAndServe(addr string) error {
|
func (srv *Server) ListenAndServe(addr string) error {
|
||||||
if srv.shuttingDown() {
|
|
||||||
return ErrServerClosed
|
|
||||||
}
|
|
||||||
|
|
||||||
srv.configureDefaults()
|
srv.configureDefaults()
|
||||||
|
|
||||||
|
@ -136,32 +147,24 @@ func (srv *Server) ListenAndServe(addr string) error {
|
||||||
|
|
||||||
// Serve starts the SMTP server and listens on the Listener provided
|
// Serve starts the SMTP server and listens on the Listener provided
|
||||||
func (srv *Server) Serve(l net.Listener) error {
|
func (srv *Server) Serve(l net.Listener) error {
|
||||||
if srv.shuttingDown() {
|
|
||||||
return ErrServerClosed
|
|
||||||
}
|
|
||||||
|
|
||||||
srv.configureDefaults()
|
srv.configureDefaults()
|
||||||
|
|
||||||
l = &onceCloseListener{Listener: l}
|
|
||||||
defer l.Close()
|
defer l.Close()
|
||||||
srv.listener = &l
|
|
||||||
|
|
||||||
var limiter chan struct{}
|
var limiter chan struct{}
|
||||||
|
|
||||||
if srv.MaxConnections > 0 {
|
if srv.MaxConnections > 0 {
|
||||||
limiter = make(chan struct{}, srv.MaxConnections)
|
limiter = make(chan struct{}, srv.MaxConnections)
|
||||||
|
} else {
|
||||||
|
limiter = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
|
||||||
conn, e := l.Accept()
|
conn, e := l.Accept()
|
||||||
if e != nil {
|
if e != nil {
|
||||||
select {
|
if ne, ok := e.(net.Error); ok && ne.Temporary() {
|
||||||
case <-srv.getDoneChan():
|
|
||||||
return ErrServerClosed
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
if ne, ok := e.(net.Error); ok && ne.Timeout() {
|
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -170,10 +173,8 @@ func (srv *Server) Serve(l net.Listener) error {
|
||||||
|
|
||||||
session := srv.newSession(conn)
|
session := srv.newSession(conn)
|
||||||
|
|
||||||
srv.waitgrp.Add(1)
|
if limiter != nil {
|
||||||
go func() {
|
go func() {
|
||||||
defer srv.waitgrp.Done()
|
|
||||||
if limiter != nil {
|
|
||||||
select {
|
select {
|
||||||
case limiter <- struct{}{}:
|
case limiter <- struct{}{}:
|
||||||
session.serve()
|
session.serve()
|
||||||
|
@ -181,53 +182,15 @@ func (srv *Server) Serve(l net.Listener) error {
|
||||||
default:
|
default:
|
||||||
session.reject()
|
session.reject()
|
||||||
}
|
}
|
||||||
} else {
|
}()
|
||||||
session.serve()
|
} else {
|
||||||
}
|
go session.serve()
|
||||||
}()
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shutdown instructs the server to shutdown, starting by closing the
|
|
||||||
// associated listener. If wait is true, it will wait for the shutdown
|
|
||||||
// to complete. If wait is false, Wait must be called afterwards.
|
|
||||||
func (srv *Server) Shutdown(wait bool) error {
|
|
||||||
var lnerr error
|
|
||||||
srv.inShutdown.setTrue()
|
|
||||||
|
|
||||||
// First close the listener
|
|
||||||
srv.mu.Lock()
|
|
||||||
if srv.listener != nil {
|
|
||||||
lnerr = (*srv.listener).Close();
|
|
||||||
}
|
|
||||||
srv.closeDoneChanLocked()
|
|
||||||
srv.mu.Unlock()
|
|
||||||
|
|
||||||
// Now wait for all client connections to close
|
|
||||||
if wait {
|
|
||||||
srv.Wait()
|
|
||||||
}
|
|
||||||
|
|
||||||
return lnerr
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait waits for all client connections to close and the server to finish
|
|
||||||
// shutting down.
|
|
||||||
func (srv *Server) Wait() error {
|
|
||||||
if !srv.shuttingDown() {
|
|
||||||
return errors.New("Server has not been Shutdown")
|
|
||||||
}
|
|
||||||
|
|
||||||
srv.waitgrp.Wait()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Address returns the listening address of the server
|
|
||||||
func (srv *Server) Address() net.Addr {
|
|
||||||
return (*srv.listener).Addr();
|
|
||||||
}
|
|
||||||
|
|
||||||
func (srv *Server) configureDefaults() {
|
func (srv *Server) configureDefaults() {
|
||||||
|
|
||||||
if srv.MaxMessageSize == 0 {
|
if srv.MaxMessageSize == 0 {
|
||||||
|
@ -268,34 +231,138 @@ func (srv *Server) configureDefaults() {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// From net/http/server.go
|
func (session *session) serve() {
|
||||||
|
|
||||||
func (s *Server) shuttingDown() bool {
|
defer session.close()
|
||||||
return s.inShutdown.isSet()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) getDoneChan() <-chan struct{} {
|
if !session.server.EnableProxyProtocol {
|
||||||
s.mu.Lock()
|
session.welcome()
|
||||||
defer s.mu.Unlock()
|
|
||||||
return s.getDoneChanLocked()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) getDoneChanLocked() chan struct{} {
|
|
||||||
if s.doneChan == nil {
|
|
||||||
s.doneChan = make(chan struct{})
|
|
||||||
}
|
}
|
||||||
return s.doneChan
|
|
||||||
|
for {
|
||||||
|
|
||||||
|
for session.scanner.Scan() {
|
||||||
|
line := session.scanner.Text()
|
||||||
|
session.logf("received: %s", strings.TrimSpace(line))
|
||||||
|
session.handle(line)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := session.scanner.Err()
|
||||||
|
|
||||||
|
if err == bufio.ErrTooLong {
|
||||||
|
|
||||||
|
session.reply(500, "Line too long")
|
||||||
|
|
||||||
|
// Advance reader to the next newline
|
||||||
|
|
||||||
|
session.reader.ReadString('\n')
|
||||||
|
session.scanner = bufio.NewScanner(session.reader)
|
||||||
|
|
||||||
|
// Reset and have the client start over.
|
||||||
|
|
||||||
|
session.reset()
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) closeDoneChanLocked() {
|
func (session *session) reject() {
|
||||||
ch := s.getDoneChanLocked()
|
session.reply(421, "Too busy. Try again later.")
|
||||||
select {
|
session.close()
|
||||||
case <-ch:
|
}
|
||||||
// Already closed. Don't close again.
|
|
||||||
default:
|
func (session *session) reset() {
|
||||||
// Safe to close here. We're the only closer, guarded
|
session.envelope = nil
|
||||||
// by s.mu.
|
}
|
||||||
close(ch)
|
|
||||||
|
func (session *session) welcome() {
|
||||||
|
|
||||||
|
if session.server.ConnectionChecker != nil {
|
||||||
|
err := session.server.ConnectionChecker(session.peer)
|
||||||
|
if err != nil {
|
||||||
|
session.error(err)
|
||||||
|
session.close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
session.reply(220, session.server.WelcomeMessage)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (session *session) reply(code int, message string) {
|
||||||
|
session.logf("sending: %d %s", code, message)
|
||||||
|
fmt.Fprintf(session.writer, "%d %s\r\n", code, message)
|
||||||
|
session.flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (session *session) flush() {
|
||||||
|
session.conn.SetWriteDeadline(time.Now().Add(session.server.WriteTimeout))
|
||||||
|
session.writer.Flush()
|
||||||
|
session.conn.SetReadDeadline(time.Now().Add(session.server.ReadTimeout))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (session *session) error(err error) {
|
||||||
|
if smtpdError, ok := err.(Error); ok {
|
||||||
|
session.reply(smtpdError.Code, smtpdError.Message)
|
||||||
|
} else {
|
||||||
|
session.reply(502, fmt.Sprintf("%s", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (session *session) logf(format string, v ...interface{}) {
|
||||||
|
if session.server.ProtocolLogger == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
session.server.ProtocolLogger.Output(2, fmt.Sprintf(
|
||||||
|
"%s [peer:%s]",
|
||||||
|
fmt.Sprintf(format, v...),
|
||||||
|
session.peer.Addr,
|
||||||
|
))
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (session *session) logError(err error, desc string) {
|
||||||
|
session.logf("%s: %v ", desc, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (session *session) extensions() []string {
|
||||||
|
|
||||||
|
extensions := []string{
|
||||||
|
fmt.Sprintf("SIZE %d", session.server.MaxMessageSize),
|
||||||
|
"8BITMIME",
|
||||||
|
"PIPELINING",
|
||||||
|
}
|
||||||
|
|
||||||
|
if session.server.EnableXCLIENT {
|
||||||
|
extensions = append(extensions, "XCLIENT")
|
||||||
|
}
|
||||||
|
|
||||||
|
if session.server.TLSConfig != nil && !session.tls {
|
||||||
|
extensions = append(extensions, "STARTTLS")
|
||||||
|
}
|
||||||
|
|
||||||
|
if session.server.Authenticator != nil && session.tls {
|
||||||
|
extensions = append(extensions, "AUTH PLAIN LOGIN")
|
||||||
|
}
|
||||||
|
|
||||||
|
return extensions
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (session *session) deliver() error {
|
||||||
|
if session.server.Handler != nil {
|
||||||
|
return session.server.Handler(session.peer, *session.envelope)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (session *session) close() {
|
||||||
|
session.writer.Flush()
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
session.conn.Close()
|
||||||
|
}
|
||||||
|
|
184
smtpd_test.go
184
smtpd_test.go
|
@ -5,7 +5,7 @@ import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"net/smtp"
|
"net/smtp"
|
||||||
"net/textproto"
|
"net/textproto"
|
||||||
|
@ -14,7 +14,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.jmbit.de/jmb/smtpd"
|
"github.com/chrj/smtpd"
|
||||||
)
|
)
|
||||||
|
|
||||||
var localhostCert = []byte(`-----BEGIN CERTIFICATE-----
|
var localhostCert = []byte(`-----BEGIN CERTIFICATE-----
|
||||||
|
@ -158,7 +158,7 @@ func runsslserver(t *testing.T, server *smtpd.Server) (addr string, closer func(
|
||||||
func TestSMTP(t *testing.T) {
|
func TestSMTP(t *testing.T) {
|
||||||
|
|
||||||
addr, closer := runserver(t, &smtpd.Server{
|
addr, closer := runserver(t, &smtpd.Server{
|
||||||
ProtocolLogger: slog.New(slog.NewTextHandler(os.Stdout, nil)),
|
ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile),
|
||||||
})
|
})
|
||||||
defer closer()
|
defer closer()
|
||||||
|
|
||||||
|
@ -233,7 +233,7 @@ func TestListenAndServe(t *testing.T) {
|
||||||
closer()
|
closer()
|
||||||
|
|
||||||
server := &smtpd.Server{
|
server := &smtpd.Server{
|
||||||
ProtocolLogger: slog.New(slog.NewTextHandler(os.Stdout, nil)),
|
ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile),
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -258,7 +258,7 @@ func TestSTARTTLS(t *testing.T) {
|
||||||
addr, closer := runsslserver(t, &smtpd.Server{
|
addr, closer := runsslserver(t, &smtpd.Server{
|
||||||
Authenticator: func(peer smtpd.Peer, username, password string) error { return nil },
|
Authenticator: func(peer smtpd.Peer, username, password string) error { return nil },
|
||||||
ForceTLS: true,
|
ForceTLS: true,
|
||||||
ProtocolLogger: slog.New(slog.NewTextHandler(os.Stdout, nil)),
|
ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile),
|
||||||
})
|
})
|
||||||
|
|
||||||
defer closer()
|
defer closer()
|
||||||
|
@ -347,7 +347,7 @@ func TestAuthRejection(t *testing.T) {
|
||||||
return smtpd.Error{Code: 550, Message: "Denied"}
|
return smtpd.Error{Code: 550, Message: "Denied"}
|
||||||
},
|
},
|
||||||
ForceTLS: true,
|
ForceTLS: true,
|
||||||
ProtocolLogger: slog.New(slog.NewTextHandler(os.Stdout, nil)),
|
ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile),
|
||||||
})
|
})
|
||||||
|
|
||||||
defer closer()
|
defer closer()
|
||||||
|
@ -371,7 +371,7 @@ func TestAuthNotSupported(t *testing.T) {
|
||||||
|
|
||||||
addr, closer := runsslserver(t, &smtpd.Server{
|
addr, closer := runsslserver(t, &smtpd.Server{
|
||||||
ForceTLS: true,
|
ForceTLS: true,
|
||||||
ProtocolLogger: slog.New(slog.NewTextHandler(os.Stdout, nil)),
|
ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile),
|
||||||
})
|
})
|
||||||
|
|
||||||
defer closer()
|
defer closer()
|
||||||
|
@ -398,7 +398,7 @@ func TestAuthBypass(t *testing.T) {
|
||||||
return smtpd.Error{Code: 550, Message: "Denied"}
|
return smtpd.Error{Code: 550, Message: "Denied"}
|
||||||
},
|
},
|
||||||
ForceTLS: true,
|
ForceTLS: true,
|
||||||
ProtocolLogger: slog.New(slog.NewTextHandler(os.Stdout, nil)),
|
ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile),
|
||||||
})
|
})
|
||||||
|
|
||||||
defer closer()
|
defer closer()
|
||||||
|
@ -424,7 +424,7 @@ func TestConnectionCheck(t *testing.T) {
|
||||||
ConnectionChecker: func(peer smtpd.Peer) error {
|
ConnectionChecker: func(peer smtpd.Peer) error {
|
||||||
return smtpd.Error{Code: 552, Message: "Denied"}
|
return smtpd.Error{Code: 552, Message: "Denied"}
|
||||||
},
|
},
|
||||||
ProtocolLogger: slog.New(slog.NewTextHandler(os.Stdout, nil)),
|
ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile),
|
||||||
})
|
})
|
||||||
|
|
||||||
defer closer()
|
defer closer()
|
||||||
|
@ -441,7 +441,7 @@ func TestConnectionCheckSimpleError(t *testing.T) {
|
||||||
ConnectionChecker: func(peer smtpd.Peer) error {
|
ConnectionChecker: func(peer smtpd.Peer) error {
|
||||||
return errors.New("Denied")
|
return errors.New("Denied")
|
||||||
},
|
},
|
||||||
ProtocolLogger: slog.New(slog.NewTextHandler(os.Stdout, nil)),
|
ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile),
|
||||||
})
|
})
|
||||||
|
|
||||||
defer closer()
|
defer closer()
|
||||||
|
@ -461,7 +461,7 @@ func TestHELOCheck(t *testing.T) {
|
||||||
}
|
}
|
||||||
return smtpd.Error{Code: 552, Message: "Denied"}
|
return smtpd.Error{Code: 552, Message: "Denied"}
|
||||||
},
|
},
|
||||||
ProtocolLogger: slog.New(slog.NewTextHandler(os.Stdout, nil)),
|
ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile),
|
||||||
})
|
})
|
||||||
|
|
||||||
defer closer()
|
defer closer()
|
||||||
|
@ -483,7 +483,7 @@ func TestSenderCheck(t *testing.T) {
|
||||||
SenderChecker: func(peer smtpd.Peer, addr string) error {
|
SenderChecker: func(peer smtpd.Peer, addr string) error {
|
||||||
return smtpd.Error{Code: 552, Message: "Denied"}
|
return smtpd.Error{Code: 552, Message: "Denied"}
|
||||||
},
|
},
|
||||||
ProtocolLogger: slog.New(slog.NewTextHandler(os.Stdout, nil)),
|
ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile),
|
||||||
})
|
})
|
||||||
|
|
||||||
defer closer()
|
defer closer()
|
||||||
|
@ -505,7 +505,7 @@ func TestRecipientCheck(t *testing.T) {
|
||||||
RecipientChecker: func(peer smtpd.Peer, addr string) error {
|
RecipientChecker: func(peer smtpd.Peer, addr string) error {
|
||||||
return smtpd.Error{Code: 552, Message: "Denied"}
|
return smtpd.Error{Code: 552, Message: "Denied"}
|
||||||
},
|
},
|
||||||
ProtocolLogger: slog.New(slog.NewTextHandler(os.Stdout, nil)),
|
ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile),
|
||||||
})
|
})
|
||||||
|
|
||||||
defer closer()
|
defer closer()
|
||||||
|
@ -529,7 +529,7 @@ func TestMaxMessageSize(t *testing.T) {
|
||||||
|
|
||||||
addr, closer := runserver(t, &smtpd.Server{
|
addr, closer := runserver(t, &smtpd.Server{
|
||||||
MaxMessageSize: 5,
|
MaxMessageSize: 5,
|
||||||
ProtocolLogger: slog.New(slog.NewTextHandler(os.Stdout, nil)),
|
ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile),
|
||||||
})
|
})
|
||||||
|
|
||||||
defer closer()
|
defer closer()
|
||||||
|
@ -586,7 +586,7 @@ func TestHandler(t *testing.T) {
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
ProtocolLogger: slog.New(slog.NewTextHandler(os.Stdout, nil)),
|
ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile),
|
||||||
})
|
})
|
||||||
|
|
||||||
defer closer()
|
defer closer()
|
||||||
|
@ -631,7 +631,7 @@ func TestRejectHandler(t *testing.T) {
|
||||||
Handler: func(peer smtpd.Peer, env smtpd.Envelope) error {
|
Handler: func(peer smtpd.Peer, env smtpd.Envelope) error {
|
||||||
return smtpd.Error{Code: 550, Message: "Rejected"}
|
return smtpd.Error{Code: 550, Message: "Rejected"}
|
||||||
},
|
},
|
||||||
ProtocolLogger: slog.New(slog.NewTextHandler(os.Stdout, nil)),
|
ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile),
|
||||||
})
|
})
|
||||||
|
|
||||||
defer closer()
|
defer closer()
|
||||||
|
@ -674,7 +674,7 @@ func TestMaxConnections(t *testing.T) {
|
||||||
|
|
||||||
addr, closer := runserver(t, &smtpd.Server{
|
addr, closer := runserver(t, &smtpd.Server{
|
||||||
MaxConnections: 1,
|
MaxConnections: 1,
|
||||||
ProtocolLogger: slog.New(slog.NewTextHandler(os.Stdout, nil)),
|
ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile),
|
||||||
})
|
})
|
||||||
|
|
||||||
defer closer()
|
defer closer()
|
||||||
|
@ -696,7 +696,7 @@ func TestNoMaxConnections(t *testing.T) {
|
||||||
|
|
||||||
addr, closer := runserver(t, &smtpd.Server{
|
addr, closer := runserver(t, &smtpd.Server{
|
||||||
MaxConnections: -1,
|
MaxConnections: -1,
|
||||||
ProtocolLogger: slog.New(slog.NewTextHandler(os.Stdout, nil)),
|
ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile),
|
||||||
})
|
})
|
||||||
|
|
||||||
defer closer()
|
defer closer()
|
||||||
|
@ -713,7 +713,7 @@ func TestMaxRecipients(t *testing.T) {
|
||||||
|
|
||||||
addr, closer := runserver(t, &smtpd.Server{
|
addr, closer := runserver(t, &smtpd.Server{
|
||||||
MaxRecipients: 1,
|
MaxRecipients: 1,
|
||||||
ProtocolLogger: slog.New(slog.NewTextHandler(os.Stdout, nil)),
|
ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile),
|
||||||
})
|
})
|
||||||
|
|
||||||
defer closer()
|
defer closer()
|
||||||
|
@ -744,7 +744,7 @@ func TestMaxRecipients(t *testing.T) {
|
||||||
func TestInvalidHelo(t *testing.T) {
|
func TestInvalidHelo(t *testing.T) {
|
||||||
|
|
||||||
addr, closer := runserver(t, &smtpd.Server{
|
addr, closer := runserver(t, &smtpd.Server{
|
||||||
ProtocolLogger: slog.New(slog.NewTextHandler(os.Stdout, nil)),
|
ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile),
|
||||||
})
|
})
|
||||||
|
|
||||||
defer closer()
|
defer closer()
|
||||||
|
@ -763,7 +763,7 @@ func TestInvalidHelo(t *testing.T) {
|
||||||
func TestInvalidSender(t *testing.T) {
|
func TestInvalidSender(t *testing.T) {
|
||||||
|
|
||||||
addr, closer := runserver(t, &smtpd.Server{
|
addr, closer := runserver(t, &smtpd.Server{
|
||||||
ProtocolLogger: slog.New(slog.NewTextHandler(os.Stdout, nil)),
|
ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile),
|
||||||
})
|
})
|
||||||
|
|
||||||
defer closer()
|
defer closer()
|
||||||
|
@ -782,7 +782,7 @@ func TestInvalidSender(t *testing.T) {
|
||||||
func TestInvalidRecipient(t *testing.T) {
|
func TestInvalidRecipient(t *testing.T) {
|
||||||
|
|
||||||
addr, closer := runserver(t, &smtpd.Server{
|
addr, closer := runserver(t, &smtpd.Server{
|
||||||
ProtocolLogger: slog.New(slog.NewTextHandler(os.Stdout, nil)),
|
ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile),
|
||||||
})
|
})
|
||||||
|
|
||||||
defer closer()
|
defer closer()
|
||||||
|
@ -805,7 +805,7 @@ func TestInvalidRecipient(t *testing.T) {
|
||||||
func TestRCPTbeforeMAIL(t *testing.T) {
|
func TestRCPTbeforeMAIL(t *testing.T) {
|
||||||
|
|
||||||
addr, closer := runserver(t, &smtpd.Server{
|
addr, closer := runserver(t, &smtpd.Server{
|
||||||
ProtocolLogger: slog.New(slog.NewTextHandler(os.Stdout, nil)),
|
ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile),
|
||||||
})
|
})
|
||||||
|
|
||||||
defer closer()
|
defer closer()
|
||||||
|
@ -824,7 +824,7 @@ func TestRCPTbeforeMAIL(t *testing.T) {
|
||||||
func TestDATAbeforeRCPT(t *testing.T) {
|
func TestDATAbeforeRCPT(t *testing.T) {
|
||||||
|
|
||||||
addr, closer := runserver(t, &smtpd.Server{
|
addr, closer := runserver(t, &smtpd.Server{
|
||||||
ProtocolLogger: slog.New(slog.NewTextHandler(os.Stdout, nil)),
|
ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile),
|
||||||
})
|
})
|
||||||
|
|
||||||
defer closer()
|
defer closer()
|
||||||
|
@ -855,7 +855,7 @@ func TestInterruptedDATA(t *testing.T) {
|
||||||
t.Fatal("Accepted DATA despite disconnection")
|
t.Fatal("Accepted DATA despite disconnection")
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
ProtocolLogger: slog.New(slog.NewTextHandler(os.Stdout, nil)),
|
ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile),
|
||||||
})
|
})
|
||||||
|
|
||||||
defer closer()
|
defer closer()
|
||||||
|
@ -893,7 +893,7 @@ func TestTimeoutClose(t *testing.T) {
|
||||||
MaxConnections: 1,
|
MaxConnections: 1,
|
||||||
ReadTimeout: time.Second,
|
ReadTimeout: time.Second,
|
||||||
WriteTimeout: time.Second,
|
WriteTimeout: time.Second,
|
||||||
ProtocolLogger: slog.New(slog.NewTextHandler(os.Stdout, nil)),
|
ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile),
|
||||||
})
|
})
|
||||||
|
|
||||||
defer closer()
|
defer closer()
|
||||||
|
@ -930,7 +930,7 @@ func TestTLSTimeout(t *testing.T) {
|
||||||
addr, closer := runsslserver(t, &smtpd.Server{
|
addr, closer := runsslserver(t, &smtpd.Server{
|
||||||
ReadTimeout: time.Second * 2,
|
ReadTimeout: time.Second * 2,
|
||||||
WriteTimeout: time.Second * 2,
|
WriteTimeout: time.Second * 2,
|
||||||
ProtocolLogger: slog.New(slog.NewTextHandler(os.Stdout, nil)),
|
ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile),
|
||||||
})
|
})
|
||||||
|
|
||||||
defer closer()
|
defer closer()
|
||||||
|
@ -973,7 +973,7 @@ func TestTLSTimeout(t *testing.T) {
|
||||||
func TestLongLine(t *testing.T) {
|
func TestLongLine(t *testing.T) {
|
||||||
|
|
||||||
addr, closer := runserver(t, &smtpd.Server{
|
addr, closer := runserver(t, &smtpd.Server{
|
||||||
ProtocolLogger: slog.New(slog.NewTextHandler(os.Stdout, nil)),
|
ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile),
|
||||||
})
|
})
|
||||||
|
|
||||||
defer closer()
|
defer closer()
|
||||||
|
@ -1012,7 +1012,7 @@ func TestXCLIENT(t *testing.T) {
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
ProtocolLogger: slog.New(slog.NewTextHandler(os.Stdout, nil)),
|
ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile),
|
||||||
})
|
})
|
||||||
|
|
||||||
defer closer()
|
defer closer()
|
||||||
|
@ -1076,7 +1076,7 @@ func TestEnvelopeReceived(t *testing.T) {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
ForceTLS: true,
|
ForceTLS: true,
|
||||||
ProtocolLogger: slog.New(slog.NewTextHandler(os.Stdout, nil)),
|
ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile),
|
||||||
})
|
})
|
||||||
|
|
||||||
defer closer()
|
defer closer()
|
||||||
|
@ -1122,7 +1122,7 @@ func TestEnvelopeReceived(t *testing.T) {
|
||||||
func TestHELO(t *testing.T) {
|
func TestHELO(t *testing.T) {
|
||||||
|
|
||||||
addr, closer := runserver(t, &smtpd.Server{
|
addr, closer := runserver(t, &smtpd.Server{
|
||||||
ProtocolLogger: slog.New(slog.NewTextHandler(os.Stdout, nil)),
|
ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile),
|
||||||
})
|
})
|
||||||
|
|
||||||
defer closer()
|
defer closer()
|
||||||
|
@ -1158,7 +1158,7 @@ func TestLOGINAuth(t *testing.T) {
|
||||||
|
|
||||||
addr, closer := runsslserver(t, &smtpd.Server{
|
addr, closer := runsslserver(t, &smtpd.Server{
|
||||||
Authenticator: func(peer smtpd.Peer, username, password string) error { return nil },
|
Authenticator: func(peer smtpd.Peer, username, password string) error { return nil },
|
||||||
ProtocolLogger: slog.New(slog.NewTextHandler(os.Stdout, nil)),
|
ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile),
|
||||||
})
|
})
|
||||||
|
|
||||||
defer closer()
|
defer closer()
|
||||||
|
@ -1213,7 +1213,7 @@ func TestLOGINAuth(t *testing.T) {
|
||||||
func TestNullSender(t *testing.T) {
|
func TestNullSender(t *testing.T) {
|
||||||
|
|
||||||
addr, closer := runserver(t, &smtpd.Server{
|
addr, closer := runserver(t, &smtpd.Server{
|
||||||
ProtocolLogger: slog.New(slog.NewTextHandler(os.Stdout, nil)),
|
ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile),
|
||||||
})
|
})
|
||||||
|
|
||||||
defer closer()
|
defer closer()
|
||||||
|
@ -1240,7 +1240,7 @@ func TestNullSender(t *testing.T) {
|
||||||
func TestNoBracketsSender(t *testing.T) {
|
func TestNoBracketsSender(t *testing.T) {
|
||||||
|
|
||||||
addr, closer := runserver(t, &smtpd.Server{
|
addr, closer := runserver(t, &smtpd.Server{
|
||||||
ProtocolLogger: slog.New(slog.NewTextHandler(os.Stdout, nil)),
|
ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile),
|
||||||
})
|
})
|
||||||
|
|
||||||
defer closer()
|
defer closer()
|
||||||
|
@ -1273,7 +1273,7 @@ func TestErrors(t *testing.T) {
|
||||||
|
|
||||||
server := &smtpd.Server{
|
server := &smtpd.Server{
|
||||||
Authenticator: func(peer smtpd.Peer, username, password string) error { return nil },
|
Authenticator: func(peer smtpd.Peer, username, password string) error { return nil },
|
||||||
ProtocolLogger: slog.New(slog.NewTextHandler(os.Stdout, nil)),
|
ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile),
|
||||||
}
|
}
|
||||||
|
|
||||||
addr, closer := runserver(t, server)
|
addr, closer := runserver(t, server)
|
||||||
|
@ -1356,7 +1356,7 @@ func TestMailformedMAILFROM(t *testing.T) {
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
ProtocolLogger: slog.New(slog.NewTextHandler(os.Stdout, nil)),
|
ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile),
|
||||||
})
|
})
|
||||||
|
|
||||||
defer closer()
|
defer closer()
|
||||||
|
@ -1402,7 +1402,7 @@ func TestTLSListener(t *testing.T) {
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
ProtocolLogger: slog.New(slog.NewTextHandler(os.Stdout, nil)),
|
ProtocolLogger: log.New(os.Stdout, "log: ", log.Lshortfile),
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -1436,113 +1436,3 @@ func TestTLSListener(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestShutdown(t *testing.T) {
|
|
||||||
fmt.Println("Starting test")
|
|
||||||
server := &smtpd.Server{
|
|
||||||
ProtocolLogger: slog.New(slog.NewTextHandler(os.Stdout, nil)),
|
|
||||||
}
|
|
||||||
|
|
||||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Listen failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
srvres := make(chan error)
|
|
||||||
go func() {
|
|
||||||
t.Log("Starting server")
|
|
||||||
srvres <- server.Serve(ln)
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Connect a client
|
|
||||||
c, err := smtp.Dial(ln.Addr().String())
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Dial failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.Hello("localhost"); err != nil {
|
|
||||||
t.Fatalf("HELO failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// While the client connection is open, shut down the server (without
|
|
||||||
// waiting for it to finish)
|
|
||||||
err = server.Shutdown(false)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Shutdown returned error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify that Shutdown() worked by attempting to connect another client
|
|
||||||
_, err = smtp.Dial(ln.Addr().String())
|
|
||||||
if err == nil {
|
|
||||||
t.Fatalf("Dial did not fail as expected")
|
|
||||||
}
|
|
||||||
if _, typok := err.(*net.OpError); !typok {
|
|
||||||
t.Fatalf("Dial did not return net.OpError as expected: %v (%T)", err, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for shutdown to complete
|
|
||||||
shutres := make(chan error)
|
|
||||||
go func() {
|
|
||||||
t.Log("Waiting for server shutdown to finish")
|
|
||||||
shutres <- server.Wait()
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Slight delay to ensure Shutdown() blocks
|
|
||||||
time.Sleep(250 * time.Millisecond)
|
|
||||||
|
|
||||||
// Wait() should not have returned yet due to open client conn
|
|
||||||
select {
|
|
||||||
case shuterr := <-shutres:
|
|
||||||
t.Fatalf("Wait() returned early w/ error: %v", shuterr)
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now close the client
|
|
||||||
t.Log("Closing client connection")
|
|
||||||
if err := c.Quit(); err != nil {
|
|
||||||
t.Fatalf("QUIT failed: %v", err)
|
|
||||||
}
|
|
||||||
c.Close()
|
|
||||||
|
|
||||||
// Wait for Wait() to return
|
|
||||||
t.Log("Waiting for Wait() to return")
|
|
||||||
select {
|
|
||||||
case shuterr := <-shutres:
|
|
||||||
if shuterr != nil {
|
|
||||||
t.Fatalf("Wait() returned error: %v", shuterr)
|
|
||||||
}
|
|
||||||
case <-time.After(15 * time.Second):
|
|
||||||
t.Fatalf("Timed out waiting for Wait() to return")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for Serve() to return
|
|
||||||
t.Log("Waiting for Serve() to return")
|
|
||||||
select {
|
|
||||||
case srverr := <-srvres:
|
|
||||||
if srverr != smtpd.ErrServerClosed {
|
|
||||||
t.Fatalf("Serve() returned error: %v", srverr)
|
|
||||||
}
|
|
||||||
case <-time.After(15 * time.Second):
|
|
||||||
t.Fatalf("Timed out waiting for Serve() to return")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestServeFailsIfShutdown(t *testing.T) {
|
|
||||||
server := &smtpd.Server{}
|
|
||||||
err := server.Shutdown(true)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Shutdown() failed: %v", err)
|
|
||||||
}
|
|
||||||
err = server.Serve(nil)
|
|
||||||
if err != smtpd.ErrServerClosed {
|
|
||||||
t.Fatalf("Serve() did not return ErrServerClosed: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWaitFailsIfNotShutdown(t *testing.T) {
|
|
||||||
server := &smtpd.Server{}
|
|
||||||
err := server.Wait()
|
|
||||||
if err == nil {
|
|
||||||
t.Fatalf("Wait() did not fail as expected")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue