summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/miekg/dns/client.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/miekg/dns/client.go')
-rw-r--r--vendor/github.com/miekg/dns/client.go93
1 files changed, 93 insertions, 0 deletions
diff --git a/vendor/github.com/miekg/dns/client.go b/vendor/github.com/miekg/dns/client.go
index 282565afd..6aa4235d1 100644
--- a/vendor/github.com/miekg/dns/client.go
+++ b/vendor/github.com/miekg/dns/client.go
@@ -7,8 +7,12 @@ import (
"context"
"crypto/tls"
"encoding/binary"
+ "fmt"
"io"
+ "io/ioutil"
"net"
+ "net/http"
+ "net/url"
"strings"
"time"
)
@@ -16,6 +20,8 @@ import (
const dnsTimeout time.Duration = 2 * time.Second
const tcpIdleTimeout time.Duration = 8 * time.Second
+const dohMimeType = "application/dns-udpwireformat"
+
// A Conn represents a connection to a DNS server.
type Conn struct {
net.Conn // a net.Conn holding the connection
@@ -37,6 +43,7 @@ type Client struct {
DialTimeout time.Duration // net.DialTimeout, defaults to 2 seconds, or net.Dialer.Timeout if expiring earlier - overridden by Timeout when that value is non-zero
ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero
WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero
+ HTTPClient *http.Client // The http.Client to use for DNS-over-HTTPS
TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2)
SingleInflight bool // if true suppress multiple outstanding queries for the same Qname, Qtype and Qclass
group singleflight
@@ -134,6 +141,11 @@ func (c *Client) Dial(address string) (conn *Conn, err error) {
// attribute appropriately
func (c *Client) Exchange(m *Msg, address string) (r *Msg, rtt time.Duration, err error) {
if !c.SingleInflight {
+ if c.Net == "https" {
+ // TODO(tmthrgd): pipe timeouts into exchangeDOH
+ return c.exchangeDOH(context.TODO(), m, address)
+ }
+
return c.exchange(m, address)
}
@@ -146,6 +158,11 @@ func (c *Client) Exchange(m *Msg, address string) (r *Msg, rtt time.Duration, er
cl = cl1
}
r, rtt, err, shared := c.group.Do(m.Question[0].Name+t+cl, func() (*Msg, time.Duration, error) {
+ if c.Net == "https" {
+ // TODO(tmthrgd): pipe timeouts into exchangeDOH
+ return c.exchangeDOH(context.TODO(), m, address)
+ }
+
return c.exchange(m, address)
})
if r != nil && shared {
@@ -191,6 +208,77 @@ func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro
return r, rtt, err
}
+func (c *Client) exchangeDOH(ctx context.Context, m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
+ p, err := m.Pack()
+ if err != nil {
+ return nil, 0, err
+ }
+
+ // TODO(tmthrgd): Allow the path to be customised?
+ u := &url.URL{
+ Scheme: "https",
+ Host: a,
+ Path: "/.well-known/dns-query",
+ }
+ if u.Port() == "443" {
+ u.Host = u.Hostname()
+ }
+
+ req, err := http.NewRequest(http.MethodPost, u.String(), bytes.NewReader(p))
+ if err != nil {
+ return nil, 0, err
+ }
+
+ req.Header.Set("Content-Type", dohMimeType)
+ req.Header.Set("Accept", dohMimeType)
+
+ t := time.Now()
+
+ hc := http.DefaultClient
+ if c.HTTPClient != nil {
+ hc = c.HTTPClient
+ }
+
+ if ctx != context.Background() && ctx != context.TODO() {
+ req = req.WithContext(ctx)
+ }
+
+ resp, err := hc.Do(req)
+ if err != nil {
+ return nil, 0, err
+ }
+ defer closeHTTPBody(resp.Body)
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, 0, fmt.Errorf("dns: server returned HTTP %d error: %q", resp.StatusCode, resp.Status)
+ }
+
+ if ct := resp.Header.Get("Content-Type"); ct != dohMimeType {
+ return nil, 0, fmt.Errorf("dns: unexpected Content-Type %q; expected %q", ct, dohMimeType)
+ }
+
+ p, err = ioutil.ReadAll(resp.Body)
+ if err != nil {
+ return nil, 0, err
+ }
+
+ rtt = time.Since(t)
+
+ r = new(Msg)
+ if err := r.Unpack(p); err != nil {
+ return r, 0, err
+ }
+
+ // TODO: TSIG? Is it even supported over DoH?
+
+ return r, rtt, nil
+}
+
+func closeHTTPBody(r io.ReadCloser) error {
+ io.Copy(ioutil.Discard, io.LimitReader(r, 8<<20))
+ return r.Close()
+}
+
// ReadMsg reads a message from the connection co.
// If the received message contains a TSIG record the transaction signature
// is verified. This method always tries to return the message, however if an
@@ -490,6 +578,10 @@ func DialTimeoutWithTLS(network, address string, tlsConfig *tls.Config, timeout
// context, if present. If there is both a context deadline and a configured
// timeout on the client, the earliest of the two takes effect.
func (c *Client) ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
+ if !c.SingleInflight && c.Net == "https" {
+ return c.exchangeDOH(ctx, m, a)
+ }
+
var timeout time.Duration
if deadline, ok := ctx.Deadline(); !ok {
timeout = 0
@@ -498,6 +590,7 @@ func (c *Client) ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg,
}
// not passing the context to the underlying calls, as the API does not support
// context. For timeouts you should set up Client.Dialer and call Client.Exchange.
+ // TODO(tmthrgd): this is a race condition
c.Dialer = &net.Dialer{Timeout: timeout}
return c.Exchange(m, a)
}