1// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// +build darwin dragonfly freebsd linux netbsd openbsd solaris
6
7// DNS client: see RFC 1035.
8// Has to be linked into package net for Dial.
9
10// TODO(rsc):
11//	Could potentially handle many outstanding lookups faster.
12//	Could have a small cache.
13//	Random UDP source port (net.Dial should do that for us).
14//	Random request IDs.
15
16package net
17
18import (
19	"errors"
20	"io"
21	"math/rand"
22	"os"
23	"sync"
24	"time"
25)
26
27// A dnsDialer provides dialing suitable for DNS queries.
28type dnsDialer interface {
29	dialDNS(string, string) (dnsConn, error)
30}
31
32var testHookDNSDialer = func(d time.Duration) dnsDialer { return &Dialer{Timeout: d} }
33
34// A dnsConn represents a DNS transport endpoint.
35type dnsConn interface {
36	io.Closer
37
38	SetDeadline(time.Time) error
39
40	// readDNSResponse reads a DNS response message from the DNS
41	// transport endpoint and returns the received DNS response
42	// message.
43	readDNSResponse() (*dnsMsg, error)
44
45	// writeDNSQuery writes a DNS query message to the DNS
46	// connection endpoint.
47	writeDNSQuery(*dnsMsg) error
48}
49
50func (c *UDPConn) readDNSResponse() (*dnsMsg, error) {
51	b := make([]byte, 512) // see RFC 1035
52	n, err := c.Read(b)
53	if err != nil {
54		return nil, err
55	}
56	msg := &dnsMsg{}
57	if !msg.Unpack(b[:n]) {
58		return nil, errors.New("cannot unmarshal DNS message")
59	}
60	return msg, nil
61}
62
63func (c *UDPConn) writeDNSQuery(msg *dnsMsg) error {
64	b, ok := msg.Pack()
65	if !ok {
66		return errors.New("cannot marshal DNS message")
67	}
68	if _, err := c.Write(b); err != nil {
69		return err
70	}
71	return nil
72}
73
74func (c *TCPConn) readDNSResponse() (*dnsMsg, error) {
75	b := make([]byte, 1280) // 1280 is a reasonable initial size for IP over Ethernet, see RFC 4035
76	if _, err := io.ReadFull(c, b[:2]); err != nil {
77		return nil, err
78	}
79	l := int(b[0])<<8 | int(b[1])
80	if l > len(b) {
81		b = make([]byte, l)
82	}
83	n, err := io.ReadFull(c, b[:l])
84	if err != nil {
85		return nil, err
86	}
87	msg := &dnsMsg{}
88	if !msg.Unpack(b[:n]) {
89		return nil, errors.New("cannot unmarshal DNS message")
90	}
91	return msg, nil
92}
93
94func (c *TCPConn) writeDNSQuery(msg *dnsMsg) error {
95	b, ok := msg.Pack()
96	if !ok {
97		return errors.New("cannot marshal DNS message")
98	}
99	l := uint16(len(b))
100	b = append([]byte{byte(l >> 8), byte(l)}, b...)
101	if _, err := c.Write(b); err != nil {
102		return err
103	}
104	return nil
105}
106
107func (d *Dialer) dialDNS(network, server string) (dnsConn, error) {
108	switch network {
109	case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6":
110	default:
111		return nil, UnknownNetworkError(network)
112	}
113	// Calling Dial here is scary -- we have to be sure not to
114	// dial a name that will require a DNS lookup, or Dial will
115	// call back here to translate it. The DNS config parser has
116	// already checked that all the cfg.servers[i] are IP
117	// addresses, which Dial will use without a DNS lookup.
118	c, err := d.Dial(network, server)
119	if err != nil {
120		return nil, err
121	}
122	switch network {
123	case "tcp", "tcp4", "tcp6":
124		return c.(*TCPConn), nil
125	case "udp", "udp4", "udp6":
126		return c.(*UDPConn), nil
127	}
128	panic("unreachable")
129}
130
131// exchange sends a query on the connection and hopes for a response.
132func exchange(server, name string, qtype uint16, timeout time.Duration) (*dnsMsg, error) {
133	d := testHookDNSDialer(timeout)
134	out := dnsMsg{
135		dnsMsgHdr: dnsMsgHdr{
136			recursion_desired: true,
137		},
138		question: []dnsQuestion{
139			{name, qtype, dnsClassINET},
140		},
141	}
142	for _, network := range []string{"udp", "tcp"} {
143		c, err := d.dialDNS(network, server)
144		if err != nil {
145			return nil, err
146		}
147		defer c.Close()
148		if timeout > 0 {
149			c.SetDeadline(time.Now().Add(timeout))
150		}
151		out.id = uint16(rand.Int()) ^ uint16(time.Now().UnixNano())
152		if err := c.writeDNSQuery(&out); err != nil {
153			return nil, err
154		}
155		in, err := c.readDNSResponse()
156		if err != nil {
157			return nil, err
158		}
159		if in.id != out.id {
160			return nil, errors.New("DNS message ID mismatch")
161		}
162		if in.truncated { // see RFC 5966
163			continue
164		}
165		return in, nil
166	}
167	return nil, errors.New("no answer from DNS server")
168}
169
170// Do a lookup for a single name, which must be rooted
171// (otherwise answer will not find the answers).
172func tryOneName(cfg *dnsConfig, name string, qtype uint16) (string, []dnsRR, error) {
173	if len(cfg.servers) == 0 {
174		return "", nil, &DNSError{Err: "no DNS servers", Name: name}
175	}
176	timeout := time.Duration(cfg.timeout) * time.Second
177	var lastErr error
178	for i := 0; i < cfg.attempts; i++ {
179		for _, server := range cfg.servers {
180			server = JoinHostPort(server, "53")
181			msg, err := exchange(server, name, qtype, timeout)
182			if err != nil {
183				lastErr = &DNSError{
184					Err:    err.Error(),
185					Name:   name,
186					Server: server,
187				}
188				if nerr, ok := err.(Error); ok && nerr.Timeout() {
189					lastErr.(*DNSError).IsTimeout = true
190				}
191				continue
192			}
193			cname, rrs, err := answer(name, server, msg, qtype)
194			// If answer errored for rcodes dnsRcodeSuccess or dnsRcodeNameError,
195			// it means the response in msg was not useful and trying another
196			// server probably won't help. Return now in those cases.
197			// TODO: indicate this in a more obvious way, such as a field on DNSError?
198			if err == nil || msg.rcode == dnsRcodeSuccess || msg.rcode == dnsRcodeNameError {
199				return cname, rrs, err
200			}
201			lastErr = err
202		}
203	}
204	return "", nil, lastErr
205}
206
207// addrRecordList converts and returns a list of IP addresses from DNS
208// address records (both A and AAAA). Other record types are ignored.
209func addrRecordList(rrs []dnsRR) []IPAddr {
210	addrs := make([]IPAddr, 0, 4)
211	for _, rr := range rrs {
212		switch rr := rr.(type) {
213		case *dnsRR_A:
214			addrs = append(addrs, IPAddr{IP: IPv4(byte(rr.A>>24), byte(rr.A>>16), byte(rr.A>>8), byte(rr.A))})
215		case *dnsRR_AAAA:
216			ip := make(IP, IPv6len)
217			copy(ip, rr.AAAA[:])
218			addrs = append(addrs, IPAddr{IP: ip})
219		}
220	}
221	return addrs
222}
223
224// A resolverConfig represents a DNS stub resolver configuration.
225type resolverConfig struct {
226	initOnce sync.Once // guards init of resolverConfig
227
228	// ch is used as a semaphore that only allows one lookup at a
229	// time to recheck resolv.conf.
230	ch          chan struct{} // guards lastChecked and modTime
231	lastChecked time.Time     // last time resolv.conf was checked
232	modTime     time.Time     // time of resolv.conf modification
233
234	mu        sync.RWMutex // protects dnsConfig
235	dnsConfig *dnsConfig   // parsed resolv.conf structure used in lookups
236}
237
238var resolvConf resolverConfig
239
240// init initializes conf and is only called via conf.initOnce.
241func (conf *resolverConfig) init() {
242	// Set dnsConfig, modTime, and lastChecked so we don't parse
243	// resolv.conf twice the first time.
244	conf.dnsConfig = systemConf().resolv
245	if conf.dnsConfig == nil {
246		conf.dnsConfig = dnsReadConfig("/etc/resolv.conf")
247	}
248
249	if fi, err := os.Stat("/etc/resolv.conf"); err == nil {
250		conf.modTime = fi.ModTime()
251	}
252	conf.lastChecked = time.Now()
253
254	// Prepare ch so that only one update of resolverConfig may
255	// run at once.
256	conf.ch = make(chan struct{}, 1)
257}
258
259// tryUpdate tries to update conf with the named resolv.conf file.
260// The name variable only exists for testing. It is otherwise always
261// "/etc/resolv.conf".
262func (conf *resolverConfig) tryUpdate(name string) {
263	conf.initOnce.Do(conf.init)
264
265	// Ensure only one update at a time checks resolv.conf.
266	if !conf.tryAcquireSema() {
267		return
268	}
269	defer conf.releaseSema()
270
271	now := time.Now()
272	if conf.lastChecked.After(now.Add(-5 * time.Second)) {
273		return
274	}
275	conf.lastChecked = now
276
277	if fi, err := os.Stat(name); err == nil {
278		if fi.ModTime().Equal(conf.modTime) {
279			return
280		}
281		conf.modTime = fi.ModTime()
282	} else {
283		// If modTime wasn't set prior, assume nothing has changed.
284		if conf.modTime.IsZero() {
285			return
286		}
287		conf.modTime = time.Time{}
288	}
289
290	dnsConf := dnsReadConfig(name)
291	conf.mu.Lock()
292	conf.dnsConfig = dnsConf
293	conf.mu.Unlock()
294}
295
296func (conf *resolverConfig) tryAcquireSema() bool {
297	select {
298	case conf.ch <- struct{}{}:
299		return true
300	default:
301		return false
302	}
303}
304
305func (conf *resolverConfig) releaseSema() {
306	<-conf.ch
307}
308
309func lookup(name string, qtype uint16) (cname string, rrs []dnsRR, err error) {
310	if !isDomainName(name) {
311		return "", nil, &DNSError{Err: "invalid domain name", Name: name}
312	}
313	resolvConf.tryUpdate("/etc/resolv.conf")
314	resolvConf.mu.RLock()
315	conf := resolvConf.dnsConfig
316	resolvConf.mu.RUnlock()
317	for _, fqdn := range conf.nameList(name) {
318		cname, rrs, err = tryOneName(conf, fqdn, qtype)
319		if err == nil {
320			break
321		}
322	}
323	if err, ok := err.(*DNSError); ok {
324		// Show original name passed to lookup, not suffixed one.
325		// In general we might have tried many suffixes; showing
326		// just one is misleading. See also golang.org/issue/6324.
327		err.Name = name
328	}
329	return
330}
331
332// nameList returns a list of names for sequential DNS queries.
333func (conf *dnsConfig) nameList(name string) []string {
334	// If name is rooted (trailing dot), try only that name.
335	rooted := len(name) > 0 && name[len(name)-1] == '.'
336	if rooted {
337		return []string{name}
338	}
339	// Build list of search choices.
340	names := make([]string, 0, 1+len(conf.search))
341	// If name has enough dots, try unsuffixed first.
342	if count(name, '.') >= conf.ndots {
343		names = append(names, name+".")
344	}
345	// Try suffixes.
346	for _, suffix := range conf.search {
347		suffixed := name + "." + suffix
348		if suffixed[len(suffixed)-1] != '.' {
349			suffixed += "."
350		}
351		names = append(names, suffixed)
352	}
353	// Try unsuffixed, if not tried first above.
354	if count(name, '.') < conf.ndots {
355		names = append(names, name+".")
356	}
357	return names
358}
359
360// hostLookupOrder specifies the order of LookupHost lookup strategies.
361// It is basically a simplified representation of nsswitch.conf.
362// "files" means /etc/hosts.
363type hostLookupOrder int
364
365const (
366	// hostLookupCgo means defer to cgo.
367	hostLookupCgo      hostLookupOrder = iota
368	hostLookupFilesDNS                 // files first
369	hostLookupDNSFiles                 // dns first
370	hostLookupFiles                    // only files
371	hostLookupDNS                      // only DNS
372)
373
374var lookupOrderName = map[hostLookupOrder]string{
375	hostLookupCgo:      "cgo",
376	hostLookupFilesDNS: "files,dns",
377	hostLookupDNSFiles: "dns,files",
378	hostLookupFiles:    "files",
379	hostLookupDNS:      "dns",
380}
381
382func (o hostLookupOrder) String() string {
383	if s, ok := lookupOrderName[o]; ok {
384		return s
385	}
386	return "hostLookupOrder=" + itoa(int(o)) + "??"
387}
388
389// goLookupHost is the native Go implementation of LookupHost.
390// Used only if cgoLookupHost refuses to handle the request
391// (that is, only if cgoLookupHost is the stub in cgo_stub.go).
392// Normally we let cgo use the C library resolver instead of
393// depending on our lookup code, so that Go and C get the same
394// answers.
395func goLookupHost(name string) (addrs []string, err error) {
396	return goLookupHostOrder(name, hostLookupFilesDNS)
397}
398
399func goLookupHostOrder(name string, order hostLookupOrder) (addrs []string, err error) {
400	if order == hostLookupFilesDNS || order == hostLookupFiles {
401		// Use entries from /etc/hosts if they match.
402		addrs = lookupStaticHost(name)
403		if len(addrs) > 0 || order == hostLookupFiles {
404			return
405		}
406	}
407	ips, err := goLookupIPOrder(name, order)
408	if err != nil {
409		return
410	}
411	addrs = make([]string, 0, len(ips))
412	for _, ip := range ips {
413		addrs = append(addrs, ip.String())
414	}
415	return
416}
417
418// lookup entries from /etc/hosts
419func goLookupIPFiles(name string) (addrs []IPAddr) {
420	for _, haddr := range lookupStaticHost(name) {
421		haddr, zone := splitHostZone(haddr)
422		if ip := ParseIP(haddr); ip != nil {
423			addr := IPAddr{IP: ip, Zone: zone}
424			addrs = append(addrs, addr)
425		}
426	}
427	sortByRFC6724(addrs)
428	return
429}
430
431// goLookupIP is the native Go implementation of LookupIP.
432// The libc versions are in cgo_*.go.
433func goLookupIP(name string) (addrs []IPAddr, err error) {
434	return goLookupIPOrder(name, hostLookupFilesDNS)
435}
436
437func goLookupIPOrder(name string, order hostLookupOrder) (addrs []IPAddr, err error) {
438	if order == hostLookupFilesDNS || order == hostLookupFiles {
439		addrs = goLookupIPFiles(name)
440		if len(addrs) > 0 || order == hostLookupFiles {
441			return addrs, nil
442		}
443	}
444	if !isDomainName(name) {
445		return nil, &DNSError{Err: "invalid domain name", Name: name}
446	}
447	resolvConf.tryUpdate("/etc/resolv.conf")
448	resolvConf.mu.RLock()
449	conf := resolvConf.dnsConfig
450	resolvConf.mu.RUnlock()
451	type racer struct {
452		fqdn string
453		rrs  []dnsRR
454		error
455	}
456	lane := make(chan racer, 1)
457	qtypes := [...]uint16{dnsTypeA, dnsTypeAAAA}
458	var lastErr error
459	for _, fqdn := range conf.nameList(name) {
460		for _, qtype := range qtypes {
461			go func(qtype uint16) {
462				_, rrs, err := tryOneName(conf, fqdn, qtype)
463				lane <- racer{fqdn, rrs, err}
464			}(qtype)
465		}
466		for range qtypes {
467			racer := <-lane
468			if racer.error != nil {
469				// Prefer error for original name.
470				if lastErr == nil || racer.fqdn == name+"." {
471					lastErr = racer.error
472				}
473				continue
474			}
475			addrs = append(addrs, addrRecordList(racer.rrs)...)
476		}
477		if len(addrs) > 0 {
478			break
479		}
480	}
481	if lastErr, ok := lastErr.(*DNSError); ok {
482		// Show original name passed to lookup, not suffixed one.
483		// In general we might have tried many suffixes; showing
484		// just one is misleading. See also golang.org/issue/6324.
485		lastErr.Name = name
486	}
487	sortByRFC6724(addrs)
488	if len(addrs) == 0 {
489		if order == hostLookupDNSFiles {
490			addrs = goLookupIPFiles(name)
491		}
492		if len(addrs) == 0 && lastErr != nil {
493			return nil, lastErr
494		}
495	}
496	return addrs, nil
497}
498
499// goLookupCNAME is the native Go implementation of LookupCNAME.
500// Used only if cgoLookupCNAME refuses to handle the request
501// (that is, only if cgoLookupCNAME is the stub in cgo_stub.go).
502// Normally we let cgo use the C library resolver instead of
503// depending on our lookup code, so that Go and C get the same
504// answers.
505func goLookupCNAME(name string) (cname string, err error) {
506	_, rrs, err := lookup(name, dnsTypeCNAME)
507	if err != nil {
508		return
509	}
510	cname = rrs[0].(*dnsRR_CNAME).Cname
511	return
512}
513
514// goLookupPTR is the native Go implementation of LookupAddr.
515// Used only if cgoLookupPTR refuses to handle the request (that is,
516// only if cgoLookupPTR is the stub in cgo_stub.go).
517// Normally we let cgo use the C library resolver instead of depending
518// on our lookup code, so that Go and C get the same answers.
519func goLookupPTR(addr string) ([]string, error) {
520	names := lookupStaticAddr(addr)
521	if len(names) > 0 {
522		return names, nil
523	}
524	arpa, err := reverseaddr(addr)
525	if err != nil {
526		return nil, err
527	}
528	_, rrs, err := lookup(arpa, dnsTypePTR)
529	if err != nil {
530		return nil, err
531	}
532	ptrs := make([]string, len(rrs))
533	for i, rr := range rrs {
534		ptrs[i] = rr.(*dnsRR_PTR).Ptr
535	}
536	return ptrs, nil
537}
538