summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/miekg/dns/client.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/miekg/dns/client.go')
-rw-r--r--vendor/github.com/miekg/dns/client.go78
1 files changed, 71 insertions, 7 deletions
diff --git a/vendor/github.com/miekg/dns/client.go b/vendor/github.com/miekg/dns/client.go
index 301dab9c1..1c14a19d8 100644
--- a/vendor/github.com/miekg/dns/client.go
+++ b/vendor/github.com/miekg/dns/client.go
@@ -4,6 +4,7 @@ package dns
import (
"bytes"
+ "context"
"crypto/tls"
"encoding/binary"
"io"
@@ -70,6 +71,43 @@ func Exchange(m *Msg, a string) (r *Msg, err error) {
return r, err
}
+// ExchangeContext performs a synchronous UDP query, like Exchange. It
+// additionally obeys deadlines from the passed Context.
+func ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg, err error) {
+ // Combine context deadline with built-in timeout. Context chooses whichever
+ // is sooner.
+ timeoutCtx, cancel := context.WithTimeout(ctx, dnsTimeout)
+ defer cancel()
+ deadline, _ := timeoutCtx.Deadline()
+
+ co := new(Conn)
+ dialer := net.Dialer{}
+ co.Conn, err = dialer.DialContext(timeoutCtx, "udp", a)
+ if err != nil {
+ return nil, err
+ }
+
+ defer co.Conn.Close()
+
+ opt := m.IsEdns0()
+ // If EDNS0 is used use that for size.
+ if opt != nil && opt.UDPSize() >= MinMsgSize {
+ co.UDPSize = opt.UDPSize()
+ }
+
+ co.SetWriteDeadline(deadline)
+ if err = co.WriteMsg(m); err != nil {
+ return nil, err
+ }
+
+ co.SetReadDeadline(deadline)
+ r, err = co.ReadMsg()
+ if err == nil && r.Id != m.Id {
+ err = ErrId
+ }
+ return r, err
+}
+
// ExchangeConn performs a synchronous query. It sends the message m via the connection
// c and waits for a reply. The connection c is not closed by ExchangeConn.
// This function is going away, but can easily be mimicked:
@@ -106,8 +144,18 @@ func ExchangeConn(c net.Conn, m *Msg) (r *Msg, err error) {
// buffer, see SetEdns0. Messages without an OPT RR will fallback to the historic limit
// of 512 bytes.
func (c *Client) Exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
+ return c.ExchangeContext(context.Background(), m, a)
+}
+
+// ExchangeContext acts like Exchange, but honors the deadline on the provided
+// context, if present. If there is both a context deadline and a configured
+// timeout on the client, the earliest of the two takes effect.
+func (c *Client) ExchangeContext(ctx context.Context, m *Msg, a string) (
+ r *Msg,
+ rtt time.Duration,
+ err error) {
if !c.SingleInflight {
- return c.exchange(m, a)
+ return c.exchange(ctx, m, a)
}
// This adds a bunch of garbage, TODO(miek).
t := "nop"
@@ -119,7 +167,7 @@ func (c *Client) Exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro
cl = cl1
}
r, rtt, err, shared := c.group.Do(m.Question[0].Name+t+cl, func() (*Msg, time.Duration, error) {
- return c.exchange(m, a)
+ return c.exchange(ctx, m, a)
})
if r != nil && shared {
r = r.Copy()
@@ -154,7 +202,7 @@ func (c *Client) writeTimeout() time.Duration {
return dnsTimeout
}
-func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
+func (c *Client) exchange(ctx context.Context, m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
var co *Conn
network := "udp"
tls := false
@@ -180,10 +228,13 @@ func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro
deadline = time.Now().Add(c.Timeout)
}
+ dialDeadline := deadlineOrTimeoutOrCtx(ctx, deadline, c.dialTimeout())
+ dialTimeout := dialDeadline.Sub(time.Now())
+
if tls {
- co, err = DialTimeoutWithTLS(network, a, c.TLSConfig, c.dialTimeout())
+ co, err = DialTimeoutWithTLS(network, a, c.TLSConfig, dialTimeout)
} else {
- co, err = DialTimeout(network, a, c.dialTimeout())
+ co, err = DialTimeout(network, a, dialTimeout)
}
if err != nil {
@@ -202,12 +253,12 @@ func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro
}
co.TsigSecret = c.TsigSecret
- co.SetWriteDeadline(deadlineOrTimeout(deadline, c.writeTimeout()))
+ co.SetWriteDeadline(deadlineOrTimeoutOrCtx(ctx, deadline, c.writeTimeout()))
if err = co.WriteMsg(m); err != nil {
return nil, 0, err
}
- co.SetReadDeadline(deadlineOrTimeout(deadline, c.readTimeout()))
+ co.SetReadDeadline(deadlineOrTimeoutOrCtx(ctx, deadline, c.readTimeout()))
r, err = co.ReadMsg()
if err == nil && r.Id != m.Id {
err = ErrId
@@ -459,9 +510,22 @@ func DialTimeoutWithTLS(network, address string, tlsConfig *tls.Config, timeout
return conn, nil
}
+// deadlineOrTimeout chooses between the provided deadline and timeout
+// by always preferring the deadline so long as it's non-zero (regardless
+// of which is bigger), and returns the equivalent deadline value.
func deadlineOrTimeout(deadline time.Time, timeout time.Duration) time.Time {
if deadline.IsZero() {
return time.Now().Add(timeout)
}
return deadline
}
+
+// deadlineOrTimeoutOrCtx returns the earliest of: a context deadline, or the
+// output of deadlineOrtimeout.
+func deadlineOrTimeoutOrCtx(ctx context.Context, deadline time.Time, timeout time.Duration) time.Time {
+ result := deadlineOrTimeout(deadline, timeout)
+ if ctxDeadline, ok := ctx.Deadline(); ok && ctxDeadline.Before(result) {
+ result = ctxDeadline
+ }
+ return result
+}