summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/miekg/dns/server.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/miekg/dns/server.go')
-rw-r--r--vendor/github.com/miekg/dns/server.go199
1 files changed, 131 insertions, 68 deletions
diff --git a/vendor/github.com/miekg/dns/server.go b/vendor/github.com/miekg/dns/server.go
index 685753f43..2d98f1488 100644
--- a/vendor/github.com/miekg/dns/server.go
+++ b/vendor/github.com/miekg/dns/server.go
@@ -9,12 +9,19 @@ import (
"io"
"net"
"sync"
+ "sync/atomic"
"time"
)
-// Maximum number of TCP queries before we close the socket.
+// Default maximum number of TCP queries before we close the socket.
const maxTCPQueries = 128
+// Interval for stop worker if no load
+const idleWorkerTimeout = 10 * time.Second
+
+// Maximum number of workers
+const maxWorkersCount = 10000
+
// Handler is implemented by any value that implements ServeDNS.
type Handler interface {
ServeDNS(w ResponseWriter, r *Msg)
@@ -43,6 +50,7 @@ type ResponseWriter interface {
}
type response struct {
+ msg []byte
hijacked bool // connection has been hijacked by handler
tsigStatus error
tsigTimersOnly bool
@@ -51,7 +59,6 @@ type response struct {
udp *net.UDPConn // i/o connection if UDP was used
tcp net.Conn // i/o connection if TCP was used
udpSession *SessionUDP // oob data to get egress interface right
- remoteAddr net.Addr // address of the client
writer Writer // writer to output the raw DNS bits
}
@@ -296,12 +303,63 @@ type Server struct {
DecorateReader DecorateReader
// DecorateWriter is optional, allows customization of the process that writes raw DNS messages.
DecorateWriter DecorateWriter
+ // Maximum number of TCP queries before we close the socket. Default is maxTCPQueries (unlimited if -1).
+ MaxTCPQueries int
+ // UDP packet or TCP connection queue
+ queue chan *response
+ // Workers count
+ workersCount int32
// Shutdown handling
lock sync.RWMutex
started bool
}
+func (srv *Server) worker(w *response) {
+ srv.serve(w)
+
+ for {
+ count := atomic.LoadInt32(&srv.workersCount)
+ if count > maxWorkersCount {
+ return
+ }
+ if atomic.CompareAndSwapInt32(&srv.workersCount, count, count+1) {
+ break
+ }
+ }
+
+ defer atomic.AddInt32(&srv.workersCount, -1)
+
+ inUse := false
+ timeout := time.NewTimer(idleWorkerTimeout)
+ defer timeout.Stop()
+LOOP:
+ for {
+ select {
+ case w, ok := <-srv.queue:
+ if !ok {
+ break LOOP
+ }
+ inUse = true
+ srv.serve(w)
+ case <-timeout.C:
+ if !inUse {
+ break LOOP
+ }
+ inUse = false
+ timeout.Reset(idleWorkerTimeout)
+ }
+ }
+}
+
+func (srv *Server) spawnWorker(w *response) {
+ select {
+ case srv.queue <- w:
+ default:
+ go srv.worker(w)
+ }
+}
+
// ListenAndServe starts a nameserver on the configured address in *Server.
func (srv *Server) ListenAndServe() error {
srv.lock.Lock()
@@ -309,6 +367,7 @@ func (srv *Server) ListenAndServe() error {
if srv.started {
return &Error{err: "server already started"}
}
+
addr := srv.Addr
if addr == "" {
addr = ":domain"
@@ -316,6 +375,8 @@ func (srv *Server) ListenAndServe() error {
if srv.UDPSize == 0 {
srv.UDPSize = MinMsgSize
}
+ srv.queue = make(chan *response)
+ defer close(srv.queue)
switch srv.Net {
case "tcp", "tcp4", "tcp6":
a, err := net.ResolveTCPAddr(srv.Net, addr)
@@ -380,8 +441,11 @@ func (srv *Server) ActivateAndServe() error {
if srv.started {
return &Error{err: "server already started"}
}
+
pConn := srv.PacketConn
l := srv.Listener
+ srv.queue = make(chan *response)
+ defer close(srv.queue)
if pConn != nil {
if srv.UDPSize == 0 {
srv.UDPSize = MinMsgSize
@@ -439,7 +503,6 @@ func (srv *Server) getReadTimeout() time.Duration {
}
// serveTCP starts a TCP listener for the server.
-// Each request is handled in a separate goroutine.
func (srv *Server) serveTCP(l net.Listener) error {
defer l.Close()
@@ -447,17 +510,6 @@ func (srv *Server) serveTCP(l net.Listener) error {
srv.NotifyStartedFunc()
}
- reader := Reader(&defaultReader{srv})
- if srv.DecorateReader != nil {
- reader = srv.DecorateReader(reader)
- }
-
- handler := srv.Handler
- if handler == nil {
- handler = DefaultServeMux
- }
- rtimeout := srv.getReadTimeout()
- // deadline is not used here
for {
rw, err := l.Accept()
srv.lock.RLock()
@@ -472,19 +524,11 @@ func (srv *Server) serveTCP(l net.Listener) error {
}
return err
}
- go func() {
- m, err := reader.ReadTCP(rw, rtimeout)
- if err != nil {
- rw.Close()
- return
- }
- srv.serve(rw.RemoteAddr(), handler, m, nil, nil, rw)
- }()
+ srv.spawnWorker(&response{tsigSecret: srv.TsigSecret, tcp: rw})
}
}
// serveUDP starts a UDP listener for the server.
-// Each request is handled in a separate goroutine.
func (srv *Server) serveUDP(l *net.UDPConn) error {
defer l.Close()
@@ -497,10 +541,6 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
reader = srv.DecorateReader(reader)
}
- handler := srv.Handler
- if handler == nil {
- handler = DefaultServeMux
- }
rtimeout := srv.getReadTimeout()
// deadline is not used here
for {
@@ -520,80 +560,98 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
if len(m) < headerSize {
continue
}
- go srv.serve(s.RemoteAddr(), handler, m, l, s, nil)
+ srv.spawnWorker(&response{msg: m, tsigSecret: srv.TsigSecret, udp: l, udpSession: s})
}
}
-// Serve a new connection.
-func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *SessionUDP, t net.Conn) {
- w := &response{tsigSecret: srv.TsigSecret, udp: u, tcp: t, remoteAddr: a, udpSession: s}
+func (srv *Server) serve(w *response) {
if srv.DecorateWriter != nil {
w.writer = srv.DecorateWriter(w)
} else {
w.writer = w
}
- q := 0 // counter for the amount of TCP queries we get
+ if w.udp != nil {
+ // serve UDP
+ srv.serveDNS(w)
+ return
+ }
reader := Reader(&defaultReader{srv})
if srv.DecorateReader != nil {
reader = srv.DecorateReader(reader)
}
-Redo:
+
+ defer func() {
+ if !w.hijacked {
+ w.Close()
+ }
+ }()
+
+ idleTimeout := tcpIdleTimeout
+ if srv.IdleTimeout != nil {
+ idleTimeout = srv.IdleTimeout()
+ }
+
+ timeout := srv.getReadTimeout()
+
+ limit := srv.MaxTCPQueries
+ if limit == 0 {
+ limit = maxTCPQueries
+ }
+
+ for q := 0; q < limit || limit == -1; q++ {
+ var err error
+ w.msg, err = reader.ReadTCP(w.tcp, timeout)
+ if err != nil {
+ // TODO(tmthrgd): handle error
+ break
+ }
+ srv.serveDNS(w)
+ if w.tcp == nil {
+ break // Close() was called
+ }
+ if w.hijacked {
+ break // client will call Close() themselves
+ }
+ // The first read uses the read timeout, the rest use the
+ // idle timeout.
+ timeout = idleTimeout
+ }
+}
+
+func (srv *Server) serveDNS(w *response) {
req := new(Msg)
- err := req.Unpack(m)
+ err := req.Unpack(w.msg)
if err != nil { // Send a FormatError back
x := new(Msg)
x.SetRcodeFormatError(req)
w.WriteMsg(x)
- goto Exit
+ return
}
if !srv.Unsafe && req.Response {
- goto Exit
+ return
}
w.tsigStatus = nil
if w.tsigSecret != nil {
if t := req.IsTsig(); t != nil {
- secret := t.Hdr.Name
- if _, ok := w.tsigSecret[secret]; !ok {
- w.tsigStatus = ErrKeyAlg
+ if secret, ok := w.tsigSecret[t.Hdr.Name]; ok {
+ w.tsigStatus = TsigVerify(w.msg, secret, "", false)
+ } else {
+ w.tsigStatus = ErrSecret
}
- w.tsigStatus = TsigVerify(m, w.tsigSecret[secret], "", false)
w.tsigTimersOnly = false
w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC
}
}
- h.ServeDNS(w, req) // Writes back to the client
-Exit:
- if w.tcp == nil {
- return
- }
- // TODO(miek): make this number configurable?
- if q > maxTCPQueries { // close socket after this many queries
- w.Close()
- return
+ handler := srv.Handler
+ if handler == nil {
+ handler = DefaultServeMux
}
- if w.hijacked {
- return // client calls Close()
- }
- if u != nil { // UDP, "close" and return
- w.Close()
- return
- }
- idleTimeout := tcpIdleTimeout
- if srv.IdleTimeout != nil {
- idleTimeout = srv.IdleTimeout()
- }
- m, err = reader.ReadTCP(w.tcp, idleTimeout)
- if err == nil {
- q++
- goto Redo
- }
- w.Close()
- return
+ handler.ServeDNS(w, req) // Writes back to the client
}
func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
@@ -696,7 +754,12 @@ func (w *response) LocalAddr() net.Addr {
}
// RemoteAddr implements the ResponseWriter.RemoteAddr method.
-func (w *response) RemoteAddr() net.Addr { return w.remoteAddr }
+func (w *response) RemoteAddr() net.Addr {
+ if w.tcp != nil {
+ return w.tcp.RemoteAddr()
+ }
+ return w.udpSession.RemoteAddr()
+}
// TsigStatus implements the ResponseWriter.TsigStatus method.
func (w *response) TsigStatus() error { return w.tsigStatus }