1// Copyright 2013 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// +build darwin dragonfly freebsd linux netbsd openbsd solaris
6
7package net
8
9import (
10	"context"
11	"errors"
12	"fmt"
13	"internal/poll"
14	"io/ioutil"
15	"os"
16	"path"
17	"reflect"
18	"strings"
19	"sync"
20	"testing"
21	"time"
22)
23
24var goResolver = Resolver{PreferGo: true}
25
26// Test address from 192.0.2.0/24 block, reserved by RFC 5737 for documentation.
27const TestAddr uint32 = 0xc0000201
28
29// Test address from 2001:db8::/32 block, reserved by RFC 3849 for documentation.
30var TestAddr6 = [16]byte{0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
31
32var dnsTransportFallbackTests = []struct {
33	server  string
34	name    string
35	qtype   uint16
36	timeout int
37	rcode   int
38}{
39	// Querying "com." with qtype=255 usually makes an answer
40	// which requires more than 512 bytes.
41	{"8.8.8.8:53", "com.", dnsTypeALL, 2, dnsRcodeSuccess},
42	{"8.8.4.4:53", "com.", dnsTypeALL, 4, dnsRcodeSuccess},
43}
44
45func TestDNSTransportFallback(t *testing.T) {
46	fake := fakeDNSServer{
47		rh: func(n, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
48			r := &dnsMsg{
49				dnsMsgHdr: dnsMsgHdr{
50					id:       q.id,
51					response: true,
52					rcode:    dnsRcodeSuccess,
53				},
54				question: q.question,
55			}
56			if n == "udp" {
57				r.truncated = true
58			}
59			return r, nil
60		},
61	}
62	r := Resolver{PreferGo: true, Dial: fake.DialContext}
63	for _, tt := range dnsTransportFallbackTests {
64		ctx, cancel := context.WithCancel(context.Background())
65		defer cancel()
66		msg, err := r.exchange(ctx, tt.server, tt.name, tt.qtype, time.Second)
67		if err != nil {
68			t.Error(err)
69			continue
70		}
71		switch msg.rcode {
72		case tt.rcode:
73		default:
74			t.Errorf("got %v from %v; want %v", msg.rcode, tt.server, tt.rcode)
75			continue
76		}
77	}
78}
79
80// See RFC 6761 for further information about the reserved, pseudo
81// domain names.
82var specialDomainNameTests = []struct {
83	name  string
84	qtype uint16
85	rcode int
86}{
87	// Name resolution APIs and libraries should not recognize the
88	// followings as special.
89	{"1.0.168.192.in-addr.arpa.", dnsTypePTR, dnsRcodeNameError},
90	{"test.", dnsTypeALL, dnsRcodeNameError},
91	{"example.com.", dnsTypeALL, dnsRcodeSuccess},
92
93	// Name resolution APIs and libraries should recognize the
94	// followings as special and should not send any queries.
95	// Though, we test those names here for verifying negative
96	// answers at DNS query-response interaction level.
97	{"localhost.", dnsTypeALL, dnsRcodeNameError},
98	{"invalid.", dnsTypeALL, dnsRcodeNameError},
99}
100
101func TestSpecialDomainName(t *testing.T) {
102	fake := fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
103		r := &dnsMsg{
104			dnsMsgHdr: dnsMsgHdr{
105				id:       q.id,
106				response: true,
107			},
108			question: q.question,
109		}
110
111		switch q.question[0].Name {
112		case "example.com.":
113			r.rcode = dnsRcodeSuccess
114		default:
115			r.rcode = dnsRcodeNameError
116		}
117
118		return r, nil
119	}}
120	r := Resolver{PreferGo: true, Dial: fake.DialContext}
121	server := "8.8.8.8:53"
122	for _, tt := range specialDomainNameTests {
123		ctx, cancel := context.WithCancel(context.Background())
124		defer cancel()
125		msg, err := r.exchange(ctx, server, tt.name, tt.qtype, 3*time.Second)
126		if err != nil {
127			t.Error(err)
128			continue
129		}
130		switch msg.rcode {
131		case tt.rcode, dnsRcodeServerFailure:
132		default:
133			t.Errorf("got %v from %v; want %v", msg.rcode, server, tt.rcode)
134			continue
135		}
136	}
137}
138
139// Issue 13705: don't try to resolve onion addresses, etc
140func TestAvoidDNSName(t *testing.T) {
141	tests := []struct {
142		name  string
143		avoid bool
144	}{
145		{"foo.com", false},
146		{"foo.com.", false},
147
148		{"foo.onion.", true},
149		{"foo.onion", true},
150		{"foo.ONION", true},
151		{"foo.ONION.", true},
152
153		// But do resolve *.local address; Issue 16739
154		{"foo.local.", false},
155		{"foo.local", false},
156		{"foo.LOCAL", false},
157		{"foo.LOCAL.", false},
158
159		{"", true}, // will be rejected earlier too
160
161		// Without stuff before onion/local, they're fine to
162		// use DNS. With a search path,
163		// "onion.vegegtables.com" can use DNS. Without a
164		// search path (or with a trailing dot), the queries
165		// are just kinda useless, but don't reveal anything
166		// private.
167		{"local", false},
168		{"onion", false},
169		{"local.", false},
170		{"onion.", false},
171	}
172	for _, tt := range tests {
173		got := avoidDNS(tt.name)
174		if got != tt.avoid {
175			t.Errorf("avoidDNS(%q) = %v; want %v", tt.name, got, tt.avoid)
176		}
177	}
178}
179
180var fakeDNSServerSuccessful = fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
181	r := &dnsMsg{
182		dnsMsgHdr: dnsMsgHdr{
183			id:       q.id,
184			response: true,
185		},
186		question: q.question,
187	}
188	if len(q.question) == 1 && q.question[0].Qtype == dnsTypeA {
189		r.answer = []dnsRR{
190			&dnsRR_A{
191				Hdr: dnsRR_Header{
192					Name:     q.question[0].Name,
193					Rrtype:   dnsTypeA,
194					Class:    dnsClassINET,
195					Rdlength: 4,
196				},
197				A: TestAddr,
198			},
199		}
200	}
201	return r, nil
202}}
203
204// Issue 13705: don't try to resolve onion addresses, etc
205func TestLookupTorOnion(t *testing.T) {
206	r := Resolver{PreferGo: true, Dial: fakeDNSServerSuccessful.DialContext}
207	addrs, err := r.LookupIPAddr(context.Background(), "foo.onion")
208	if err != nil {
209		t.Fatalf("lookup = %v; want nil", err)
210	}
211	if len(addrs) > 0 {
212		t.Errorf("unexpected addresses: %v", addrs)
213	}
214}
215
216type resolvConfTest struct {
217	dir  string
218	path string
219	*resolverConfig
220}
221
222func newResolvConfTest() (*resolvConfTest, error) {
223	dir, err := ioutil.TempDir("", "go-resolvconftest")
224	if err != nil {
225		return nil, err
226	}
227	conf := &resolvConfTest{
228		dir:            dir,
229		path:           path.Join(dir, "resolv.conf"),
230		resolverConfig: &resolvConf,
231	}
232	conf.initOnce.Do(conf.init)
233	return conf, nil
234}
235
236func (conf *resolvConfTest) writeAndUpdate(lines []string) error {
237	f, err := os.OpenFile(conf.path, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600)
238	if err != nil {
239		return err
240	}
241	if _, err := f.WriteString(strings.Join(lines, "\n")); err != nil {
242		f.Close()
243		return err
244	}
245	f.Close()
246	if err := conf.forceUpdate(conf.path, time.Now().Add(time.Hour)); err != nil {
247		return err
248	}
249	return nil
250}
251
252func (conf *resolvConfTest) forceUpdate(name string, lastChecked time.Time) error {
253	dnsConf := dnsReadConfig(name)
254	conf.mu.Lock()
255	conf.dnsConfig = dnsConf
256	conf.mu.Unlock()
257	for i := 0; i < 5; i++ {
258		if conf.tryAcquireSema() {
259			conf.lastChecked = lastChecked
260			conf.releaseSema()
261			return nil
262		}
263	}
264	return fmt.Errorf("tryAcquireSema for %s failed", name)
265}
266
267func (conf *resolvConfTest) servers() []string {
268	conf.mu.RLock()
269	servers := conf.dnsConfig.servers
270	conf.mu.RUnlock()
271	return servers
272}
273
274func (conf *resolvConfTest) teardown() error {
275	err := conf.forceUpdate("/etc/resolv.conf", time.Time{})
276	os.RemoveAll(conf.dir)
277	return err
278}
279
280var updateResolvConfTests = []struct {
281	name    string   // query name
282	lines   []string // resolver configuration lines
283	servers []string // expected name servers
284}{
285	{
286		name:    "golang.org",
287		lines:   []string{"nameserver 8.8.8.8"},
288		servers: []string{"8.8.8.8:53"},
289	},
290	{
291		name:    "",
292		lines:   nil, // an empty resolv.conf should use defaultNS as name servers
293		servers: defaultNS,
294	},
295	{
296		name:    "www.example.com",
297		lines:   []string{"nameserver 8.8.4.4"},
298		servers: []string{"8.8.4.4:53"},
299	},
300}
301
302func TestUpdateResolvConf(t *testing.T) {
303	r := Resolver{PreferGo: true, Dial: fakeDNSServerSuccessful.DialContext}
304
305	conf, err := newResolvConfTest()
306	if err != nil {
307		t.Fatal(err)
308	}
309	defer conf.teardown()
310
311	for i, tt := range updateResolvConfTests {
312		if err := conf.writeAndUpdate(tt.lines); err != nil {
313			t.Error(err)
314			continue
315		}
316		if tt.name != "" {
317			var wg sync.WaitGroup
318			const N = 10
319			wg.Add(N)
320			for j := 0; j < N; j++ {
321				go func(name string) {
322					defer wg.Done()
323					ips, err := r.LookupIPAddr(context.Background(), name)
324					if err != nil {
325						t.Error(err)
326						return
327					}
328					if len(ips) == 0 {
329						t.Errorf("no records for %s", name)
330						return
331					}
332				}(tt.name)
333			}
334			wg.Wait()
335		}
336		servers := conf.servers()
337		if !reflect.DeepEqual(servers, tt.servers) {
338			t.Errorf("#%d: got %v; want %v", i, servers, tt.servers)
339			continue
340		}
341	}
342}
343
344var goLookupIPWithResolverConfigTests = []struct {
345	name  string
346	lines []string // resolver configuration lines
347	error
348	a, aaaa bool // whether response contains A, AAAA-record
349}{
350	// no records, transport timeout
351	{
352		"jgahvsekduiv9bw4b3qhn4ykdfgj0493iohkrjfhdvhjiu4j",
353		[]string{
354			"options timeout:1 attempts:1",
355			"nameserver 255.255.255.255", // please forgive us for abuse of limited broadcast address
356		},
357		&DNSError{Name: "jgahvsekduiv9bw4b3qhn4ykdfgj0493iohkrjfhdvhjiu4j", Server: "255.255.255.255:53", IsTimeout: true},
358		false, false,
359	},
360
361	// no records, non-existent domain
362	{
363		"jgahvsekduiv9bw4b3qhn4ykdfgj0493iohkrjfhdvhjiu4j",
364		[]string{
365			"options timeout:3 attempts:1",
366			"nameserver 8.8.8.8",
367		},
368		&DNSError{Name: "jgahvsekduiv9bw4b3qhn4ykdfgj0493iohkrjfhdvhjiu4j", Server: "8.8.8.8:53", IsTimeout: false},
369		false, false,
370	},
371
372	// a few A records, no AAAA records
373	{
374		"ipv4.google.com.",
375		[]string{
376			"nameserver 8.8.8.8",
377			"nameserver 2001:4860:4860::8888",
378		},
379		nil,
380		true, false,
381	},
382	{
383		"ipv4.google.com",
384		[]string{
385			"domain golang.org",
386			"nameserver 2001:4860:4860::8888",
387			"nameserver 8.8.8.8",
388		},
389		nil,
390		true, false,
391	},
392	{
393		"ipv4.google.com",
394		[]string{
395			"search x.golang.org y.golang.org",
396			"nameserver 2001:4860:4860::8888",
397			"nameserver 8.8.8.8",
398		},
399		nil,
400		true, false,
401	},
402
403	// no A records, a few AAAA records
404	{
405		"ipv6.google.com.",
406		[]string{
407			"nameserver 2001:4860:4860::8888",
408			"nameserver 8.8.8.8",
409		},
410		nil,
411		false, true,
412	},
413	{
414		"ipv6.google.com",
415		[]string{
416			"domain golang.org",
417			"nameserver 8.8.8.8",
418			"nameserver 2001:4860:4860::8888",
419		},
420		nil,
421		false, true,
422	},
423	{
424		"ipv6.google.com",
425		[]string{
426			"search x.golang.org y.golang.org",
427			"nameserver 8.8.8.8",
428			"nameserver 2001:4860:4860::8888",
429		},
430		nil,
431		false, true,
432	},
433
434	// both A and AAAA records
435	{
436		"hostname.as112.net", // see RFC 7534
437		[]string{
438			"domain golang.org",
439			"nameserver 2001:4860:4860::8888",
440			"nameserver 8.8.8.8",
441		},
442		nil,
443		true, true,
444	},
445	{
446		"hostname.as112.net", // see RFC 7534
447		[]string{
448			"search x.golang.org y.golang.org",
449			"nameserver 2001:4860:4860::8888",
450			"nameserver 8.8.8.8",
451		},
452		nil,
453		true, true,
454	},
455}
456
457func TestGoLookupIPWithResolverConfig(t *testing.T) {
458	fake := fakeDNSServer{func(n, s string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
459		switch s {
460		case "[2001:4860:4860::8888]:53", "8.8.8.8:53":
461			break
462		default:
463			time.Sleep(10 * time.Millisecond)
464			return nil, poll.ErrTimeout
465		}
466		r := &dnsMsg{
467			dnsMsgHdr: dnsMsgHdr{
468				id:       q.id,
469				response: true,
470			},
471			question: q.question,
472		}
473		for _, question := range q.question {
474			switch question.Qtype {
475			case dnsTypeA:
476				switch question.Name {
477				case "hostname.as112.net.":
478					break
479				case "ipv4.google.com.":
480					r.answer = append(r.answer, &dnsRR_A{
481						Hdr: dnsRR_Header{
482							Name:     q.question[0].Name,
483							Rrtype:   dnsTypeA,
484							Class:    dnsClassINET,
485							Rdlength: 4,
486						},
487						A: TestAddr,
488					})
489				default:
490
491				}
492			case dnsTypeAAAA:
493				switch question.Name {
494				case "hostname.as112.net.":
495					break
496				case "ipv6.google.com.":
497					r.answer = append(r.answer, &dnsRR_AAAA{
498						Hdr: dnsRR_Header{
499							Name:     q.question[0].Name,
500							Rrtype:   dnsTypeAAAA,
501							Class:    dnsClassINET,
502							Rdlength: 16,
503						},
504						AAAA: TestAddr6,
505					})
506				}
507			}
508		}
509		return r, nil
510	}}
511	r := Resolver{PreferGo: true, Dial: fake.DialContext}
512
513	conf, err := newResolvConfTest()
514	if err != nil {
515		t.Fatal(err)
516	}
517	defer conf.teardown()
518
519	for _, tt := range goLookupIPWithResolverConfigTests {
520		if err := conf.writeAndUpdate(tt.lines); err != nil {
521			t.Error(err)
522			continue
523		}
524		addrs, err := r.LookupIPAddr(context.Background(), tt.name)
525		if err != nil {
526			if err, ok := err.(*DNSError); !ok || tt.error != nil && (err.Name != tt.error.(*DNSError).Name || err.Server != tt.error.(*DNSError).Server || err.IsTimeout != tt.error.(*DNSError).IsTimeout) {
527				t.Errorf("got %v; want %v", err, tt.error)
528			}
529			continue
530		}
531		if len(addrs) == 0 {
532			t.Errorf("no records for %s", tt.name)
533		}
534		if !tt.a && !tt.aaaa && len(addrs) > 0 {
535			t.Errorf("unexpected %v for %s", addrs, tt.name)
536		}
537		for _, addr := range addrs {
538			if !tt.a && addr.IP.To4() != nil {
539				t.Errorf("got %v; must not be IPv4 address", addr)
540			}
541			if !tt.aaaa && addr.IP.To16() != nil && addr.IP.To4() == nil {
542				t.Errorf("got %v; must not be IPv6 address", addr)
543			}
544		}
545	}
546}
547
548// Test that goLookupIPOrder falls back to the host file when no DNS servers are available.
549func TestGoLookupIPOrderFallbackToFile(t *testing.T) {
550	fake := fakeDNSServer{func(n, s string, q *dnsMsg, tm time.Time) (*dnsMsg, error) {
551		r := &dnsMsg{
552			dnsMsgHdr: dnsMsgHdr{
553				id:       q.id,
554				response: true,
555			},
556			question: q.question,
557		}
558		return r, nil
559	}}
560	r := Resolver{PreferGo: true, Dial: fake.DialContext}
561
562	// Add a config that simulates no dns servers being available.
563	conf, err := newResolvConfTest()
564	if err != nil {
565		t.Fatal(err)
566	}
567	if err := conf.writeAndUpdate([]string{}); err != nil {
568		t.Fatal(err)
569	}
570	// Redirect host file lookups.
571	defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath)
572	testHookHostsPath = "testdata/hosts"
573
574	for _, order := range []hostLookupOrder{hostLookupFilesDNS, hostLookupDNSFiles} {
575		name := fmt.Sprintf("order %v", order)
576
577		// First ensure that we get an error when contacting a non-existent host.
578		_, _, err := r.goLookupIPCNAMEOrder(context.Background(), "notarealhost", order)
579		if err == nil {
580			t.Errorf("%s: expected error while looking up name not in hosts file", name)
581			continue
582		}
583
584		// Now check that we get an address when the name appears in the hosts file.
585		addrs, _, err := r.goLookupIPCNAMEOrder(context.Background(), "thor", order) // entry is in "testdata/hosts"
586		if err != nil {
587			t.Errorf("%s: expected to successfully lookup host entry", name)
588			continue
589		}
590		if len(addrs) != 1 {
591			t.Errorf("%s: expected exactly one result, but got %v", name, addrs)
592			continue
593		}
594		if got, want := addrs[0].String(), "127.1.1.1"; got != want {
595			t.Errorf("%s: address doesn't match expectation. got %v, want %v", name, got, want)
596		}
597	}
598	defer conf.teardown()
599}
600
601// Issue 12712.
602// When using search domains, return the error encountered
603// querying the original name instead of an error encountered
604// querying a generated name.
605func TestErrorForOriginalNameWhenSearching(t *testing.T) {
606	const fqdn = "doesnotexist.domain"
607
608	conf, err := newResolvConfTest()
609	if err != nil {
610		t.Fatal(err)
611	}
612	defer conf.teardown()
613
614	if err := conf.writeAndUpdate([]string{"search servfail"}); err != nil {
615		t.Fatal(err)
616	}
617
618	fake := fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
619		r := &dnsMsg{
620			dnsMsgHdr: dnsMsgHdr{
621				id:       q.id,
622				response: true,
623			},
624			question: q.question,
625		}
626
627		switch q.question[0].Name {
628		case fqdn + ".servfail.":
629			r.rcode = dnsRcodeServerFailure
630		default:
631			r.rcode = dnsRcodeNameError
632		}
633
634		return r, nil
635	}}
636
637	cases := []struct {
638		strictErrors bool
639		wantErr      *DNSError
640	}{
641		{true, &DNSError{Name: fqdn, Err: "server misbehaving", IsTemporary: true}},
642		{false, &DNSError{Name: fqdn, Err: errNoSuchHost.Error()}},
643	}
644	for _, tt := range cases {
645		r := Resolver{PreferGo: true, StrictErrors: tt.strictErrors, Dial: fake.DialContext}
646		_, err = r.LookupIPAddr(context.Background(), fqdn)
647		if err == nil {
648			t.Fatal("expected an error")
649		}
650
651		want := tt.wantErr
652		if err, ok := err.(*DNSError); !ok || err.Name != want.Name || err.Err != want.Err || err.IsTemporary != want.IsTemporary {
653			t.Errorf("got %v; want %v", err, want)
654		}
655	}
656}
657
658// Issue 15434. If a name server gives a lame referral, continue to the next.
659func TestIgnoreLameReferrals(t *testing.T) {
660	conf, err := newResolvConfTest()
661	if err != nil {
662		t.Fatal(err)
663	}
664	defer conf.teardown()
665
666	if err := conf.writeAndUpdate([]string{"nameserver 192.0.2.1", // the one that will give a lame referral
667		"nameserver 192.0.2.2"}); err != nil {
668		t.Fatal(err)
669	}
670
671	fake := fakeDNSServer{func(_, s string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
672		t.Log(s, q)
673		r := &dnsMsg{
674			dnsMsgHdr: dnsMsgHdr{
675				id:       q.id,
676				response: true,
677			},
678			question: q.question,
679		}
680
681		if s == "192.0.2.2:53" {
682			r.recursion_available = true
683			if q.question[0].Qtype == dnsTypeA {
684				r.answer = []dnsRR{
685					&dnsRR_A{
686						Hdr: dnsRR_Header{
687							Name:     q.question[0].Name,
688							Rrtype:   dnsTypeA,
689							Class:    dnsClassINET,
690							Rdlength: 4,
691						},
692						A: TestAddr,
693					},
694				}
695			}
696		}
697
698		return r, nil
699	}}
700	r := Resolver{PreferGo: true, Dial: fake.DialContext}
701
702	addrs, err := r.LookupIPAddr(context.Background(), "www.golang.org")
703	if err != nil {
704		t.Fatal(err)
705	}
706
707	if got := len(addrs); got != 1 {
708		t.Fatalf("got %d addresses, want 1", got)
709	}
710
711	if got, want := addrs[0].String(), "192.0.2.1"; got != want {
712		t.Fatalf("got address %v, want %v", got, want)
713	}
714}
715
716func BenchmarkGoLookupIP(b *testing.B) {
717	testHookUninstaller.Do(uninstallTestHooks)
718	ctx := context.Background()
719
720	for i := 0; i < b.N; i++ {
721		goResolver.LookupIPAddr(ctx, "www.example.com")
722	}
723}
724
725func BenchmarkGoLookupIPNoSuchHost(b *testing.B) {
726	testHookUninstaller.Do(uninstallTestHooks)
727	ctx := context.Background()
728
729	for i := 0; i < b.N; i++ {
730		goResolver.LookupIPAddr(ctx, "some.nonexistent")
731	}
732}
733
734func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) {
735	testHookUninstaller.Do(uninstallTestHooks)
736
737	conf, err := newResolvConfTest()
738	if err != nil {
739		b.Fatal(err)
740	}
741	defer conf.teardown()
742
743	lines := []string{
744		"nameserver 203.0.113.254", // use TEST-NET-3 block, see RFC 5737
745		"nameserver 8.8.8.8",
746	}
747	if err := conf.writeAndUpdate(lines); err != nil {
748		b.Fatal(err)
749	}
750	ctx := context.Background()
751
752	for i := 0; i < b.N; i++ {
753		goResolver.LookupIPAddr(ctx, "www.example.com")
754	}
755}
756
757type fakeDNSServer struct {
758	rh func(n, s string, q *dnsMsg, t time.Time) (*dnsMsg, error)
759}
760
761func (server *fakeDNSServer) DialContext(_ context.Context, n, s string) (Conn, error) {
762	return &fakeDNSConn{nil, server, n, s, nil, time.Time{}}, nil
763}
764
765type fakeDNSConn struct {
766	Conn
767	server *fakeDNSServer
768	n      string
769	s      string
770	q      *dnsMsg
771	t      time.Time
772}
773
774func (f *fakeDNSConn) Close() error {
775	return nil
776}
777
778func (f *fakeDNSConn) Read(b []byte) (int, error) {
779	resp, err := f.server.rh(f.n, f.s, f.q, f.t)
780	if err != nil {
781		return 0, err
782	}
783
784	bb, ok := resp.Pack()
785	if !ok {
786		return 0, errors.New("cannot marshal DNS message")
787	}
788	if len(b) < len(bb) {
789		return 0, errors.New("read would fragment DNS message")
790	}
791
792	copy(b, bb)
793	return len(bb), nil
794}
795
796func (f *fakeDNSConn) ReadFrom(b []byte) (int, Addr, error) {
797	return 0, nil, nil
798}
799
800func (f *fakeDNSConn) Write(b []byte) (int, error) {
801	f.q = new(dnsMsg)
802	if !f.q.Unpack(b) {
803		return 0, errors.New("cannot unmarshal DNS message")
804	}
805	return len(b), nil
806}
807
808func (f *fakeDNSConn) WriteTo(b []byte, addr Addr) (int, error) {
809	return 0, nil
810}
811
812func (f *fakeDNSConn) SetDeadline(t time.Time) error {
813	f.t = t
814	return nil
815}
816
817// UDP round-tripper algorithm should ignore invalid DNS responses (issue 13281).
818func TestIgnoreDNSForgeries(t *testing.T) {
819	c, s := Pipe()
820	go func() {
821		b := make([]byte, 512)
822		n, err := s.Read(b)
823		if err != nil {
824			t.Error(err)
825			return
826		}
827
828		msg := &dnsMsg{}
829		if !msg.Unpack(b[:n]) {
830			t.Error("invalid DNS query")
831			return
832		}
833
834		s.Write([]byte("garbage DNS response packet"))
835
836		msg.response = true
837		msg.id++ // make invalid ID
838		b, ok := msg.Pack()
839		if !ok {
840			t.Error("failed to pack DNS response")
841			return
842		}
843		s.Write(b)
844
845		msg.id-- // restore original ID
846		msg.answer = []dnsRR{
847			&dnsRR_A{
848				Hdr: dnsRR_Header{
849					Name:     "www.example.com.",
850					Rrtype:   dnsTypeA,
851					Class:    dnsClassINET,
852					Rdlength: 4,
853				},
854				A: TestAddr,
855			},
856		}
857
858		b, ok = msg.Pack()
859		if !ok {
860			t.Error("failed to pack DNS response")
861			return
862		}
863		s.Write(b)
864	}()
865
866	msg := &dnsMsg{
867		dnsMsgHdr: dnsMsgHdr{
868			id: 42,
869		},
870		question: []dnsQuestion{
871			{
872				Name:   "www.example.com.",
873				Qtype:  dnsTypeA,
874				Qclass: dnsClassINET,
875			},
876		},
877	}
878
879	dc := &dnsPacketConn{c}
880	resp, err := dc.dnsRoundTrip(msg)
881	if err != nil {
882		t.Fatalf("dnsRoundTripUDP failed: %v", err)
883	}
884
885	if got := resp.answer[0].(*dnsRR_A).A; got != TestAddr {
886		t.Errorf("got address %v, want %v", got, TestAddr)
887	}
888}
889
890// Issue 16865. If a name server times out, continue to the next.
891func TestRetryTimeout(t *testing.T) {
892	conf, err := newResolvConfTest()
893	if err != nil {
894		t.Fatal(err)
895	}
896	defer conf.teardown()
897
898	testConf := []string{
899		"nameserver 192.0.2.1", // the one that will timeout
900		"nameserver 192.0.2.2",
901	}
902	if err := conf.writeAndUpdate(testConf); err != nil {
903		t.Fatal(err)
904	}
905
906	var deadline0 time.Time
907
908	fake := fakeDNSServer{func(_, s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) {
909		t.Log(s, q, deadline)
910
911		if deadline.IsZero() {
912			t.Error("zero deadline")
913		}
914
915		if s == "192.0.2.1:53" {
916			deadline0 = deadline
917			time.Sleep(10 * time.Millisecond)
918			return nil, poll.ErrTimeout
919		}
920
921		if deadline.Equal(deadline0) {
922			t.Error("deadline didn't change")
923		}
924
925		return mockTXTResponse(q), nil
926	}}
927	r := &Resolver{PreferGo: true, Dial: fake.DialContext}
928
929	_, err = r.LookupTXT(context.Background(), "www.golang.org")
930	if err != nil {
931		t.Fatal(err)
932	}
933
934	if deadline0.IsZero() {
935		t.Error("deadline0 still zero", deadline0)
936	}
937}
938
939func TestRotate(t *testing.T) {
940	// without rotation, always uses the first server
941	testRotate(t, false, []string{"192.0.2.1", "192.0.2.2"}, []string{"192.0.2.1:53", "192.0.2.1:53", "192.0.2.1:53"})
942
943	// with rotation, rotates through back to first
944	testRotate(t, true, []string{"192.0.2.1", "192.0.2.2"}, []string{"192.0.2.1:53", "192.0.2.2:53", "192.0.2.1:53"})
945}
946
947func testRotate(t *testing.T, rotate bool, nameservers, wantServers []string) {
948	conf, err := newResolvConfTest()
949	if err != nil {
950		t.Fatal(err)
951	}
952	defer conf.teardown()
953
954	var confLines []string
955	for _, ns := range nameservers {
956		confLines = append(confLines, "nameserver "+ns)
957	}
958	if rotate {
959		confLines = append(confLines, "options rotate")
960	}
961
962	if err := conf.writeAndUpdate(confLines); err != nil {
963		t.Fatal(err)
964	}
965
966	var usedServers []string
967	fake := fakeDNSServer{func(_, s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) {
968		usedServers = append(usedServers, s)
969		return mockTXTResponse(q), nil
970	}}
971	r := Resolver{PreferGo: true, Dial: fake.DialContext}
972
973	// len(nameservers) + 1 to allow rotation to get back to start
974	for i := 0; i < len(nameservers)+1; i++ {
975		if _, err := r.LookupTXT(context.Background(), "www.golang.org"); err != nil {
976			t.Fatal(err)
977		}
978	}
979
980	if !reflect.DeepEqual(usedServers, wantServers) {
981		t.Errorf("rotate=%t got used servers:\n%v\nwant:\n%v", rotate, usedServers, wantServers)
982	}
983}
984
985func mockTXTResponse(q *dnsMsg) *dnsMsg {
986	r := &dnsMsg{
987		dnsMsgHdr: dnsMsgHdr{
988			id:                  q.id,
989			response:            true,
990			recursion_available: true,
991		},
992		question: q.question,
993		answer: []dnsRR{
994			&dnsRR_TXT{
995				Hdr: dnsRR_Header{
996					Name:   q.question[0].Name,
997					Rrtype: dnsTypeTXT,
998					Class:  dnsClassINET,
999				},
1000				Txt: "ok",
1001			},
1002		},
1003	}
1004
1005	return r
1006}
1007
1008// Issue 17448. With StrictErrors enabled, temporary errors should make
1009// LookupIP fail rather than return a partial result.
1010func TestStrictErrorsLookupIP(t *testing.T) {
1011	conf, err := newResolvConfTest()
1012	if err != nil {
1013		t.Fatal(err)
1014	}
1015	defer conf.teardown()
1016
1017	confData := []string{
1018		"nameserver 192.0.2.53",
1019		"search x.golang.org y.golang.org",
1020	}
1021	if err := conf.writeAndUpdate(confData); err != nil {
1022		t.Fatal(err)
1023	}
1024
1025	const name = "test-issue19592"
1026	const server = "192.0.2.53:53"
1027	const searchX = "test-issue19592.x.golang.org."
1028	const searchY = "test-issue19592.y.golang.org."
1029	const ip4 = "192.0.2.1"
1030	const ip6 = "2001:db8::1"
1031
1032	type resolveWhichEnum int
1033	const (
1034		resolveOK resolveWhichEnum = iota
1035		resolveOpError
1036		resolveServfail
1037		resolveTimeout
1038	)
1039
1040	makeTempError := func(err string) error {
1041		return &DNSError{
1042			Err:         err,
1043			Name:        name,
1044			Server:      server,
1045			IsTemporary: true,
1046		}
1047	}
1048	makeTimeout := func() error {
1049		return &DNSError{
1050			Err:       poll.ErrTimeout.Error(),
1051			Name:      name,
1052			Server:    server,
1053			IsTimeout: true,
1054		}
1055	}
1056	makeNxDomain := func() error {
1057		return &DNSError{
1058			Err:    errNoSuchHost.Error(),
1059			Name:   name,
1060			Server: server,
1061		}
1062	}
1063
1064	cases := []struct {
1065		desc          string
1066		resolveWhich  func(quest *dnsQuestion) resolveWhichEnum
1067		wantStrictErr error
1068		wantLaxErr    error
1069		wantIPs       []string
1070	}{
1071		{
1072			desc: "No errors",
1073			resolveWhich: func(quest *dnsQuestion) resolveWhichEnum {
1074				return resolveOK
1075			},
1076			wantIPs: []string{ip4, ip6},
1077		},
1078		{
1079			desc: "searchX error fails in strict mode",
1080			resolveWhich: func(quest *dnsQuestion) resolveWhichEnum {
1081				if quest.Name == searchX {
1082					return resolveTimeout
1083				}
1084				return resolveOK
1085			},
1086			wantStrictErr: makeTimeout(),
1087			wantIPs:       []string{ip4, ip6},
1088		},
1089		{
1090			desc: "searchX IPv4-only timeout fails in strict mode",
1091			resolveWhich: func(quest *dnsQuestion) resolveWhichEnum {
1092				if quest.Name == searchX && quest.Qtype == dnsTypeA {
1093					return resolveTimeout
1094				}
1095				return resolveOK
1096			},
1097			wantStrictErr: makeTimeout(),
1098			wantIPs:       []string{ip4, ip6},
1099		},
1100		{
1101			desc: "searchX IPv6-only servfail fails in strict mode",
1102			resolveWhich: func(quest *dnsQuestion) resolveWhichEnum {
1103				if quest.Name == searchX && quest.Qtype == dnsTypeAAAA {
1104					return resolveServfail
1105				}
1106				return resolveOK
1107			},
1108			wantStrictErr: makeTempError("server misbehaving"),
1109			wantIPs:       []string{ip4, ip6},
1110		},
1111		{
1112			desc: "searchY error always fails",
1113			resolveWhich: func(quest *dnsQuestion) resolveWhichEnum {
1114				if quest.Name == searchY {
1115					return resolveTimeout
1116				}
1117				return resolveOK
1118			},
1119			wantStrictErr: makeTimeout(),
1120			wantLaxErr:    makeNxDomain(), // This one reaches the "test." FQDN.
1121		},
1122		{
1123			desc: "searchY IPv4-only socket error fails in strict mode",
1124			resolveWhich: func(quest *dnsQuestion) resolveWhichEnum {
1125				if quest.Name == searchY && quest.Qtype == dnsTypeA {
1126					return resolveOpError
1127				}
1128				return resolveOK
1129			},
1130			wantStrictErr: makeTempError("write: socket on fire"),
1131			wantIPs:       []string{ip6},
1132		},
1133		{
1134			desc: "searchY IPv6-only timeout fails in strict mode",
1135			resolveWhich: func(quest *dnsQuestion) resolveWhichEnum {
1136				if quest.Name == searchY && quest.Qtype == dnsTypeAAAA {
1137					return resolveTimeout
1138				}
1139				return resolveOK
1140			},
1141			wantStrictErr: makeTimeout(),
1142			wantIPs:       []string{ip4},
1143		},
1144	}
1145
1146	for i, tt := range cases {
1147		fake := fakeDNSServer{func(_, s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) {
1148			t.Log(s, q)
1149
1150			switch tt.resolveWhich(&q.question[0]) {
1151			case resolveOK:
1152				// Handle below.
1153			case resolveOpError:
1154				return nil, &OpError{Op: "write", Err: fmt.Errorf("socket on fire")}
1155			case resolveServfail:
1156				return &dnsMsg{
1157					dnsMsgHdr: dnsMsgHdr{
1158						id:       q.id,
1159						response: true,
1160						rcode:    dnsRcodeServerFailure,
1161					},
1162					question: q.question,
1163				}, nil
1164			case resolveTimeout:
1165				return nil, poll.ErrTimeout
1166			default:
1167				t.Fatal("Impossible resolveWhich")
1168			}
1169
1170			switch q.question[0].Name {
1171			case searchX, name + ".":
1172				// Return NXDOMAIN to utilize the search list.
1173				return &dnsMsg{
1174					dnsMsgHdr: dnsMsgHdr{
1175						id:       q.id,
1176						response: true,
1177						rcode:    dnsRcodeNameError,
1178					},
1179					question: q.question,
1180				}, nil
1181			case searchY:
1182				// Return records below.
1183			default:
1184				return nil, fmt.Errorf("Unexpected Name: %v", q.question[0].Name)
1185			}
1186
1187			r := &dnsMsg{
1188				dnsMsgHdr: dnsMsgHdr{
1189					id:       q.id,
1190					response: true,
1191				},
1192				question: q.question,
1193			}
1194			switch q.question[0].Qtype {
1195			case dnsTypeA:
1196				r.answer = []dnsRR{
1197					&dnsRR_A{
1198						Hdr: dnsRR_Header{
1199							Name:     q.question[0].Name,
1200							Rrtype:   dnsTypeA,
1201							Class:    dnsClassINET,
1202							Rdlength: 4,
1203						},
1204						A: TestAddr,
1205					},
1206				}
1207			case dnsTypeAAAA:
1208				r.answer = []dnsRR{
1209					&dnsRR_AAAA{
1210						Hdr: dnsRR_Header{
1211							Name:     q.question[0].Name,
1212							Rrtype:   dnsTypeAAAA,
1213							Class:    dnsClassINET,
1214							Rdlength: 16,
1215						},
1216						AAAA: TestAddr6,
1217					},
1218				}
1219			default:
1220				return nil, fmt.Errorf("Unexpected Qtype: %v", q.question[0].Qtype)
1221			}
1222			return r, nil
1223		}}
1224
1225		for _, strict := range []bool{true, false} {
1226			r := Resolver{PreferGo: true, StrictErrors: strict, Dial: fake.DialContext}
1227			ips, err := r.LookupIPAddr(context.Background(), name)
1228
1229			var wantErr error
1230			if strict {
1231				wantErr = tt.wantStrictErr
1232			} else {
1233				wantErr = tt.wantLaxErr
1234			}
1235			if !reflect.DeepEqual(err, wantErr) {
1236				t.Errorf("#%d (%s) strict=%v: got err %#v; want %#v", i, tt.desc, strict, err, wantErr)
1237			}
1238
1239			gotIPs := map[string]struct{}{}
1240			for _, ip := range ips {
1241				gotIPs[ip.String()] = struct{}{}
1242			}
1243			wantIPs := map[string]struct{}{}
1244			if wantErr == nil {
1245				for _, ip := range tt.wantIPs {
1246					wantIPs[ip] = struct{}{}
1247				}
1248			}
1249			if !reflect.DeepEqual(gotIPs, wantIPs) {
1250				t.Errorf("#%d (%s) strict=%v: got ips %v; want %v", i, tt.desc, strict, gotIPs, wantIPs)
1251			}
1252		}
1253	}
1254}
1255
1256// Issue 17448. With StrictErrors enabled, temporary errors should make
1257// LookupTXT stop walking the search list.
1258func TestStrictErrorsLookupTXT(t *testing.T) {
1259	conf, err := newResolvConfTest()
1260	if err != nil {
1261		t.Fatal(err)
1262	}
1263	defer conf.teardown()
1264
1265	confData := []string{
1266		"nameserver 192.0.2.53",
1267		"search x.golang.org y.golang.org",
1268	}
1269	if err := conf.writeAndUpdate(confData); err != nil {
1270		t.Fatal(err)
1271	}
1272
1273	const name = "test"
1274	const server = "192.0.2.53:53"
1275	const searchX = "test.x.golang.org."
1276	const searchY = "test.y.golang.org."
1277	const txt = "Hello World"
1278
1279	fake := fakeDNSServer{func(_, s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) {
1280		t.Log(s, q)
1281
1282		switch q.question[0].Name {
1283		case searchX:
1284			return nil, poll.ErrTimeout
1285		case searchY:
1286			return mockTXTResponse(q), nil
1287		default:
1288			return nil, fmt.Errorf("Unexpected Name: %v", q.question[0].Name)
1289		}
1290	}}
1291
1292	for _, strict := range []bool{true, false} {
1293		r := Resolver{StrictErrors: strict, Dial: fake.DialContext}
1294		_, rrs, err := r.lookup(context.Background(), name, dnsTypeTXT)
1295		var wantErr error
1296		var wantRRs int
1297		if strict {
1298			wantErr = &DNSError{
1299				Err:       poll.ErrTimeout.Error(),
1300				Name:      name,
1301				Server:    server,
1302				IsTimeout: true,
1303			}
1304		} else {
1305			wantRRs = 1
1306		}
1307		if !reflect.DeepEqual(err, wantErr) {
1308			t.Errorf("strict=%v: got err %#v; want %#v", strict, err, wantErr)
1309		}
1310		if len(rrs) != wantRRs {
1311			t.Errorf("strict=%v: got %v; want %v", strict, len(rrs), wantRRs)
1312		}
1313	}
1314}
1315