1// Copyright 2013 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
6
7package net
8
9import (
10	"context"
11	"errors"
12	"fmt"
13	"internal/poll"
14	"io/ioutil"
15	"os"
16	"path"
17	"reflect"
18	"strings"
19	"sync"
20	"testing"
21	"time"
22)
23
24var goResolver = Resolver{PreferGo: true}
25
26// Test address from 192.0.2.0/24 block, reserved by RFC 5737 for documentation.
27const TestAddr uint32 = 0xc0000201
28
29// Test address from 2001:db8::/32 block, reserved by RFC 3849 for documentation.
30var VarTestAddr6 = [16]byte{0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
31
32var dnsTransportFallbackTests = []struct {
33	server  string
34	name    string
35	qtype   uint16
36	timeout int
37	rcode   int
38}{
39	// Querying "com." with qtype=255 usually makes an answer
40	// which requires more than 512 bytes.
41	{"8.8.8.8:53", "com.", dnsTypeALL, 2, dnsRcodeSuccess},
42	{"8.8.4.4:53", "com.", dnsTypeALL, 4, dnsRcodeSuccess},
43}
44
45func TestDNSTransportFallback(t *testing.T) {
46	fake := fakeDNSServer{
47		rh: func(n, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
48			r := &dnsMsg{
49				dnsMsgHdr: dnsMsgHdr{
50					id:       q.id,
51					response: true,
52					rcode:    dnsRcodeSuccess,
53				},
54				question: q.question,
55			}
56			if n == "udp" {
57				r.truncated = true
58			}
59			return r, nil
60		},
61	}
62	r := Resolver{PreferGo: true, Dial: fake.DialContext}
63	for _, tt := range dnsTransportFallbackTests {
64		ctx, cancel := context.WithCancel(context.Background())
65		defer cancel()
66		msg, err := r.exchange(ctx, tt.server, tt.name, tt.qtype, time.Second)
67		if err != nil {
68			t.Error(err)
69			continue
70		}
71		switch msg.rcode {
72		case tt.rcode:
73		default:
74			t.Errorf("got %v from %v; want %v", msg.rcode, tt.server, tt.rcode)
75			continue
76		}
77	}
78}
79
80// See RFC 6761 for further information about the reserved, pseudo
81// domain names.
82var specialDomainNameTests = []struct {
83	name  string
84	qtype uint16
85	rcode int
86}{
87	// Name resolution APIs and libraries should not recognize the
88	// followings as special.
89	{"1.0.168.192.in-addr.arpa.", dnsTypePTR, dnsRcodeNameError},
90	{"test.", dnsTypeALL, dnsRcodeNameError},
91	{"example.com.", dnsTypeALL, dnsRcodeSuccess},
92
93	// Name resolution APIs and libraries should recognize the
94	// followings as special and should not send any queries.
95	// Though, we test those names here for verifying negative
96	// answers at DNS query-response interaction level.
97	{"localhost.", dnsTypeALL, dnsRcodeNameError},
98	{"invalid.", dnsTypeALL, dnsRcodeNameError},
99}
100
101func TestSpecialDomainName(t *testing.T) {
102	fake := fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
103		r := &dnsMsg{
104			dnsMsgHdr: dnsMsgHdr{
105				id:       q.id,
106				response: true,
107			},
108			question: q.question,
109		}
110
111		switch q.question[0].Name {
112		case "example.com.":
113			r.rcode = dnsRcodeSuccess
114		default:
115			r.rcode = dnsRcodeNameError
116		}
117
118		return r, nil
119	}}
120	r := Resolver{PreferGo: true, Dial: fake.DialContext}
121	server := "8.8.8.8:53"
122	for _, tt := range specialDomainNameTests {
123		ctx, cancel := context.WithCancel(context.Background())
124		defer cancel()
125		msg, err := r.exchange(ctx, server, tt.name, tt.qtype, 3*time.Second)
126		if err != nil {
127			t.Error(err)
128			continue
129		}
130		switch msg.rcode {
131		case tt.rcode, dnsRcodeServerFailure:
132		default:
133			t.Errorf("got %v from %v; want %v", msg.rcode, server, tt.rcode)
134			continue
135		}
136	}
137}
138
139// Issue 13705: don't try to resolve onion addresses, etc
140func TestAvoidDNSName(t *testing.T) {
141	tests := []struct {
142		name  string
143		avoid bool
144	}{
145		{"foo.com", false},
146		{"foo.com.", false},
147
148		{"foo.onion.", true},
149		{"foo.onion", true},
150		{"foo.ONION", true},
151		{"foo.ONION.", true},
152
153		// But do resolve *.local address; Issue 16739
154		{"foo.local.", false},
155		{"foo.local", false},
156		{"foo.LOCAL", false},
157		{"foo.LOCAL.", false},
158
159		{"", true}, // will be rejected earlier too
160
161		// Without stuff before onion/local, they're fine to
162		// use DNS. With a search path,
163		// "onion.vegegtables.com" can use DNS. Without a
164		// search path (or with a trailing dot), the queries
165		// are just kinda useless, but don't reveal anything
166		// private.
167		{"local", false},
168		{"onion", false},
169		{"local.", false},
170		{"onion.", false},
171	}
172	for _, tt := range tests {
173		got := avoidDNS(tt.name)
174		if got != tt.avoid {
175			t.Errorf("avoidDNS(%q) = %v; want %v", tt.name, got, tt.avoid)
176		}
177	}
178}
179
180var fakeDNSServerSuccessful = fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
181	r := &dnsMsg{
182		dnsMsgHdr: dnsMsgHdr{
183			id:       q.id,
184			response: true,
185		},
186		question: q.question,
187	}
188	if len(q.question) == 1 && q.question[0].Qtype == dnsTypeA {
189		r.answer = []dnsRR{
190			&dnsRR_A{
191				Hdr: dnsRR_Header{
192					Name:     q.question[0].Name,
193					Rrtype:   dnsTypeA,
194					Class:    dnsClassINET,
195					Rdlength: 4,
196				},
197				A: TestAddr,
198			},
199		}
200	}
201	return r, nil
202}}
203
204// Issue 13705: don't try to resolve onion addresses, etc
205func TestLookupTorOnion(t *testing.T) {
206	defer dnsWaitGroup.Wait()
207	r := Resolver{PreferGo: true, Dial: fakeDNSServerSuccessful.DialContext}
208	addrs, err := r.LookupIPAddr(context.Background(), "foo.onion")
209	if err != nil {
210		t.Fatalf("lookup = %v; want nil", err)
211	}
212	if len(addrs) > 0 {
213		t.Errorf("unexpected addresses: %v", addrs)
214	}
215}
216
217type resolvConfTest struct {
218	dir  string
219	path string
220	*resolverConfig
221}
222
223func newResolvConfTest() (*resolvConfTest, error) {
224	dir, err := ioutil.TempDir("", "go-resolvconftest")
225	if err != nil {
226		return nil, err
227	}
228	conf := &resolvConfTest{
229		dir:            dir,
230		path:           path.Join(dir, "resolv.conf"),
231		resolverConfig: &resolvConf,
232	}
233	conf.initOnce.Do(conf.init)
234	return conf, nil
235}
236
237func (conf *resolvConfTest) writeAndUpdate(lines []string) error {
238	f, err := os.OpenFile(conf.path, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600)
239	if err != nil {
240		return err
241	}
242	if _, err := f.WriteString(strings.Join(lines, "\n")); err != nil {
243		f.Close()
244		return err
245	}
246	f.Close()
247	if err := conf.forceUpdate(conf.path, time.Now().Add(time.Hour)); err != nil {
248		return err
249	}
250	return nil
251}
252
253func (conf *resolvConfTest) forceUpdate(name string, lastChecked time.Time) error {
254	dnsConf := dnsReadConfig(name)
255	conf.mu.Lock()
256	conf.dnsConfig = dnsConf
257	conf.mu.Unlock()
258	for i := 0; i < 5; i++ {
259		if conf.tryAcquireSema() {
260			conf.lastChecked = lastChecked
261			conf.releaseSema()
262			return nil
263		}
264	}
265	return fmt.Errorf("tryAcquireSema for %s failed", name)
266}
267
268func (conf *resolvConfTest) servers() []string {
269	conf.mu.RLock()
270	servers := conf.dnsConfig.servers
271	conf.mu.RUnlock()
272	return servers
273}
274
275func (conf *resolvConfTest) teardown() error {
276	err := conf.forceUpdate("/etc/resolv.conf", time.Time{})
277	os.RemoveAll(conf.dir)
278	return err
279}
280
281var updateResolvConfTests = []struct {
282	name    string   // query name
283	lines   []string // resolver configuration lines
284	servers []string // expected name servers
285}{
286	{
287		name:    "golang.org",
288		lines:   []string{"nameserver 8.8.8.8"},
289		servers: []string{"8.8.8.8:53"},
290	},
291	{
292		name:    "",
293		lines:   nil, // an empty resolv.conf should use defaultNS as name servers
294		servers: defaultNS,
295	},
296	{
297		name:    "www.example.com",
298		lines:   []string{"nameserver 8.8.4.4"},
299		servers: []string{"8.8.4.4:53"},
300	},
301}
302
303func TestUpdateResolvConf(t *testing.T) {
304	defer dnsWaitGroup.Wait()
305
306	r := Resolver{PreferGo: true, Dial: fakeDNSServerSuccessful.DialContext}
307
308	conf, err := newResolvConfTest()
309	if err != nil {
310		t.Fatal(err)
311	}
312	defer conf.teardown()
313
314	for i, tt := range updateResolvConfTests {
315		if err := conf.writeAndUpdate(tt.lines); err != nil {
316			t.Error(err)
317			continue
318		}
319		if tt.name != "" {
320			var wg sync.WaitGroup
321			const N = 10
322			wg.Add(N)
323			for j := 0; j < N; j++ {
324				go func(name string) {
325					defer wg.Done()
326					ips, err := r.LookupIPAddr(context.Background(), name)
327					if err != nil {
328						t.Error(err)
329						return
330					}
331					if len(ips) == 0 {
332						t.Errorf("no records for %s", name)
333						return
334					}
335				}(tt.name)
336			}
337			wg.Wait()
338		}
339		servers := conf.servers()
340		if !reflect.DeepEqual(servers, tt.servers) {
341			t.Errorf("#%d: got %v; want %v", i, servers, tt.servers)
342			continue
343		}
344	}
345}
346
347var goLookupIPWithResolverConfigTests = []struct {
348	name  string
349	lines []string // resolver configuration lines
350	error
351	a, aaaa bool // whether response contains A, AAAA-record
352}{
353	// no records, transport timeout
354	{
355		"jgahvsekduiv9bw4b3qhn4ykdfgj0493iohkrjfhdvhjiu4j",
356		[]string{
357			"options timeout:1 attempts:1",
358			"nameserver 255.255.255.255", // please forgive us for abuse of limited broadcast address
359		},
360		&DNSError{Name: "jgahvsekduiv9bw4b3qhn4ykdfgj0493iohkrjfhdvhjiu4j", Server: "255.255.255.255:53", IsTimeout: true},
361		false, false,
362	},
363
364	// no records, non-existent domain
365	{
366		"jgahvsekduiv9bw4b3qhn4ykdfgj0493iohkrjfhdvhjiu4j",
367		[]string{
368			"options timeout:3 attempts:1",
369			"nameserver 8.8.8.8",
370		},
371		&DNSError{Name: "jgahvsekduiv9bw4b3qhn4ykdfgj0493iohkrjfhdvhjiu4j", Server: "8.8.8.8:53", IsTimeout: false},
372		false, false,
373	},
374
375	// a few A records, no AAAA records
376	{
377		"ipv4.google.com.",
378		[]string{
379			"nameserver 8.8.8.8",
380			"nameserver 2001:4860:4860::8888",
381		},
382		nil,
383		true, false,
384	},
385	{
386		"ipv4.google.com",
387		[]string{
388			"domain golang.org",
389			"nameserver 2001:4860:4860::8888",
390			"nameserver 8.8.8.8",
391		},
392		nil,
393		true, false,
394	},
395	{
396		"ipv4.google.com",
397		[]string{
398			"search x.golang.org y.golang.org",
399			"nameserver 2001:4860:4860::8888",
400			"nameserver 8.8.8.8",
401		},
402		nil,
403		true, false,
404	},
405
406	// no A records, a few AAAA records
407	{
408		"ipv6.google.com.",
409		[]string{
410			"nameserver 2001:4860:4860::8888",
411			"nameserver 8.8.8.8",
412		},
413		nil,
414		false, true,
415	},
416	{
417		"ipv6.google.com",
418		[]string{
419			"domain golang.org",
420			"nameserver 8.8.8.8",
421			"nameserver 2001:4860:4860::8888",
422		},
423		nil,
424		false, true,
425	},
426	{
427		"ipv6.google.com",
428		[]string{
429			"search x.golang.org y.golang.org",
430			"nameserver 8.8.8.8",
431			"nameserver 2001:4860:4860::8888",
432		},
433		nil,
434		false, true,
435	},
436
437	// both A and AAAA records
438	{
439		"hostname.as112.net", // see RFC 7534
440		[]string{
441			"domain golang.org",
442			"nameserver 2001:4860:4860::8888",
443			"nameserver 8.8.8.8",
444		},
445		nil,
446		true, true,
447	},
448	{
449		"hostname.as112.net", // see RFC 7534
450		[]string{
451			"search x.golang.org y.golang.org",
452			"nameserver 2001:4860:4860::8888",
453			"nameserver 8.8.8.8",
454		},
455		nil,
456		true, true,
457	},
458}
459
460func TestGoLookupIPWithResolverConfig(t *testing.T) {
461	defer dnsWaitGroup.Wait()
462
463	fake := fakeDNSServer{func(n, s string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
464		switch s {
465		case "[2001:4860:4860::8888]:53", "8.8.8.8:53":
466			break
467		default:
468			time.Sleep(10 * time.Millisecond)
469			return nil, poll.ErrTimeout
470		}
471		r := &dnsMsg{
472			dnsMsgHdr: dnsMsgHdr{
473				id:       q.id,
474				response: true,
475			},
476			question: q.question,
477		}
478		for _, question := range q.question {
479			switch question.Qtype {
480			case dnsTypeA:
481				switch question.Name {
482				case "hostname.as112.net.":
483					break
484				case "ipv4.google.com.":
485					r.answer = append(r.answer, &dnsRR_A{
486						Hdr: dnsRR_Header{
487							Name:     q.question[0].Name,
488							Rrtype:   dnsTypeA,
489							Class:    dnsClassINET,
490							Rdlength: 4,
491						},
492						A: TestAddr,
493					})
494				default:
495
496				}
497			case dnsTypeAAAA:
498				switch question.Name {
499				case "hostname.as112.net.":
500					break
501				case "ipv6.google.com.":
502					r.answer = append(r.answer, &dnsRR_AAAA{
503						Hdr: dnsRR_Header{
504							Name:     q.question[0].Name,
505							Rrtype:   dnsTypeAAAA,
506							Class:    dnsClassINET,
507							Rdlength: 16,
508						},
509						AAAA: VarTestAddr6,
510					})
511				}
512			}
513		}
514		return r, nil
515	}}
516	r := Resolver{PreferGo: true, Dial: fake.DialContext}
517
518	conf, err := newResolvConfTest()
519	if err != nil {
520		t.Fatal(err)
521	}
522	defer conf.teardown()
523
524	for _, tt := range goLookupIPWithResolverConfigTests {
525		if err := conf.writeAndUpdate(tt.lines); err != nil {
526			t.Error(err)
527			continue
528		}
529		addrs, err := r.LookupIPAddr(context.Background(), tt.name)
530		if err != nil {
531			if err, ok := err.(*DNSError); !ok || tt.error != nil && (err.Name != tt.error.(*DNSError).Name || err.Server != tt.error.(*DNSError).Server || err.IsTimeout != tt.error.(*DNSError).IsTimeout) {
532				t.Errorf("got %v; want %v", err, tt.error)
533			}
534			continue
535		}
536		if len(addrs) == 0 {
537			t.Errorf("no records for %s", tt.name)
538		}
539		if !tt.a && !tt.aaaa && len(addrs) > 0 {
540			t.Errorf("unexpected %v for %s", addrs, tt.name)
541		}
542		for _, addr := range addrs {
543			if !tt.a && addr.IP.To4() != nil {
544				t.Errorf("got %v; must not be IPv4 address", addr)
545			}
546			if !tt.aaaa && addr.IP.To16() != nil && addr.IP.To4() == nil {
547				t.Errorf("got %v; must not be IPv6 address", addr)
548			}
549		}
550	}
551}
552
553// Test that goLookupIPOrder falls back to the host file when no DNS servers are available.
554func TestGoLookupIPOrderFallbackToFile(t *testing.T) {
555	defer dnsWaitGroup.Wait()
556
557	fake := fakeDNSServer{func(n, s string, q *dnsMsg, tm time.Time) (*dnsMsg, error) {
558		r := &dnsMsg{
559			dnsMsgHdr: dnsMsgHdr{
560				id:       q.id,
561				response: true,
562			},
563			question: q.question,
564		}
565		return r, nil
566	}}
567	r := Resolver{PreferGo: true, Dial: fake.DialContext}
568
569	// Add a config that simulates no dns servers being available.
570	conf, err := newResolvConfTest()
571	if err != nil {
572		t.Fatal(err)
573	}
574	if err := conf.writeAndUpdate([]string{}); err != nil {
575		t.Fatal(err)
576	}
577	// Redirect host file lookups.
578	defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath)
579	testHookHostsPath = "testdata/hosts"
580
581	for _, order := range []hostLookupOrder{hostLookupFilesDNS, hostLookupDNSFiles} {
582		name := fmt.Sprintf("order %v", order)
583
584		// First ensure that we get an error when contacting a non-existent host.
585		_, _, err := r.goLookupIPCNAMEOrder(context.Background(), "notarealhost", order)
586		if err == nil {
587			t.Errorf("%s: expected error while looking up name not in hosts file", name)
588			continue
589		}
590
591		// Now check that we get an address when the name appears in the hosts file.
592		addrs, _, err := r.goLookupIPCNAMEOrder(context.Background(), "thor", order) // entry is in "testdata/hosts"
593		if err != nil {
594			t.Errorf("%s: expected to successfully lookup host entry", name)
595			continue
596		}
597		if len(addrs) != 1 {
598			t.Errorf("%s: expected exactly one result, but got %v", name, addrs)
599			continue
600		}
601		if got, want := addrs[0].String(), "127.1.1.1"; got != want {
602			t.Errorf("%s: address doesn't match expectation. got %v, want %v", name, got, want)
603		}
604	}
605	defer conf.teardown()
606}
607
608// Issue 12712.
609// When using search domains, return the error encountered
610// querying the original name instead of an error encountered
611// querying a generated name.
612func TestErrorForOriginalNameWhenSearching(t *testing.T) {
613	defer dnsWaitGroup.Wait()
614
615	const fqdn = "doesnotexist.domain"
616
617	conf, err := newResolvConfTest()
618	if err != nil {
619		t.Fatal(err)
620	}
621	defer conf.teardown()
622
623	if err := conf.writeAndUpdate([]string{"search servfail"}); err != nil {
624		t.Fatal(err)
625	}
626
627	fake := fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
628		r := &dnsMsg{
629			dnsMsgHdr: dnsMsgHdr{
630				id:       q.id,
631				response: true,
632			},
633			question: q.question,
634		}
635
636		switch q.question[0].Name {
637		case fqdn + ".servfail.":
638			r.rcode = dnsRcodeServerFailure
639		default:
640			r.rcode = dnsRcodeNameError
641		}
642
643		return r, nil
644	}}
645
646	cases := []struct {
647		strictErrors bool
648		wantErr      *DNSError
649	}{
650		{true, &DNSError{Name: fqdn, Err: "server misbehaving", IsTemporary: true}},
651		{false, &DNSError{Name: fqdn, Err: errNoSuchHost.Error()}},
652	}
653	for _, tt := range cases {
654		r := Resolver{PreferGo: true, StrictErrors: tt.strictErrors, Dial: fake.DialContext}
655		_, err = r.LookupIPAddr(context.Background(), fqdn)
656		if err == nil {
657			t.Fatal("expected an error")
658		}
659
660		want := tt.wantErr
661		if err, ok := err.(*DNSError); !ok || err.Name != want.Name || err.Err != want.Err || err.IsTemporary != want.IsTemporary {
662			t.Errorf("got %v; want %v", err, want)
663		}
664	}
665}
666
667// Issue 15434. If a name server gives a lame referral, continue to the next.
668func TestIgnoreLameReferrals(t *testing.T) {
669	defer dnsWaitGroup.Wait()
670
671	conf, err := newResolvConfTest()
672	if err != nil {
673		t.Fatal(err)
674	}
675	defer conf.teardown()
676
677	if err := conf.writeAndUpdate([]string{"nameserver 192.0.2.1", // the one that will give a lame referral
678		"nameserver 192.0.2.2"}); err != nil {
679		t.Fatal(err)
680	}
681
682	fake := fakeDNSServer{func(_, s string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
683		t.Log(s, q)
684		r := &dnsMsg{
685			dnsMsgHdr: dnsMsgHdr{
686				id:       q.id,
687				response: true,
688			},
689			question: q.question,
690		}
691
692		if s == "192.0.2.2:53" {
693			r.recursion_available = true
694			if q.question[0].Qtype == dnsTypeA {
695				r.answer = []dnsRR{
696					&dnsRR_A{
697						Hdr: dnsRR_Header{
698							Name:     q.question[0].Name,
699							Rrtype:   dnsTypeA,
700							Class:    dnsClassINET,
701							Rdlength: 4,
702						},
703						A: TestAddr,
704					},
705				}
706			}
707		}
708
709		return r, nil
710	}}
711	r := Resolver{PreferGo: true, Dial: fake.DialContext}
712
713	addrs, err := r.LookupIPAddr(context.Background(), "www.golang.org")
714	if err != nil {
715		t.Fatal(err)
716	}
717
718	if got := len(addrs); got != 1 {
719		t.Fatalf("got %d addresses, want 1", got)
720	}
721
722	if got, want := addrs[0].String(), "192.0.2.1"; got != want {
723		t.Fatalf("got address %v, want %v", got, want)
724	}
725}
726
727func BenchmarkGoLookupIP(b *testing.B) {
728	testHookUninstaller.Do(uninstallTestHooks)
729	ctx := context.Background()
730
731	for i := 0; i < b.N; i++ {
732		goResolver.LookupIPAddr(ctx, "www.example.com")
733	}
734}
735
736func BenchmarkGoLookupIPNoSuchHost(b *testing.B) {
737	testHookUninstaller.Do(uninstallTestHooks)
738	ctx := context.Background()
739
740	for i := 0; i < b.N; i++ {
741		goResolver.LookupIPAddr(ctx, "some.nonexistent")
742	}
743}
744
745func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) {
746	testHookUninstaller.Do(uninstallTestHooks)
747
748	conf, err := newResolvConfTest()
749	if err != nil {
750		b.Fatal(err)
751	}
752	defer conf.teardown()
753
754	lines := []string{
755		"nameserver 203.0.113.254", // use TEST-NET-3 block, see RFC 5737
756		"nameserver 8.8.8.8",
757	}
758	if err := conf.writeAndUpdate(lines); err != nil {
759		b.Fatal(err)
760	}
761	ctx := context.Background()
762
763	for i := 0; i < b.N; i++ {
764		goResolver.LookupIPAddr(ctx, "www.example.com")
765	}
766}
767
768type fakeDNSServer struct {
769	rh func(n, s string, q *dnsMsg, t time.Time) (*dnsMsg, error)
770}
771
772func (server *fakeDNSServer) DialContext(_ context.Context, n, s string) (Conn, error) {
773	return &fakeDNSConn{nil, server, n, s, nil, time.Time{}}, nil
774}
775
776type fakeDNSConn struct {
777	Conn
778	server *fakeDNSServer
779	n      string
780	s      string
781	q      *dnsMsg
782	t      time.Time
783}
784
785func (f *fakeDNSConn) Close() error {
786	return nil
787}
788
789func (f *fakeDNSConn) Read(b []byte) (int, error) {
790	resp, err := f.server.rh(f.n, f.s, f.q, f.t)
791	if err != nil {
792		return 0, err
793	}
794
795	bb, ok := resp.Pack()
796	if !ok {
797		return 0, errors.New("cannot marshal DNS message")
798	}
799	if len(b) < len(bb) {
800		return 0, errors.New("read would fragment DNS message")
801	}
802
803	copy(b, bb)
804	return len(bb), nil
805}
806
807func (f *fakeDNSConn) ReadFrom(b []byte) (int, Addr, error) {
808	return 0, nil, nil
809}
810
811func (f *fakeDNSConn) Write(b []byte) (int, error) {
812	f.q = new(dnsMsg)
813	if !f.q.Unpack(b) {
814		return 0, errors.New("cannot unmarshal DNS message")
815	}
816	return len(b), nil
817}
818
819func (f *fakeDNSConn) WriteTo(b []byte, addr Addr) (int, error) {
820	return 0, nil
821}
822
823func (f *fakeDNSConn) SetDeadline(t time.Time) error {
824	f.t = t
825	return nil
826}
827
828// UDP round-tripper algorithm should ignore invalid DNS responses (issue 13281).
829func TestIgnoreDNSForgeries(t *testing.T) {
830	c, s := Pipe()
831	go func() {
832		b := make([]byte, 512)
833		n, err := s.Read(b)
834		if err != nil {
835			t.Error(err)
836			return
837		}
838
839		msg := &dnsMsg{}
840		if !msg.Unpack(b[:n]) {
841			t.Error("invalid DNS query")
842			return
843		}
844
845		s.Write([]byte("garbage DNS response packet"))
846
847		msg.response = true
848		msg.id++ // make invalid ID
849		b, ok := msg.Pack()
850		if !ok {
851			t.Error("failed to pack DNS response")
852			return
853		}
854		s.Write(b)
855
856		msg.id-- // restore original ID
857		msg.answer = []dnsRR{
858			&dnsRR_A{
859				Hdr: dnsRR_Header{
860					Name:     "www.example.com.",
861					Rrtype:   dnsTypeA,
862					Class:    dnsClassINET,
863					Rdlength: 4,
864				},
865				A: TestAddr,
866			},
867		}
868
869		b, ok = msg.Pack()
870		if !ok {
871			t.Error("failed to pack DNS response")
872			return
873		}
874		s.Write(b)
875	}()
876
877	msg := &dnsMsg{
878		dnsMsgHdr: dnsMsgHdr{
879			id: 42,
880		},
881		question: []dnsQuestion{
882			{
883				Name:   "www.example.com.",
884				Qtype:  dnsTypeA,
885				Qclass: dnsClassINET,
886			},
887		},
888	}
889
890	dc := &dnsPacketConn{c}
891	resp, err := dc.dnsRoundTrip(msg)
892	if err != nil {
893		t.Fatalf("dnsRoundTripUDP failed: %v", err)
894	}
895
896	if got := resp.answer[0].(*dnsRR_A).A; got != TestAddr {
897		t.Errorf("got address %v, want %v", got, TestAddr)
898	}
899}
900
901// Issue 16865. If a name server times out, continue to the next.
902func TestRetryTimeout(t *testing.T) {
903	defer dnsWaitGroup.Wait()
904
905	conf, err := newResolvConfTest()
906	if err != nil {
907		t.Fatal(err)
908	}
909	defer conf.teardown()
910
911	testConf := []string{
912		"nameserver 192.0.2.1", // the one that will timeout
913		"nameserver 192.0.2.2",
914	}
915	if err := conf.writeAndUpdate(testConf); err != nil {
916		t.Fatal(err)
917	}
918
919	var deadline0 time.Time
920
921	fake := fakeDNSServer{func(_, s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) {
922		t.Log(s, q, deadline)
923
924		if deadline.IsZero() {
925			t.Error("zero deadline")
926		}
927
928		if s == "192.0.2.1:53" {
929			deadline0 = deadline
930			time.Sleep(10 * time.Millisecond)
931			return nil, poll.ErrTimeout
932		}
933
934		if deadline.Equal(deadline0) {
935			t.Error("deadline didn't change")
936		}
937
938		return mockTXTResponse(q), nil
939	}}
940	r := &Resolver{PreferGo: true, Dial: fake.DialContext}
941
942	_, err = r.LookupTXT(context.Background(), "www.golang.org")
943	if err != nil {
944		t.Fatal(err)
945	}
946
947	if deadline0.IsZero() {
948		t.Error("deadline0 still zero", deadline0)
949	}
950}
951
952func TestRotate(t *testing.T) {
953	// without rotation, always uses the first server
954	testRotate(t, false, []string{"192.0.2.1", "192.0.2.2"}, []string{"192.0.2.1:53", "192.0.2.1:53", "192.0.2.1:53"})
955
956	// with rotation, rotates through back to first
957	testRotate(t, true, []string{"192.0.2.1", "192.0.2.2"}, []string{"192.0.2.1:53", "192.0.2.2:53", "192.0.2.1:53"})
958}
959
960func testRotate(t *testing.T, rotate bool, nameservers, wantServers []string) {
961	defer dnsWaitGroup.Wait()
962
963	conf, err := newResolvConfTest()
964	if err != nil {
965		t.Fatal(err)
966	}
967	defer conf.teardown()
968
969	var confLines []string
970	for _, ns := range nameservers {
971		confLines = append(confLines, "nameserver "+ns)
972	}
973	if rotate {
974		confLines = append(confLines, "options rotate")
975	}
976
977	if err := conf.writeAndUpdate(confLines); err != nil {
978		t.Fatal(err)
979	}
980
981	var usedServers []string
982	fake := fakeDNSServer{func(_, s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) {
983		usedServers = append(usedServers, s)
984		return mockTXTResponse(q), nil
985	}}
986	r := Resolver{PreferGo: true, Dial: fake.DialContext}
987
988	// len(nameservers) + 1 to allow rotation to get back to start
989	for i := 0; i < len(nameservers)+1; i++ {
990		if _, err := r.LookupTXT(context.Background(), "www.golang.org"); err != nil {
991			t.Fatal(err)
992		}
993	}
994
995	if !reflect.DeepEqual(usedServers, wantServers) {
996		t.Errorf("rotate=%t got used servers:\n%v\nwant:\n%v", rotate, usedServers, wantServers)
997	}
998}
999
1000func mockTXTResponse(q *dnsMsg) *dnsMsg {
1001	r := &dnsMsg{
1002		dnsMsgHdr: dnsMsgHdr{
1003			id:                  q.id,
1004			response:            true,
1005			recursion_available: true,
1006		},
1007		question: q.question,
1008		answer: []dnsRR{
1009			&dnsRR_TXT{
1010				Hdr: dnsRR_Header{
1011					Name:   q.question[0].Name,
1012					Rrtype: dnsTypeTXT,
1013					Class:  dnsClassINET,
1014				},
1015				Txt: "ok",
1016			},
1017		},
1018	}
1019
1020	return r
1021}
1022
1023// Issue 17448. With StrictErrors enabled, temporary errors should make
1024// LookupIP fail rather than return a partial result.
1025func TestStrictErrorsLookupIP(t *testing.T) {
1026	defer dnsWaitGroup.Wait()
1027
1028	conf, err := newResolvConfTest()
1029	if err != nil {
1030		t.Fatal(err)
1031	}
1032	defer conf.teardown()
1033
1034	confData := []string{
1035		"nameserver 192.0.2.53",
1036		"search x.golang.org y.golang.org",
1037	}
1038	if err := conf.writeAndUpdate(confData); err != nil {
1039		t.Fatal(err)
1040	}
1041
1042	const name = "test-issue19592"
1043	const server = "192.0.2.53:53"
1044	const searchX = "test-issue19592.x.golang.org."
1045	const searchY = "test-issue19592.y.golang.org."
1046	const ip4 = "192.0.2.1"
1047	const ip6 = "2001:db8::1"
1048
1049	type resolveWhichEnum int
1050	const (
1051		resolveOK resolveWhichEnum = iota
1052		resolveOpError
1053		resolveServfail
1054		resolveTimeout
1055	)
1056
1057	makeTempError := func(err string) error {
1058		return &DNSError{
1059			Err:         err,
1060			Name:        name,
1061			Server:      server,
1062			IsTemporary: true,
1063		}
1064	}
1065	makeTimeout := func() error {
1066		return &DNSError{
1067			Err:       poll.ErrTimeout.Error(),
1068			Name:      name,
1069			Server:    server,
1070			IsTimeout: true,
1071		}
1072	}
1073	makeNxDomain := func() error {
1074		return &DNSError{
1075			Err:    errNoSuchHost.Error(),
1076			Name:   name,
1077			Server: server,
1078		}
1079	}
1080
1081	cases := []struct {
1082		desc          string
1083		resolveWhich  func(quest *dnsQuestion) resolveWhichEnum
1084		wantStrictErr error
1085		wantLaxErr    error
1086		wantIPs       []string
1087	}{
1088		{
1089			desc: "No errors",
1090			resolveWhich: func(quest *dnsQuestion) resolveWhichEnum {
1091				return resolveOK
1092			},
1093			wantIPs: []string{ip4, ip6},
1094		},
1095		{
1096			desc: "searchX error fails in strict mode",
1097			resolveWhich: func(quest *dnsQuestion) resolveWhichEnum {
1098				if quest.Name == searchX {
1099					return resolveTimeout
1100				}
1101				return resolveOK
1102			},
1103			wantStrictErr: makeTimeout(),
1104			wantIPs:       []string{ip4, ip6},
1105		},
1106		{
1107			desc: "searchX IPv4-only timeout fails in strict mode",
1108			resolveWhich: func(quest *dnsQuestion) resolveWhichEnum {
1109				if quest.Name == searchX && quest.Qtype == dnsTypeA {
1110					return resolveTimeout
1111				}
1112				return resolveOK
1113			},
1114			wantStrictErr: makeTimeout(),
1115			wantIPs:       []string{ip4, ip6},
1116		},
1117		{
1118			desc: "searchX IPv6-only servfail fails in strict mode",
1119			resolveWhich: func(quest *dnsQuestion) resolveWhichEnum {
1120				if quest.Name == searchX && quest.Qtype == dnsTypeAAAA {
1121					return resolveServfail
1122				}
1123				return resolveOK
1124			},
1125			wantStrictErr: makeTempError("server misbehaving"),
1126			wantIPs:       []string{ip4, ip6},
1127		},
1128		{
1129			desc: "searchY error always fails",
1130			resolveWhich: func(quest *dnsQuestion) resolveWhichEnum {
1131				if quest.Name == searchY {
1132					return resolveTimeout
1133				}
1134				return resolveOK
1135			},
1136			wantStrictErr: makeTimeout(),
1137			wantLaxErr:    makeNxDomain(), // This one reaches the "test." FQDN.
1138		},
1139		{
1140			desc: "searchY IPv4-only socket error fails in strict mode",
1141			resolveWhich: func(quest *dnsQuestion) resolveWhichEnum {
1142				if quest.Name == searchY && quest.Qtype == dnsTypeA {
1143					return resolveOpError
1144				}
1145				return resolveOK
1146			},
1147			wantStrictErr: makeTempError("write: socket on fire"),
1148			wantIPs:       []string{ip6},
1149		},
1150		{
1151			desc: "searchY IPv6-only timeout fails in strict mode",
1152			resolveWhich: func(quest *dnsQuestion) resolveWhichEnum {
1153				if quest.Name == searchY && quest.Qtype == dnsTypeAAAA {
1154					return resolveTimeout
1155				}
1156				return resolveOK
1157			},
1158			wantStrictErr: makeTimeout(),
1159			wantIPs:       []string{ip4},
1160		},
1161	}
1162
1163	for i, tt := range cases {
1164		fake := fakeDNSServer{func(_, s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) {
1165			t.Log(s, q)
1166
1167			switch tt.resolveWhich(&q.question[0]) {
1168			case resolveOK:
1169				// Handle below.
1170			case resolveOpError:
1171				return nil, &OpError{Op: "write", Err: fmt.Errorf("socket on fire")}
1172			case resolveServfail:
1173				return &dnsMsg{
1174					dnsMsgHdr: dnsMsgHdr{
1175						id:       q.id,
1176						response: true,
1177						rcode:    dnsRcodeServerFailure,
1178					},
1179					question: q.question,
1180				}, nil
1181			case resolveTimeout:
1182				return nil, poll.ErrTimeout
1183			default:
1184				t.Fatal("Impossible resolveWhich")
1185			}
1186
1187			switch q.question[0].Name {
1188			case searchX, name + ".":
1189				// Return NXDOMAIN to utilize the search list.
1190				return &dnsMsg{
1191					dnsMsgHdr: dnsMsgHdr{
1192						id:       q.id,
1193						response: true,
1194						rcode:    dnsRcodeNameError,
1195					},
1196					question: q.question,
1197				}, nil
1198			case searchY:
1199				// Return records below.
1200			default:
1201				return nil, fmt.Errorf("Unexpected Name: %v", q.question[0].Name)
1202			}
1203
1204			r := &dnsMsg{
1205				dnsMsgHdr: dnsMsgHdr{
1206					id:       q.id,
1207					response: true,
1208				},
1209				question: q.question,
1210			}
1211			switch q.question[0].Qtype {
1212			case dnsTypeA:
1213				r.answer = []dnsRR{
1214					&dnsRR_A{
1215						Hdr: dnsRR_Header{
1216							Name:     q.question[0].Name,
1217							Rrtype:   dnsTypeA,
1218							Class:    dnsClassINET,
1219							Rdlength: 4,
1220						},
1221						A: TestAddr,
1222					},
1223				}
1224			case dnsTypeAAAA:
1225				r.answer = []dnsRR{
1226					&dnsRR_AAAA{
1227						Hdr: dnsRR_Header{
1228							Name:     q.question[0].Name,
1229							Rrtype:   dnsTypeAAAA,
1230							Class:    dnsClassINET,
1231							Rdlength: 16,
1232						},
1233						AAAA: VarTestAddr6,
1234					},
1235				}
1236			default:
1237				return nil, fmt.Errorf("Unexpected Qtype: %v", q.question[0].Qtype)
1238			}
1239			return r, nil
1240		}}
1241
1242		for _, strict := range []bool{true, false} {
1243			r := Resolver{PreferGo: true, StrictErrors: strict, Dial: fake.DialContext}
1244			ips, err := r.LookupIPAddr(context.Background(), name)
1245
1246			var wantErr error
1247			if strict {
1248				wantErr = tt.wantStrictErr
1249			} else {
1250				wantErr = tt.wantLaxErr
1251			}
1252			if !reflect.DeepEqual(err, wantErr) {
1253				t.Errorf("#%d (%s) strict=%v: got err %#v; want %#v", i, tt.desc, strict, err, wantErr)
1254			}
1255
1256			gotIPs := map[string]struct{}{}
1257			for _, ip := range ips {
1258				gotIPs[ip.String()] = struct{}{}
1259			}
1260			wantIPs := map[string]struct{}{}
1261			if wantErr == nil {
1262				for _, ip := range tt.wantIPs {
1263					wantIPs[ip] = struct{}{}
1264				}
1265			}
1266			if !reflect.DeepEqual(gotIPs, wantIPs) {
1267				t.Errorf("#%d (%s) strict=%v: got ips %v; want %v", i, tt.desc, strict, gotIPs, wantIPs)
1268			}
1269		}
1270	}
1271}
1272
1273// Issue 17448. With StrictErrors enabled, temporary errors should make
1274// LookupTXT stop walking the search list.
1275func TestStrictErrorsLookupTXT(t *testing.T) {
1276	defer dnsWaitGroup.Wait()
1277
1278	conf, err := newResolvConfTest()
1279	if err != nil {
1280		t.Fatal(err)
1281	}
1282	defer conf.teardown()
1283
1284	confData := []string{
1285		"nameserver 192.0.2.53",
1286		"search x.golang.org y.golang.org",
1287	}
1288	if err := conf.writeAndUpdate(confData); err != nil {
1289		t.Fatal(err)
1290	}
1291
1292	const name = "test"
1293	const server = "192.0.2.53:53"
1294	const searchX = "test.x.golang.org."
1295	const searchY = "test.y.golang.org."
1296	const txt = "Hello World"
1297
1298	fake := fakeDNSServer{func(_, s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) {
1299		t.Log(s, q)
1300
1301		switch q.question[0].Name {
1302		case searchX:
1303			return nil, poll.ErrTimeout
1304		case searchY:
1305			return mockTXTResponse(q), nil
1306		default:
1307			return nil, fmt.Errorf("Unexpected Name: %v", q.question[0].Name)
1308		}
1309	}}
1310
1311	for _, strict := range []bool{true, false} {
1312		r := Resolver{StrictErrors: strict, Dial: fake.DialContext}
1313		_, rrs, err := r.lookup(context.Background(), name, dnsTypeTXT)
1314		var wantErr error
1315		var wantRRs int
1316		if strict {
1317			wantErr = &DNSError{
1318				Err:       poll.ErrTimeout.Error(),
1319				Name:      name,
1320				Server:    server,
1321				IsTimeout: true,
1322			}
1323		} else {
1324			wantRRs = 1
1325		}
1326		if !reflect.DeepEqual(err, wantErr) {
1327			t.Errorf("strict=%v: got err %#v; want %#v", strict, err, wantErr)
1328		}
1329		if len(rrs) != wantRRs {
1330			t.Errorf("strict=%v: got %v; want %v", strict, len(rrs), wantRRs)
1331		}
1332	}
1333}
1334
1335// Test for a race between uninstalling the test hooks and closing a
1336// socket connection. This used to fail when testing with -race.
1337func TestDNSGoroutineRace(t *testing.T) {
1338	defer dnsWaitGroup.Wait()
1339
1340	fake := fakeDNSServer{func(n, s string, q *dnsMsg, t time.Time) (*dnsMsg, error) {
1341		time.Sleep(10 * time.Microsecond)
1342		return nil, poll.ErrTimeout
1343	}}
1344	r := Resolver{PreferGo: true, Dial: fake.DialContext}
1345
1346	// The timeout here is less than the timeout used by the server,
1347	// so the goroutine started to query the (fake) server will hang
1348	// around after this test is done if we don't call dnsWaitGroup.Wait.
1349	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Microsecond)
1350	defer cancel()
1351	_, err := r.LookupIPAddr(ctx, "where.are.they.now")
1352	if err == nil {
1353		t.Fatal("fake DNS lookup unexpectedly succeeded")
1354	}
1355}
1356