summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/miekg/dns/server.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/miekg/dns/server.go')
-rw-r--r--vendor/github.com/miekg/dns/server.go302
1 files changed, 217 insertions, 85 deletions
diff --git a/vendor/github.com/miekg/dns/server.go b/vendor/github.com/miekg/dns/server.go
index 2d98f1488..2901f8724 100644
--- a/vendor/github.com/miekg/dns/server.go
+++ b/vendor/github.com/miekg/dns/server.go
@@ -4,10 +4,13 @@ package dns
import (
"bytes"
+ "context"
"crypto/tls"
"encoding/binary"
+ "errors"
"io"
"net"
+ "strings"
"sync"
"sync/atomic"
"time"
@@ -16,11 +19,22 @@ import (
// Default maximum number of TCP queries before we close the socket.
const maxTCPQueries = 128
-// Interval for stop worker if no load
+// The maximum number of idle workers.
+//
+// This controls the maximum number of workers that are allowed to stay
+// idle waiting for incoming requests before being torn down.
+//
+// If this limit is reached, the server will just keep spawning new
+// workers (goroutines) for each incoming request. In this case, each
+// worker will only be used for a single request.
+const maxIdleWorkersCount = 10000
+
+// The maximum length of time a worker may idle for before being destroyed.
const idleWorkerTimeout = 10 * time.Second
-// Maximum number of workers
-const maxWorkersCount = 10000
+// aLongTimeAgo is a non-zero time, far in the past, used for
+// immediate cancelation of network operations.
+var aLongTimeAgo = time.Unix(1, 0)
// Handler is implemented by any value that implements ServeDNS.
type Handler interface {
@@ -49,6 +63,12 @@ type ResponseWriter interface {
Hijack()
}
+// A ConnectionStater interface is used by a DNS Handler to access TLS connection state
+// when available.
+type ConnectionStater interface {
+ ConnectionState() *tls.ConnectionState
+}
+
type response struct {
msg []byte
hijacked bool // connection has been hijacked by handler
@@ -60,6 +80,7 @@ type response struct {
tcp net.Conn // i/o connection if TCP was used
udpSession *SessionUDP // oob data to get egress interface right
writer Writer // writer to output the raw DNS bits
+ wg *sync.WaitGroup // for gracefull shutdown
}
// ServeMux is an DNS request multiplexer. It matches the
@@ -151,7 +172,7 @@ func (mux *ServeMux) match(q string, t uint16) Handler {
for i := 0; i < l; i++ {
b[i] = q[off+i]
if b[i] >= 'A' && b[i] <= 'Z' {
- b[i] |= ('a' - 'A')
+ b[i] |= 'a' - 'A'
}
}
if h, ok := mux.z[string(b[:l])]; ok { // causes garbage, might want to change the map key
@@ -305,14 +326,30 @@ type Server struct {
DecorateWriter DecorateWriter
// Maximum number of TCP queries before we close the socket. Default is maxTCPQueries (unlimited if -1).
MaxTCPQueries int
+ // Whether to set the SO_REUSEPORT socket option, allowing multiple listeners to be bound to a single address.
+ // It is only supported on go1.11+ and when using ListenAndServe.
+ ReusePort bool
// UDP packet or TCP connection queue
queue chan *response
// Workers count
workersCount int32
+
// Shutdown handling
- lock sync.RWMutex
- started bool
+ lock sync.RWMutex
+ started bool
+ shutdown chan struct{}
+ conns map[net.Conn]struct{}
+
+ // A pool for UDP message buffers.
+ udpPool sync.Pool
+}
+
+func (srv *Server) isStarted() bool {
+ srv.lock.RLock()
+ started := srv.started
+ srv.lock.RUnlock()
+ return started
}
func (srv *Server) worker(w *response) {
@@ -320,7 +357,7 @@ func (srv *Server) worker(w *response) {
for {
count := atomic.LoadInt32(&srv.workersCount)
- if count > maxWorkersCount {
+ if count > maxIdleWorkersCount {
return
}
if atomic.CompareAndSwapInt32(&srv.workersCount, count, count+1) {
@@ -360,10 +397,36 @@ func (srv *Server) spawnWorker(w *response) {
}
}
+func makeUDPBuffer(size int) func() interface{} {
+ return func() interface{} {
+ return make([]byte, size)
+ }
+}
+
+func (srv *Server) init() {
+ srv.queue = make(chan *response)
+
+ srv.shutdown = make(chan struct{})
+ srv.conns = make(map[net.Conn]struct{})
+
+ if srv.UDPSize == 0 {
+ srv.UDPSize = MinMsgSize
+ }
+
+ srv.udpPool.New = makeUDPBuffer(srv.UDPSize)
+}
+
+func unlockOnce(l sync.Locker) func() {
+ var once sync.Once
+ return func() { once.Do(l.Unlock) }
+}
+
// ListenAndServe starts a nameserver on the configured address in *Server.
func (srv *Server) ListenAndServe() error {
+ unlock := unlockOnce(&srv.lock)
srv.lock.Lock()
- defer srv.lock.Unlock()
+ defer unlock()
+
if srv.started {
return &Error{err: "server already started"}
}
@@ -372,63 +435,47 @@ func (srv *Server) ListenAndServe() error {
if addr == "" {
addr = ":domain"
}
- if srv.UDPSize == 0 {
- srv.UDPSize = MinMsgSize
- }
- srv.queue = make(chan *response)
+
+ srv.init()
defer close(srv.queue)
+
switch srv.Net {
case "tcp", "tcp4", "tcp6":
- a, err := net.ResolveTCPAddr(srv.Net, addr)
- if err != nil {
- return err
- }
- l, err := net.ListenTCP(srv.Net, a)
+ l, err := listenTCP(srv.Net, addr, srv.ReusePort)
if err != nil {
return err
}
srv.Listener = l
srv.started = true
- srv.lock.Unlock()
- err = srv.serveTCP(l)
- srv.lock.Lock() // to satisfy the defer at the top
- return err
+ unlock()
+ return srv.serveTCP(l)
case "tcp-tls", "tcp4-tls", "tcp6-tls":
- network := "tcp"
- if srv.Net == "tcp4-tls" {
- network = "tcp4"
- } else if srv.Net == "tcp6-tls" {
- network = "tcp6"
+ if srv.TLSConfig == nil || (len(srv.TLSConfig.Certificates) == 0 && srv.TLSConfig.GetCertificate == nil) {
+ return errors.New("dns: neither Certificates nor GetCertificate set in Config")
}
-
- l, err := tls.Listen(network, addr, srv.TLSConfig)
+ network := strings.TrimSuffix(srv.Net, "-tls")
+ l, err := listenTCP(network, addr, srv.ReusePort)
if err != nil {
return err
}
+ l = tls.NewListener(l, srv.TLSConfig)
srv.Listener = l
srv.started = true
- srv.lock.Unlock()
- err = srv.serveTCP(l)
- srv.lock.Lock() // to satisfy the defer at the top
- return err
+ unlock()
+ return srv.serveTCP(l)
case "udp", "udp4", "udp6":
- a, err := net.ResolveUDPAddr(srv.Net, addr)
- if err != nil {
- return err
- }
- l, err := net.ListenUDP(srv.Net, a)
+ l, err := listenUDP(srv.Net, addr, srv.ReusePort)
if err != nil {
return err
}
- if e := setUDPSocketOptions(l); e != nil {
+ u := l.(*net.UDPConn)
+ if e := setUDPSocketOptions(u); e != nil {
return e
}
srv.PacketConn = l
srv.started = true
- srv.lock.Unlock()
- err = srv.serveUDP(l)
- srv.lock.Lock() // to satisfy the defer at the top
- return err
+ unlock()
+ return srv.serveUDP(u)
}
return &Error{err: "bad network"}
}
@@ -436,20 +483,20 @@ func (srv *Server) ListenAndServe() error {
// ActivateAndServe starts a nameserver with the PacketConn or Listener
// configured in *Server. Its main use is to start a server from systemd.
func (srv *Server) ActivateAndServe() error {
+ unlock := unlockOnce(&srv.lock)
srv.lock.Lock()
- defer srv.lock.Unlock()
+ defer unlock()
+
if srv.started {
return &Error{err: "server already started"}
}
+ srv.init()
+ defer close(srv.queue)
+
pConn := srv.PacketConn
l := srv.Listener
- srv.queue = make(chan *response)
- defer close(srv.queue)
if pConn != nil {
- if srv.UDPSize == 0 {
- srv.UDPSize = MinMsgSize
- }
// Check PacketConn interface's type is valid and value
// is not nil
if t, ok := pConn.(*net.UDPConn); ok && t != nil {
@@ -457,18 +504,14 @@ func (srv *Server) ActivateAndServe() error {
return e
}
srv.started = true
- srv.lock.Unlock()
- e := srv.serveUDP(t)
- srv.lock.Lock() // to satisfy the defer at the top
- return e
+ unlock()
+ return srv.serveUDP(t)
}
}
if l != nil {
srv.started = true
- srv.lock.Unlock()
- e := srv.serveTCP(l)
- srv.lock.Lock() // to satisfy the defer at the top
- return e
+ unlock()
+ return srv.serveTCP(l)
}
return &Error{err: "bad listeners"}
}
@@ -476,23 +519,58 @@ func (srv *Server) ActivateAndServe() error {
// Shutdown shuts down a server. After a call to Shutdown, ListenAndServe and
// ActivateAndServe will return.
func (srv *Server) Shutdown() error {
+ return srv.ShutdownContext(context.Background())
+}
+
+// ShutdownContext shuts down a server. After a call to ShutdownContext,
+// ListenAndServe and ActivateAndServe will return.
+//
+// A context.Context may be passed to limit how long to wait for connections
+// to terminate.
+func (srv *Server) ShutdownContext(ctx context.Context) error {
srv.lock.Lock()
- if !srv.started {
- srv.lock.Unlock()
- return &Error{err: "server not started"}
- }
+ started := srv.started
srv.started = false
srv.lock.Unlock()
+ if !started {
+ return &Error{err: "server not started"}
+ }
+
if srv.PacketConn != nil {
- srv.PacketConn.Close()
+ srv.PacketConn.SetReadDeadline(aLongTimeAgo) // Unblock reads
}
+
if srv.Listener != nil {
srv.Listener.Close()
}
- return nil
+
+ srv.lock.Lock()
+ for rw := range srv.conns {
+ rw.SetReadDeadline(aLongTimeAgo) // Unblock reads
+ }
+ srv.lock.Unlock()
+
+ if testShutdownNotify != nil {
+ testShutdownNotify.Broadcast()
+ }
+
+ var ctxErr error
+ select {
+ case <-srv.shutdown:
+ case <-ctx.Done():
+ ctxErr = ctx.Err()
+ }
+
+ if srv.PacketConn != nil {
+ srv.PacketConn.Close()
+ }
+
+ return ctxErr
}
+var testShutdownNotify *sync.Cond
+
// getReadTimeout is a helper func to use system timeout if server did not intend to change it.
func (srv *Server) getReadTimeout() time.Duration {
rtimeout := dnsTimeout
@@ -510,22 +588,36 @@ func (srv *Server) serveTCP(l net.Listener) error {
srv.NotifyStartedFunc()
}
- for {
+ var wg sync.WaitGroup
+ defer func() {
+ wg.Wait()
+ close(srv.shutdown)
+ }()
+
+ for srv.isStarted() {
rw, err := l.Accept()
- srv.lock.RLock()
- if !srv.started {
- srv.lock.RUnlock()
- return nil
- }
- srv.lock.RUnlock()
if err != nil {
+ if !srv.isStarted() {
+ return nil
+ }
if neterr, ok := err.(net.Error); ok && neterr.Temporary() {
continue
}
return err
}
- srv.spawnWorker(&response{tsigSecret: srv.TsigSecret, tcp: rw})
+ srv.lock.Lock()
+ // Track the connection to allow unblocking reads on shutdown.
+ srv.conns[rw] = struct{}{}
+ srv.lock.Unlock()
+ wg.Add(1)
+ srv.spawnWorker(&response{
+ tsigSecret: srv.TsigSecret,
+ tcp: rw,
+ wg: &wg,
+ })
}
+
+ return nil
}
// serveUDP starts a UDP listener for the server.
@@ -541,27 +633,42 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
reader = srv.DecorateReader(reader)
}
+ var wg sync.WaitGroup
+ defer func() {
+ wg.Wait()
+ close(srv.shutdown)
+ }()
+
rtimeout := srv.getReadTimeout()
// deadline is not used here
- for {
+ for srv.isStarted() {
m, s, err := reader.ReadUDP(l, rtimeout)
- srv.lock.RLock()
- if !srv.started {
- srv.lock.RUnlock()
- return nil
- }
- srv.lock.RUnlock()
if err != nil {
+ if !srv.isStarted() {
+ return nil
+ }
if netErr, ok := err.(net.Error); ok && netErr.Temporary() {
continue
}
return err
}
if len(m) < headerSize {
+ if cap(m) == srv.UDPSize {
+ srv.udpPool.Put(m[:srv.UDPSize])
+ }
continue
}
- srv.spawnWorker(&response{msg: m, tsigSecret: srv.TsigSecret, udp: l, udpSession: s})
+ wg.Add(1)
+ srv.spawnWorker(&response{
+ msg: m,
+ tsigSecret: srv.TsigSecret,
+ udp: l,
+ udpSession: s,
+ wg: &wg,
+ })
}
+
+ return nil
}
func (srv *Server) serve(w *response) {
@@ -574,20 +681,28 @@ func (srv *Server) serve(w *response) {
if w.udp != nil {
// serve UDP
srv.serveDNS(w)
- return
- }
- reader := Reader(&defaultReader{srv})
- if srv.DecorateReader != nil {
- reader = srv.DecorateReader(reader)
+ w.wg.Done()
+ return
}
defer func() {
if !w.hijacked {
w.Close()
}
+
+ srv.lock.Lock()
+ delete(srv.conns, w.tcp)
+ srv.lock.Unlock()
+
+ w.wg.Done()
}()
+ reader := Reader(&defaultReader{srv})
+ if srv.DecorateReader != nil {
+ reader = srv.DecorateReader(reader)
+ }
+
idleTimeout := tcpIdleTimeout
if srv.IdleTimeout != nil {
idleTimeout = srv.IdleTimeout()
@@ -600,7 +715,7 @@ func (srv *Server) serve(w *response) {
limit = maxTCPQueries
}
- for q := 0; q < limit || limit == -1; q++ {
+ for q := 0; (q < limit || limit == -1) && srv.isStarted(); q++ {
var err error
w.msg, err = reader.ReadTCP(w.tcp, timeout)
if err != nil {
@@ -623,6 +738,10 @@ func (srv *Server) serve(w *response) {
func (srv *Server) serveDNS(w *response) {
req := new(Msg)
err := req.Unpack(w.msg)
+ if w.udp != nil && cap(w.msg) == srv.UDPSize {
+ srv.udpPool.Put(w.msg[:srv.UDPSize])
+ }
+ w.msg = nil
if err != nil { // Send a FormatError back
x := new(Msg)
x.SetRcodeFormatError(req)
@@ -691,9 +810,10 @@ func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error)
func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) {
conn.SetReadDeadline(time.Now().Add(timeout))
- m := make([]byte, srv.UDPSize)
+ m := srv.udpPool.Get().([]byte)
n, s, err := ReadFromSessionUDP(conn, m)
if err != nil {
+ srv.udpPool.Put(m)
return nil, nil, err
}
m = m[:n]
@@ -780,3 +900,15 @@ func (w *response) Close() error {
}
return nil
}
+
+// ConnectionState() implements the ConnectionStater.ConnectionState() interface.
+func (w *response) ConnectionState() *tls.ConnectionState {
+ type tlsConnectionStater interface {
+ ConnectionState() tls.ConnectionState
+ }
+ if v, ok := w.tcp.(tlsConnectionStater); ok {
+ t := v.ConnectionState()
+ return &t
+ }
+ return nil
+}