From d5e1f7e2982c2fcc888ccac550b34095efbee217 Mon Sep 17 00:00:00 2001 From: Christopher Speller Date: Fri, 18 May 2018 07:32:31 -0700 Subject: Upgrading server dependency. (#8807) --- vendor/github.com/miekg/dns/server.go | 199 ++++++++++++++++++++++------------ 1 file changed, 131 insertions(+), 68 deletions(-) (limited to 'vendor/github.com/miekg/dns/server.go') diff --git a/vendor/github.com/miekg/dns/server.go b/vendor/github.com/miekg/dns/server.go index 685753f43..2d98f1488 100644 --- a/vendor/github.com/miekg/dns/server.go +++ b/vendor/github.com/miekg/dns/server.go @@ -9,12 +9,19 @@ import ( "io" "net" "sync" + "sync/atomic" "time" ) -// Maximum number of TCP queries before we close the socket. +// Default maximum number of TCP queries before we close the socket. const maxTCPQueries = 128 +// Interval for stop worker if no load +const idleWorkerTimeout = 10 * time.Second + +// Maximum number of workers +const maxWorkersCount = 10000 + // Handler is implemented by any value that implements ServeDNS. type Handler interface { ServeDNS(w ResponseWriter, r *Msg) @@ -43,6 +50,7 @@ type ResponseWriter interface { } type response struct { + msg []byte hijacked bool // connection has been hijacked by handler tsigStatus error tsigTimersOnly bool @@ -51,7 +59,6 @@ type response struct { udp *net.UDPConn // i/o connection if UDP was used tcp net.Conn // i/o connection if TCP was used udpSession *SessionUDP // oob data to get egress interface right - remoteAddr net.Addr // address of the client writer Writer // writer to output the raw DNS bits } @@ -296,12 +303,63 @@ type Server struct { DecorateReader DecorateReader // DecorateWriter is optional, allows customization of the process that writes raw DNS messages. DecorateWriter DecorateWriter + // Maximum number of TCP queries before we close the socket. Default is maxTCPQueries (unlimited if -1). + MaxTCPQueries int + // UDP packet or TCP connection queue + queue chan *response + // Workers count + workersCount int32 // Shutdown handling lock sync.RWMutex started bool } +func (srv *Server) worker(w *response) { + srv.serve(w) + + for { + count := atomic.LoadInt32(&srv.workersCount) + if count > maxWorkersCount { + return + } + if atomic.CompareAndSwapInt32(&srv.workersCount, count, count+1) { + break + } + } + + defer atomic.AddInt32(&srv.workersCount, -1) + + inUse := false + timeout := time.NewTimer(idleWorkerTimeout) + defer timeout.Stop() +LOOP: + for { + select { + case w, ok := <-srv.queue: + if !ok { + break LOOP + } + inUse = true + srv.serve(w) + case <-timeout.C: + if !inUse { + break LOOP + } + inUse = false + timeout.Reset(idleWorkerTimeout) + } + } +} + +func (srv *Server) spawnWorker(w *response) { + select { + case srv.queue <- w: + default: + go srv.worker(w) + } +} + // ListenAndServe starts a nameserver on the configured address in *Server. func (srv *Server) ListenAndServe() error { srv.lock.Lock() @@ -309,6 +367,7 @@ func (srv *Server) ListenAndServe() error { if srv.started { return &Error{err: "server already started"} } + addr := srv.Addr if addr == "" { addr = ":domain" @@ -316,6 +375,8 @@ func (srv *Server) ListenAndServe() error { if srv.UDPSize == 0 { srv.UDPSize = MinMsgSize } + srv.queue = make(chan *response) + defer close(srv.queue) switch srv.Net { case "tcp", "tcp4", "tcp6": a, err := net.ResolveTCPAddr(srv.Net, addr) @@ -380,8 +441,11 @@ func (srv *Server) ActivateAndServe() error { if srv.started { return &Error{err: "server already started"} } + pConn := srv.PacketConn l := srv.Listener + srv.queue = make(chan *response) + defer close(srv.queue) if pConn != nil { if srv.UDPSize == 0 { srv.UDPSize = MinMsgSize @@ -439,7 +503,6 @@ func (srv *Server) getReadTimeout() time.Duration { } // serveTCP starts a TCP listener for the server. -// Each request is handled in a separate goroutine. func (srv *Server) serveTCP(l net.Listener) error { defer l.Close() @@ -447,17 +510,6 @@ func (srv *Server) serveTCP(l net.Listener) error { srv.NotifyStartedFunc() } - reader := Reader(&defaultReader{srv}) - if srv.DecorateReader != nil { - reader = srv.DecorateReader(reader) - } - - handler := srv.Handler - if handler == nil { - handler = DefaultServeMux - } - rtimeout := srv.getReadTimeout() - // deadline is not used here for { rw, err := l.Accept() srv.lock.RLock() @@ -472,19 +524,11 @@ func (srv *Server) serveTCP(l net.Listener) error { } return err } - go func() { - m, err := reader.ReadTCP(rw, rtimeout) - if err != nil { - rw.Close() - return - } - srv.serve(rw.RemoteAddr(), handler, m, nil, nil, rw) - }() + srv.spawnWorker(&response{tsigSecret: srv.TsigSecret, tcp: rw}) } } // serveUDP starts a UDP listener for the server. -// Each request is handled in a separate goroutine. func (srv *Server) serveUDP(l *net.UDPConn) error { defer l.Close() @@ -497,10 +541,6 @@ func (srv *Server) serveUDP(l *net.UDPConn) error { reader = srv.DecorateReader(reader) } - handler := srv.Handler - if handler == nil { - handler = DefaultServeMux - } rtimeout := srv.getReadTimeout() // deadline is not used here for { @@ -520,80 +560,98 @@ func (srv *Server) serveUDP(l *net.UDPConn) error { if len(m) < headerSize { continue } - go srv.serve(s.RemoteAddr(), handler, m, l, s, nil) + srv.spawnWorker(&response{msg: m, tsigSecret: srv.TsigSecret, udp: l, udpSession: s}) } } -// Serve a new connection. -func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *SessionUDP, t net.Conn) { - w := &response{tsigSecret: srv.TsigSecret, udp: u, tcp: t, remoteAddr: a, udpSession: s} +func (srv *Server) serve(w *response) { if srv.DecorateWriter != nil { w.writer = srv.DecorateWriter(w) } else { w.writer = w } - q := 0 // counter for the amount of TCP queries we get + if w.udp != nil { + // serve UDP + srv.serveDNS(w) + return + } reader := Reader(&defaultReader{srv}) if srv.DecorateReader != nil { reader = srv.DecorateReader(reader) } -Redo: + + defer func() { + if !w.hijacked { + w.Close() + } + }() + + idleTimeout := tcpIdleTimeout + if srv.IdleTimeout != nil { + idleTimeout = srv.IdleTimeout() + } + + timeout := srv.getReadTimeout() + + limit := srv.MaxTCPQueries + if limit == 0 { + limit = maxTCPQueries + } + + for q := 0; q < limit || limit == -1; q++ { + var err error + w.msg, err = reader.ReadTCP(w.tcp, timeout) + if err != nil { + // TODO(tmthrgd): handle error + break + } + srv.serveDNS(w) + if w.tcp == nil { + break // Close() was called + } + if w.hijacked { + break // client will call Close() themselves + } + // The first read uses the read timeout, the rest use the + // idle timeout. + timeout = idleTimeout + } +} + +func (srv *Server) serveDNS(w *response) { req := new(Msg) - err := req.Unpack(m) + err := req.Unpack(w.msg) if err != nil { // Send a FormatError back x := new(Msg) x.SetRcodeFormatError(req) w.WriteMsg(x) - goto Exit + return } if !srv.Unsafe && req.Response { - goto Exit + return } w.tsigStatus = nil if w.tsigSecret != nil { if t := req.IsTsig(); t != nil { - secret := t.Hdr.Name - if _, ok := w.tsigSecret[secret]; !ok { - w.tsigStatus = ErrKeyAlg + if secret, ok := w.tsigSecret[t.Hdr.Name]; ok { + w.tsigStatus = TsigVerify(w.msg, secret, "", false) + } else { + w.tsigStatus = ErrSecret } - w.tsigStatus = TsigVerify(m, w.tsigSecret[secret], "", false) w.tsigTimersOnly = false w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC } } - h.ServeDNS(w, req) // Writes back to the client -Exit: - if w.tcp == nil { - return - } - // TODO(miek): make this number configurable? - if q > maxTCPQueries { // close socket after this many queries - w.Close() - return + handler := srv.Handler + if handler == nil { + handler = DefaultServeMux } - if w.hijacked { - return // client calls Close() - } - if u != nil { // UDP, "close" and return - w.Close() - return - } - idleTimeout := tcpIdleTimeout - if srv.IdleTimeout != nil { - idleTimeout = srv.IdleTimeout() - } - m, err = reader.ReadTCP(w.tcp, idleTimeout) - if err == nil { - q++ - goto Redo - } - w.Close() - return + handler.ServeDNS(w, req) // Writes back to the client } func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error) { @@ -696,7 +754,12 @@ func (w *response) LocalAddr() net.Addr { } // RemoteAddr implements the ResponseWriter.RemoteAddr method. -func (w *response) RemoteAddr() net.Addr { return w.remoteAddr } +func (w *response) RemoteAddr() net.Addr { + if w.tcp != nil { + return w.tcp.RemoteAddr() + } + return w.udpSession.RemoteAddr() +} // TsigStatus implements the ResponseWriter.TsigStatus method. func (w *response) TsigStatus() error { return w.tsigStatus } -- cgit v1.2.3-1-g7c22