1package dns
2
3// A client implementation.
4
5import (
6	"bytes"
7	"context"
8	"crypto/tls"
9	"encoding/binary"
10	"fmt"
11	"io"
12	"io/ioutil"
13	"net"
14	"net/http"
15	"strings"
16	"time"
17)
18
19const (
20	dnsTimeout     time.Duration = 2 * time.Second
21	tcpIdleTimeout time.Duration = 8 * time.Second
22
23	dohMimeType = "application/dns-message"
24)
25
26// A Conn represents a connection to a DNS server.
27type Conn struct {
28	net.Conn                         // a net.Conn holding the connection
29	UDPSize        uint16            // minimum receive buffer for UDP messages
30	TsigSecret     map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2)
31	tsigRequestMAC string
32}
33
34// A Client defines parameters for a DNS client.
35type Client struct {
36	Net       string      // if "tcp" or "tcp-tls" (DNS over TLS) a TCP query will be initiated, otherwise an UDP one (default is "" for UDP)
37	UDPSize   uint16      // minimum receive buffer for UDP messages
38	TLSConfig *tls.Config // TLS connection configuration
39	Dialer    *net.Dialer // a net.Dialer used to set local address, timeouts and more
40	// Timeout is a cumulative timeout for dial, write and read, defaults to 0 (disabled) - overrides DialTimeout, ReadTimeout,
41	// WriteTimeout when non-zero. Can be overridden with net.Dialer.Timeout (see Client.ExchangeWithDialer and
42	// Client.Dialer) or context.Context.Deadline (see the deprecated ExchangeContext)
43	Timeout        time.Duration
44	DialTimeout    time.Duration     // net.DialTimeout, defaults to 2 seconds, or net.Dialer.Timeout if expiring earlier - overridden by Timeout when that value is non-zero
45	ReadTimeout    time.Duration     // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero
46	WriteTimeout   time.Duration     // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero
47	HTTPClient     *http.Client      // The http.Client to use for DNS-over-HTTPS
48	TsigSecret     map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2)
49	SingleInflight bool              // if true suppress multiple outstanding queries for the same Qname, Qtype and Qclass
50	group          singleflight
51}
52
53// Exchange performs a synchronous UDP query. It sends the message m to the address
54// contained in a and waits for a reply. Exchange does not retry a failed query, nor
55// will it fall back to TCP in case of truncation.
56// See client.Exchange for more information on setting larger buffer sizes.
57func Exchange(m *Msg, a string) (r *Msg, err error) {
58	client := Client{Net: "udp"}
59	r, _, err = client.Exchange(m, a)
60	return r, err
61}
62
63func (c *Client) dialTimeout() time.Duration {
64	if c.Timeout != 0 {
65		return c.Timeout
66	}
67	if c.DialTimeout != 0 {
68		return c.DialTimeout
69	}
70	return dnsTimeout
71}
72
73func (c *Client) readTimeout() time.Duration {
74	if c.ReadTimeout != 0 {
75		return c.ReadTimeout
76	}
77	return dnsTimeout
78}
79
80func (c *Client) writeTimeout() time.Duration {
81	if c.WriteTimeout != 0 {
82		return c.WriteTimeout
83	}
84	return dnsTimeout
85}
86
87// Dial connects to the address on the named network.
88func (c *Client) Dial(address string) (conn *Conn, err error) {
89	// create a new dialer with the appropriate timeout
90	var d net.Dialer
91	if c.Dialer == nil {
92		d = net.Dialer{Timeout: c.getTimeoutForRequest(c.dialTimeout())}
93	} else {
94		d = *c.Dialer
95	}
96
97	network := c.Net
98	if network == "" {
99		network = "udp"
100	}
101
102	useTLS := strings.HasPrefix(network, "tcp") && strings.HasSuffix(network, "-tls")
103
104	conn = new(Conn)
105	if useTLS {
106		network = strings.TrimSuffix(network, "-tls")
107
108		conn.Conn, err = tls.DialWithDialer(&d, network, address, c.TLSConfig)
109	} else {
110		conn.Conn, err = d.Dial(network, address)
111	}
112	if err != nil {
113		return nil, err
114	}
115
116	return conn, nil
117}
118
119// Exchange performs a synchronous query. It sends the message m to the address
120// contained in a and waits for a reply. Basic use pattern with a *dns.Client:
121//
122//	c := new(dns.Client)
123//	in, rtt, err := c.Exchange(message, "127.0.0.1:53")
124//
125// Exchange does not retry a failed query, nor will it fall back to TCP in
126// case of truncation.
127// It is up to the caller to create a message that allows for larger responses to be
128// returned. Specifically this means adding an EDNS0 OPT RR that will advertise a larger
129// buffer, see SetEdns0. Messages without an OPT RR will fallback to the historic limit
130// of 512 bytes
131// To specify a local address or a timeout, the caller has to set the `Client.Dialer`
132// attribute appropriately
133func (c *Client) Exchange(m *Msg, address string) (r *Msg, rtt time.Duration, err error) {
134	if !c.SingleInflight {
135		if c.Net == "https" {
136			// TODO(tmthrgd): pipe timeouts into exchangeDOH
137			return c.exchangeDOH(context.TODO(), m, address)
138		}
139
140		return c.exchange(m, address)
141	}
142
143	t := "nop"
144	if t1, ok := TypeToString[m.Question[0].Qtype]; ok {
145		t = t1
146	}
147	cl := "nop"
148	if cl1, ok := ClassToString[m.Question[0].Qclass]; ok {
149		cl = cl1
150	}
151	r, rtt, err, shared := c.group.Do(m.Question[0].Name+t+cl, func() (*Msg, time.Duration, error) {
152		if c.Net == "https" {
153			// TODO(tmthrgd): pipe timeouts into exchangeDOH
154			return c.exchangeDOH(context.TODO(), m, address)
155		}
156
157		return c.exchange(m, address)
158	})
159	if r != nil && shared {
160		r = r.Copy()
161	}
162	return r, rtt, err
163}
164
165func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
166	var co *Conn
167
168	co, err = c.Dial(a)
169
170	if err != nil {
171		return nil, 0, err
172	}
173	defer co.Close()
174
175	opt := m.IsEdns0()
176	// If EDNS0 is used use that for size.
177	if opt != nil && opt.UDPSize() >= MinMsgSize {
178		co.UDPSize = opt.UDPSize()
179	}
180	// Otherwise use the client's configured UDP size.
181	if opt == nil && c.UDPSize >= MinMsgSize {
182		co.UDPSize = c.UDPSize
183	}
184
185	co.TsigSecret = c.TsigSecret
186	t := time.Now()
187	// write with the appropriate write timeout
188	co.SetWriteDeadline(t.Add(c.getTimeoutForRequest(c.writeTimeout())))
189	if err = co.WriteMsg(m); err != nil {
190		return nil, 0, err
191	}
192
193	co.SetReadDeadline(time.Now().Add(c.getTimeoutForRequest(c.readTimeout())))
194	r, err = co.ReadMsg()
195	if err == nil && r.Id != m.Id {
196		err = ErrId
197	}
198	rtt = time.Since(t)
199	return r, rtt, err
200}
201
202func (c *Client) exchangeDOH(ctx context.Context, m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
203	p, err := m.Pack()
204	if err != nil {
205		return nil, 0, err
206	}
207
208	req, err := http.NewRequest(http.MethodPost, a, bytes.NewReader(p))
209	if err != nil {
210		return nil, 0, err
211	}
212
213	req.Header.Set("Content-Type", dohMimeType)
214	req.Header.Set("Accept", dohMimeType)
215
216	hc := http.DefaultClient
217	if c.HTTPClient != nil {
218		hc = c.HTTPClient
219	}
220
221	if ctx != context.Background() && ctx != context.TODO() {
222		req = req.WithContext(ctx)
223	}
224
225	t := time.Now()
226
227	resp, err := hc.Do(req)
228	if err != nil {
229		return nil, 0, err
230	}
231	defer closeHTTPBody(resp.Body)
232
233	if resp.StatusCode != http.StatusOK {
234		return nil, 0, fmt.Errorf("dns: server returned HTTP %d error: %q", resp.StatusCode, resp.Status)
235	}
236
237	if ct := resp.Header.Get("Content-Type"); ct != dohMimeType {
238		return nil, 0, fmt.Errorf("dns: unexpected Content-Type %q; expected %q", ct, dohMimeType)
239	}
240
241	p, err = ioutil.ReadAll(resp.Body)
242	if err != nil {
243		return nil, 0, err
244	}
245
246	rtt = time.Since(t)
247
248	r = new(Msg)
249	if err := r.Unpack(p); err != nil {
250		return r, 0, err
251	}
252
253	// TODO: TSIG? Is it even supported over DoH?
254
255	return r, rtt, nil
256}
257
258func closeHTTPBody(r io.ReadCloser) error {
259	io.Copy(ioutil.Discard, io.LimitReader(r, 8<<20))
260	return r.Close()
261}
262
263// ReadMsg reads a message from the connection co.
264// If the received message contains a TSIG record the transaction signature
265// is verified. This method always tries to return the message, however if an
266// error is returned there are no guarantees that the returned message is a
267// valid representation of the packet read.
268func (co *Conn) ReadMsg() (*Msg, error) {
269	p, err := co.ReadMsgHeader(nil)
270	if err != nil {
271		return nil, err
272	}
273
274	m := new(Msg)
275	if err := m.Unpack(p); err != nil {
276		// If an error was returned, we still want to allow the user to use
277		// the message, but naively they can just check err if they don't want
278		// to use an erroneous message
279		return m, err
280	}
281	if t := m.IsTsig(); t != nil {
282		if _, ok := co.TsigSecret[t.Hdr.Name]; !ok {
283			return m, ErrSecret
284		}
285		// Need to work on the original message p, as that was used to calculate the tsig.
286		err = TsigVerify(p, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false)
287	}
288	return m, err
289}
290
291// ReadMsgHeader reads a DNS message, parses and populates hdr (when hdr is not nil).
292// Returns message as a byte slice to be parsed with Msg.Unpack later on.
293// Note that error handling on the message body is not possible as only the header is parsed.
294func (co *Conn) ReadMsgHeader(hdr *Header) ([]byte, error) {
295	var (
296		p   []byte
297		n   int
298		err error
299	)
300
301	switch t := co.Conn.(type) {
302	case *net.TCPConn, *tls.Conn:
303		r := t.(io.Reader)
304
305		// First two bytes specify the length of the entire message.
306		l, err := tcpMsgLen(r)
307		if err != nil {
308			return nil, err
309		}
310		p = make([]byte, l)
311		n, err = tcpRead(r, p)
312	default:
313		if co.UDPSize > MinMsgSize {
314			p = make([]byte, co.UDPSize)
315		} else {
316			p = make([]byte, MinMsgSize)
317		}
318		n, err = co.Read(p)
319	}
320
321	if err != nil {
322		return nil, err
323	} else if n < headerSize {
324		return nil, ErrShortRead
325	}
326
327	p = p[:n]
328	if hdr != nil {
329		dh, _, err := unpackMsgHdr(p, 0)
330		if err != nil {
331			return nil, err
332		}
333		*hdr = dh
334	}
335	return p, err
336}
337
338// tcpMsgLen is a helper func to read first two bytes of stream as uint16 packet length.
339func tcpMsgLen(t io.Reader) (int, error) {
340	p := []byte{0, 0}
341	n, err := t.Read(p)
342	if err != nil {
343		return 0, err
344	}
345
346	// As seen with my local router/switch, returns 1 byte on the above read,
347	// resulting a a ShortRead. Just write it out (instead of loop) and read the
348	// other byte.
349	if n == 1 {
350		n1, err := t.Read(p[1:])
351		if err != nil {
352			return 0, err
353		}
354		n += n1
355	}
356
357	if n != 2 {
358		return 0, ErrShortRead
359	}
360	l := binary.BigEndian.Uint16(p)
361	if l == 0 {
362		return 0, ErrShortRead
363	}
364	return int(l), nil
365}
366
367// tcpRead calls TCPConn.Read enough times to fill allocated buffer.
368func tcpRead(t io.Reader, p []byte) (int, error) {
369	n, err := t.Read(p)
370	if err != nil {
371		return n, err
372	}
373	for n < len(p) {
374		j, err := t.Read(p[n:])
375		if err != nil {
376			return n, err
377		}
378		n += j
379	}
380	return n, err
381}
382
383// Read implements the net.Conn read method.
384func (co *Conn) Read(p []byte) (n int, err error) {
385	if co.Conn == nil {
386		return 0, ErrConnEmpty
387	}
388	if len(p) < 2 {
389		return 0, io.ErrShortBuffer
390	}
391	switch t := co.Conn.(type) {
392	case *net.TCPConn, *tls.Conn:
393		r := t.(io.Reader)
394
395		l, err := tcpMsgLen(r)
396		if err != nil {
397			return 0, err
398		}
399		if l > len(p) {
400			return int(l), io.ErrShortBuffer
401		}
402		return tcpRead(r, p[:l])
403	}
404	// UDP connection
405	n, err = co.Conn.Read(p)
406	if err != nil {
407		return n, err
408	}
409	return n, err
410}
411
412// WriteMsg sends a message through the connection co.
413// If the message m contains a TSIG record the transaction
414// signature is calculated.
415func (co *Conn) WriteMsg(m *Msg) (err error) {
416	var out []byte
417	if t := m.IsTsig(); t != nil {
418		mac := ""
419		if _, ok := co.TsigSecret[t.Hdr.Name]; !ok {
420			return ErrSecret
421		}
422		out, mac, err = TsigGenerate(m, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false)
423		// Set for the next read, although only used in zone transfers
424		co.tsigRequestMAC = mac
425	} else {
426		out, err = m.Pack()
427	}
428	if err != nil {
429		return err
430	}
431	if _, err = co.Write(out); err != nil {
432		return err
433	}
434	return nil
435}
436
437// Write implements the net.Conn Write method.
438func (co *Conn) Write(p []byte) (n int, err error) {
439	switch t := co.Conn.(type) {
440	case *net.TCPConn, *tls.Conn:
441		w := t.(io.Writer)
442
443		lp := len(p)
444		if lp < 2 {
445			return 0, io.ErrShortBuffer
446		}
447		if lp > MaxMsgSize {
448			return 0, &Error{err: "message too large"}
449		}
450		l := make([]byte, 2, lp+2)
451		binary.BigEndian.PutUint16(l, uint16(lp))
452		p = append(l, p...)
453		n, err := io.Copy(w, bytes.NewReader(p))
454		return int(n), err
455	}
456	n, err = co.Conn.Write(p)
457	return n, err
458}
459
460// Return the appropriate timeout for a specific request
461func (c *Client) getTimeoutForRequest(timeout time.Duration) time.Duration {
462	var requestTimeout time.Duration
463	if c.Timeout != 0 {
464		requestTimeout = c.Timeout
465	} else {
466		requestTimeout = timeout
467	}
468	// net.Dialer.Timeout has priority if smaller than the timeouts computed so
469	// far
470	if c.Dialer != nil && c.Dialer.Timeout != 0 {
471		if c.Dialer.Timeout < requestTimeout {
472			requestTimeout = c.Dialer.Timeout
473		}
474	}
475	return requestTimeout
476}
477
478// Dial connects to the address on the named network.
479func Dial(network, address string) (conn *Conn, err error) {
480	conn = new(Conn)
481	conn.Conn, err = net.Dial(network, address)
482	if err != nil {
483		return nil, err
484	}
485	return conn, nil
486}
487
488// ExchangeContext performs a synchronous UDP query, like Exchange. It
489// additionally obeys deadlines from the passed Context.
490func ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg, err error) {
491	client := Client{Net: "udp"}
492	r, _, err = client.ExchangeContext(ctx, m, a)
493	// ignorint rtt to leave the original ExchangeContext API unchanged, but
494	// this function will go away
495	return r, err
496}
497
498// ExchangeConn performs a synchronous query. It sends the message m via the connection
499// c and waits for a reply. The connection c is not closed by ExchangeConn.
500// This function is going away, but can easily be mimicked:
501//
502//	co := &dns.Conn{Conn: c} // c is your net.Conn
503//	co.WriteMsg(m)
504//	in, _  := co.ReadMsg()
505//	co.Close()
506//
507func ExchangeConn(c net.Conn, m *Msg) (r *Msg, err error) {
508	println("dns: ExchangeConn: this function is deprecated")
509	co := new(Conn)
510	co.Conn = c
511	if err = co.WriteMsg(m); err != nil {
512		return nil, err
513	}
514	r, err = co.ReadMsg()
515	if err == nil && r.Id != m.Id {
516		err = ErrId
517	}
518	return r, err
519}
520
521// DialTimeout acts like Dial but takes a timeout.
522func DialTimeout(network, address string, timeout time.Duration) (conn *Conn, err error) {
523	client := Client{Net: network, Dialer: &net.Dialer{Timeout: timeout}}
524	conn, err = client.Dial(address)
525	if err != nil {
526		return nil, err
527	}
528	return conn, nil
529}
530
531// DialWithTLS connects to the address on the named network with TLS.
532func DialWithTLS(network, address string, tlsConfig *tls.Config) (conn *Conn, err error) {
533	if !strings.HasSuffix(network, "-tls") {
534		network += "-tls"
535	}
536	client := Client{Net: network, TLSConfig: tlsConfig}
537	conn, err = client.Dial(address)
538
539	if err != nil {
540		return nil, err
541	}
542	return conn, nil
543}
544
545// DialTimeoutWithTLS acts like DialWithTLS but takes a timeout.
546func DialTimeoutWithTLS(network, address string, tlsConfig *tls.Config, timeout time.Duration) (conn *Conn, err error) {
547	if !strings.HasSuffix(network, "-tls") {
548		network += "-tls"
549	}
550	client := Client{Net: network, Dialer: &net.Dialer{Timeout: timeout}, TLSConfig: tlsConfig}
551	conn, err = client.Dial(address)
552	if err != nil {
553		return nil, err
554	}
555	return conn, nil
556}
557
558// ExchangeContext acts like Exchange, but honors the deadline on the provided
559// context, if present. If there is both a context deadline and a configured
560// timeout on the client, the earliest of the two takes effect.
561func (c *Client) ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
562	if !c.SingleInflight && c.Net == "https" {
563		return c.exchangeDOH(ctx, m, a)
564	}
565
566	var timeout time.Duration
567	if deadline, ok := ctx.Deadline(); !ok {
568		timeout = 0
569	} else {
570		timeout = time.Until(deadline)
571	}
572	// not passing the context to the underlying calls, as the API does not support
573	// context. For timeouts you should set up Client.Dialer and call Client.Exchange.
574	// TODO(tmthrgd): this is a race condition
575	c.Dialer = &net.Dialer{Timeout: timeout}
576	return c.Exchange(m, a)
577}
578