1package dns
2
3// A client implementation.
4
5import (
6	"bytes"
7	"crypto/tls"
8	"encoding/binary"
9	"io"
10	"net"
11	"time"
12)
13
14const dnsTimeout time.Duration = 2 * time.Second
15const tcpIdleTimeout time.Duration = 8 * time.Second
16
17// A Conn represents a connection to a DNS server.
18type Conn struct {
19	net.Conn                         // a net.Conn holding the connection
20	UDPSize        uint16            // minimum receive buffer for UDP messages
21	TsigSecret     map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be fully qualified
22	rtt            time.Duration
23	t              time.Time
24	tsigRequestMAC string
25}
26
27// A Client defines parameters for a DNS client.
28type Client struct {
29	Net            string            // if "tcp" or "tcp-tls" (DNS over TLS) a TCP query will be initiated, otherwise an UDP one (default is "" for UDP)
30	UDPSize        uint16            // minimum receive buffer for UDP messages
31	TLSConfig      *tls.Config       // TLS connection configuration
32	Timeout        time.Duration     // a cumulative timeout for dial, write and read, defaults to 0 (disabled) - overrides DialTimeout, ReadTimeout and WriteTimeout when non-zero
33	DialTimeout    time.Duration     // net.DialTimeout, defaults to 2 seconds - overridden by Timeout when that value is non-zero
34	ReadTimeout    time.Duration     // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero
35	WriteTimeout   time.Duration     // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero
36	TsigSecret     map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be fully qualified
37	SingleInflight bool              // if true suppress multiple outstanding queries for the same Qname, Qtype and Qclass
38	group          singleflight
39}
40
41// Exchange performs a synchronous UDP query. It sends the message m to the address
42// contained in a and waits for a reply. Exchange does not retry a failed query, nor
43// will it fall back to TCP in case of truncation.
44// See client.Exchange for more information on setting larger buffer sizes.
45func Exchange(m *Msg, a string) (r *Msg, err error) {
46	var co *Conn
47	co, err = DialTimeout("udp", a, dnsTimeout)
48	if err != nil {
49		return nil, err
50	}
51
52	defer co.Close()
53
54	opt := m.IsEdns0()
55	// If EDNS0 is used use that for size.
56	if opt != nil && opt.UDPSize() >= MinMsgSize {
57		co.UDPSize = opt.UDPSize()
58	}
59
60	co.SetWriteDeadline(time.Now().Add(dnsTimeout))
61	if err = co.WriteMsg(m); err != nil {
62		return nil, err
63	}
64
65	co.SetReadDeadline(time.Now().Add(dnsTimeout))
66	r, err = co.ReadMsg()
67	if err == nil && r.Id != m.Id {
68		err = ErrId
69	}
70	return r, err
71}
72
73// ExchangeConn performs a synchronous query. It sends the message m via the connection
74// c and waits for a reply. The connection c is not closed by ExchangeConn.
75// This function is going away, but can easily be mimicked:
76//
77//	co := &dns.Conn{Conn: c} // c is your net.Conn
78//	co.WriteMsg(m)
79//	in, _  := co.ReadMsg()
80//	co.Close()
81//
82func ExchangeConn(c net.Conn, m *Msg) (r *Msg, err error) {
83	println("dns: this function is deprecated")
84	co := new(Conn)
85	co.Conn = c
86	if err = co.WriteMsg(m); err != nil {
87		return nil, err
88	}
89	r, err = co.ReadMsg()
90	if err == nil && r.Id != m.Id {
91		err = ErrId
92	}
93	return r, err
94}
95
96// Exchange performs a synchronous query. It sends the message m to the address
97// contained in a and waits for a reply. Basic use pattern with a *dns.Client:
98//
99//	c := new(dns.Client)
100//	in, rtt, err := c.Exchange(message, "127.0.0.1:53")
101//
102// Exchange does not retry a failed query, nor will it fall back to TCP in
103// case of truncation.
104// It is up to the caller to create a message that allows for larger responses to be
105// returned. Specifically this means adding an EDNS0 OPT RR that will advertise a larger
106// buffer, see SetEdns0. Messsages without an OPT RR will fallback to the historic limit
107// of 512 bytes.
108func (c *Client) Exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
109	if !c.SingleInflight {
110		return c.exchange(m, a)
111	}
112	// This adds a bunch of garbage, TODO(miek).
113	t := "nop"
114	if t1, ok := TypeToString[m.Question[0].Qtype]; ok {
115		t = t1
116	}
117	cl := "nop"
118	if cl1, ok := ClassToString[m.Question[0].Qclass]; ok {
119		cl = cl1
120	}
121	r, rtt, err, shared := c.group.Do(m.Question[0].Name+t+cl, func() (*Msg, time.Duration, error) {
122		return c.exchange(m, a)
123	})
124	if err != nil {
125		return r, rtt, err
126	}
127	if shared {
128		return r.Copy(), rtt, nil
129	}
130	return r, rtt, nil
131}
132
133func (c *Client) dialTimeout() time.Duration {
134	if c.Timeout != 0 {
135		return c.Timeout
136	}
137	if c.DialTimeout != 0 {
138		return c.DialTimeout
139	}
140	return dnsTimeout
141}
142
143func (c *Client) readTimeout() time.Duration {
144	if c.ReadTimeout != 0 {
145		return c.ReadTimeout
146	}
147	return dnsTimeout
148}
149
150func (c *Client) writeTimeout() time.Duration {
151	if c.WriteTimeout != 0 {
152		return c.WriteTimeout
153	}
154	return dnsTimeout
155}
156
157func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
158	var co *Conn
159	network := "udp"
160	tls := false
161
162	switch c.Net {
163	case "tcp-tls":
164		network = "tcp"
165		tls = true
166	case "tcp4-tls":
167		network = "tcp4"
168		tls = true
169	case "tcp6-tls":
170		network = "tcp6"
171		tls = true
172	default:
173		if c.Net != "" {
174			network = c.Net
175		}
176	}
177
178	var deadline time.Time
179	if c.Timeout != 0 {
180		deadline = time.Now().Add(c.Timeout)
181	}
182
183	if tls {
184		co, err = DialTimeoutWithTLS(network, a, c.TLSConfig, c.dialTimeout())
185	} else {
186		co, err = DialTimeout(network, a, c.dialTimeout())
187	}
188
189	if err != nil {
190		return nil, 0, err
191	}
192	defer co.Close()
193
194	opt := m.IsEdns0()
195	// If EDNS0 is used use that for size.
196	if opt != nil && opt.UDPSize() >= MinMsgSize {
197		co.UDPSize = opt.UDPSize()
198	}
199	// Otherwise use the client's configured UDP size.
200	if opt == nil && c.UDPSize >= MinMsgSize {
201		co.UDPSize = c.UDPSize
202	}
203
204	co.TsigSecret = c.TsigSecret
205	co.SetWriteDeadline(deadlineOrTimeout(deadline, c.writeTimeout()))
206	if err = co.WriteMsg(m); err != nil {
207		return nil, 0, err
208	}
209
210	co.SetReadDeadline(deadlineOrTimeout(deadline, c.readTimeout()))
211	r, err = co.ReadMsg()
212	if err == nil && r.Id != m.Id {
213		err = ErrId
214	}
215	return r, co.rtt, err
216}
217
218// ReadMsg reads a message from the connection co.
219// If the received message contains a TSIG record the transaction
220// signature is verified.
221func (co *Conn) ReadMsg() (*Msg, error) {
222	p, err := co.ReadMsgHeader(nil)
223	if err != nil {
224		return nil, err
225	}
226
227	m := new(Msg)
228	if err := m.Unpack(p); err != nil {
229		// If ErrTruncated was returned, we still want to allow the user to use
230		// the message, but naively they can just check err if they don't want
231		// to use a truncated message
232		if err == ErrTruncated {
233			return m, err
234		}
235		return nil, err
236	}
237	if t := m.IsTsig(); t != nil {
238		if _, ok := co.TsigSecret[t.Hdr.Name]; !ok {
239			return m, ErrSecret
240		}
241		// Need to work on the original message p, as that was used to calculate the tsig.
242		err = TsigVerify(p, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false)
243	}
244	return m, err
245}
246
247// ReadMsgHeader reads a DNS message, parses and populates hdr (when hdr is not nil).
248// Returns message as a byte slice to be parsed with Msg.Unpack later on.
249// Note that error handling on the message body is not possible as only the header is parsed.
250func (co *Conn) ReadMsgHeader(hdr *Header) ([]byte, error) {
251	var (
252		p   []byte
253		n   int
254		err error
255	)
256
257	switch t := co.Conn.(type) {
258	case *net.TCPConn, *tls.Conn:
259		r := t.(io.Reader)
260
261		// First two bytes specify the length of the entire message.
262		l, err := tcpMsgLen(r)
263		if err != nil {
264			return nil, err
265		}
266		p = make([]byte, l)
267		n, err = tcpRead(r, p)
268		co.rtt = time.Since(co.t)
269	default:
270		if co.UDPSize > MinMsgSize {
271			p = make([]byte, co.UDPSize)
272		} else {
273			p = make([]byte, MinMsgSize)
274		}
275		n, err = co.Read(p)
276		co.rtt = time.Since(co.t)
277	}
278
279	if err != nil {
280		return nil, err
281	} else if n < headerSize {
282		return nil, ErrShortRead
283	}
284
285	p = p[:n]
286	if hdr != nil {
287		dh, _, err := unpackMsgHdr(p, 0)
288		if err != nil {
289			return nil, err
290		}
291		*hdr = dh
292	}
293	return p, err
294}
295
296// tcpMsgLen is a helper func to read first two bytes of stream as uint16 packet length.
297func tcpMsgLen(t io.Reader) (int, error) {
298	p := []byte{0, 0}
299	n, err := t.Read(p)
300	if err != nil {
301		return 0, err
302	}
303	if n != 2 {
304		return 0, ErrShortRead
305	}
306	l := binary.BigEndian.Uint16(p)
307	if l == 0 {
308		return 0, ErrShortRead
309	}
310	return int(l), nil
311}
312
313// tcpRead calls TCPConn.Read enough times to fill allocated buffer.
314func tcpRead(t io.Reader, p []byte) (int, error) {
315	n, err := t.Read(p)
316	if err != nil {
317		return n, err
318	}
319	for n < len(p) {
320		j, err := t.Read(p[n:])
321		if err != nil {
322			return n, err
323		}
324		n += j
325	}
326	return n, err
327}
328
329// Read implements the net.Conn read method.
330func (co *Conn) Read(p []byte) (n int, err error) {
331	if co.Conn == nil {
332		return 0, ErrConnEmpty
333	}
334	if len(p) < 2 {
335		return 0, io.ErrShortBuffer
336	}
337	switch t := co.Conn.(type) {
338	case *net.TCPConn, *tls.Conn:
339		r := t.(io.Reader)
340
341		l, err := tcpMsgLen(r)
342		if err != nil {
343			return 0, err
344		}
345		if l > len(p) {
346			return int(l), io.ErrShortBuffer
347		}
348		return tcpRead(r, p[:l])
349	}
350	// UDP connection
351	n, err = co.Conn.Read(p)
352	if err != nil {
353		return n, err
354	}
355	return n, err
356}
357
358// WriteMsg sends a message through the connection co.
359// If the message m contains a TSIG record the transaction
360// signature is calculated.
361func (co *Conn) WriteMsg(m *Msg) (err error) {
362	var out []byte
363	if t := m.IsTsig(); t != nil {
364		mac := ""
365		if _, ok := co.TsigSecret[t.Hdr.Name]; !ok {
366			return ErrSecret
367		}
368		out, mac, err = TsigGenerate(m, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false)
369		// Set for the next read, although only used in zone transfers
370		co.tsigRequestMAC = mac
371	} else {
372		out, err = m.Pack()
373	}
374	if err != nil {
375		return err
376	}
377	co.t = time.Now()
378	if _, err = co.Write(out); err != nil {
379		return err
380	}
381	return nil
382}
383
384// Write implements the net.Conn Write method.
385func (co *Conn) Write(p []byte) (n int, err error) {
386	switch t := co.Conn.(type) {
387	case *net.TCPConn, *tls.Conn:
388		w := t.(io.Writer)
389
390		lp := len(p)
391		if lp < 2 {
392			return 0, io.ErrShortBuffer
393		}
394		if lp > MaxMsgSize {
395			return 0, &Error{err: "message too large"}
396		}
397		l := make([]byte, 2, lp+2)
398		binary.BigEndian.PutUint16(l, uint16(lp))
399		p = append(l, p...)
400		n, err := io.Copy(w, bytes.NewReader(p))
401		return int(n), err
402	}
403	n, err = co.Conn.(*net.UDPConn).Write(p)
404	return n, err
405}
406
407// Dial connects to the address on the named network.
408func Dial(network, address string) (conn *Conn, err error) {
409	conn = new(Conn)
410	conn.Conn, err = net.Dial(network, address)
411	if err != nil {
412		return nil, err
413	}
414	return conn, nil
415}
416
417// DialTimeout acts like Dial but takes a timeout.
418func DialTimeout(network, address string, timeout time.Duration) (conn *Conn, err error) {
419	conn = new(Conn)
420	conn.Conn, err = net.DialTimeout(network, address, timeout)
421	if err != nil {
422		return nil, err
423	}
424	return conn, nil
425}
426
427// DialWithTLS connects to the address on the named network with TLS.
428func DialWithTLS(network, address string, tlsConfig *tls.Config) (conn *Conn, err error) {
429	conn = new(Conn)
430	conn.Conn, err = tls.Dial(network, address, tlsConfig)
431	if err != nil {
432		return nil, err
433	}
434	return conn, nil
435}
436
437// DialTimeoutWithTLS acts like DialWithTLS but takes a timeout.
438func DialTimeoutWithTLS(network, address string, tlsConfig *tls.Config, timeout time.Duration) (conn *Conn, err error) {
439	var dialer net.Dialer
440	dialer.Timeout = timeout
441
442	conn = new(Conn)
443	conn.Conn, err = tls.DialWithDialer(&dialer, network, address, tlsConfig)
444	if err != nil {
445		return nil, err
446	}
447	return conn, nil
448}
449
450func deadlineOrTimeout(deadline time.Time, timeout time.Duration) time.Time {
451	if deadline.IsZero() {
452		return time.Now().Add(timeout)
453	}
454	return deadline
455}
456