1package dns
2
3import (
4	"context"
5	"crypto/tls"
6	"errors"
7	"fmt"
8	"net"
9	"strconv"
10	"strings"
11	"testing"
12	"time"
13)
14
15func TestDialUDP(t *testing.T) {
16	HandleFunc("miek.nl.", HelloServer)
17	defer HandleRemove("miek.nl.")
18
19	s, addrstr, _, err := RunLocalUDPServer(":0")
20	if err != nil {
21		t.Fatalf("unable to run test server: %v", err)
22	}
23	defer s.Shutdown()
24
25	m := new(Msg)
26	m.SetQuestion("miek.nl.", TypeSOA)
27
28	c := new(Client)
29	conn, err := c.Dial(addrstr)
30	if err != nil {
31		t.Fatalf("failed to dial: %v", err)
32	}
33	if conn == nil {
34		t.Fatalf("conn is nil")
35	}
36}
37
38func TestClientSync(t *testing.T) {
39	HandleFunc("miek.nl.", HelloServer)
40	defer HandleRemove("miek.nl.")
41
42	s, addrstr, _, err := RunLocalUDPServer(":0")
43	if err != nil {
44		t.Fatalf("unable to run test server: %v", err)
45	}
46	defer s.Shutdown()
47
48	m := new(Msg)
49	m.SetQuestion("miek.nl.", TypeSOA)
50
51	c := new(Client)
52	r, _, err := c.Exchange(m, addrstr)
53	if err != nil {
54		t.Fatalf("failed to exchange: %v", err)
55	}
56	if r == nil {
57		t.Fatal("response is nil")
58	}
59	if r.Rcode != RcodeSuccess {
60		t.Errorf("failed to get an valid answer\n%v", r)
61	}
62	// And now with plain Exchange().
63	r, err = Exchange(m, addrstr)
64	if err != nil {
65		t.Errorf("failed to exchange: %v", err)
66	}
67	if r == nil || r.Rcode != RcodeSuccess {
68		t.Errorf("failed to get an valid answer\n%v", r)
69	}
70}
71
72func TestClientLocalAddress(t *testing.T) {
73	HandleFunc("miek.nl.", HelloServerEchoAddrPort)
74	defer HandleRemove("miek.nl.")
75
76	s, addrstr, _, err := RunLocalUDPServer(":0")
77	if err != nil {
78		t.Fatalf("unable to run test server: %v", err)
79	}
80	defer s.Shutdown()
81
82	m := new(Msg)
83	m.SetQuestion("miek.nl.", TypeSOA)
84
85	c := new(Client)
86	laddr := net.UDPAddr{IP: net.ParseIP("0.0.0.0"), Port: 12345, Zone: ""}
87	c.Dialer = &net.Dialer{LocalAddr: &laddr}
88	r, _, err := c.Exchange(m, addrstr)
89	if err != nil {
90		t.Fatalf("failed to exchange: %v", err)
91	}
92	if r != nil && r.Rcode != RcodeSuccess {
93		t.Errorf("failed to get an valid answer\n%v", r)
94	}
95	if len(r.Extra) != 1 {
96		t.Fatalf("failed to get additional answers\n%v", r)
97	}
98	txt := r.Extra[0].(*TXT)
99	if txt == nil {
100		t.Errorf("invalid TXT response\n%v", txt)
101	}
102	if len(txt.Txt) != 1 || !strings.Contains(txt.Txt[0], ":12345") {
103		t.Errorf("invalid TXT response\n%v", txt.Txt)
104	}
105}
106
107func TestClientTLSSyncV4(t *testing.T) {
108	HandleFunc("miek.nl.", HelloServer)
109	defer HandleRemove("miek.nl.")
110
111	cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock)
112	if err != nil {
113		t.Fatalf("unable to build certificate: %v", err)
114	}
115
116	config := tls.Config{
117		Certificates: []tls.Certificate{cert},
118	}
119
120	s, addrstr, _, err := RunLocalTLSServer(":0", &config)
121	if err != nil {
122		t.Fatalf("unable to run test server: %v", err)
123	}
124	defer s.Shutdown()
125
126	m := new(Msg)
127	m.SetQuestion("miek.nl.", TypeSOA)
128
129	c := new(Client)
130
131	// test tcp-tls
132	c.Net = "tcp-tls"
133	c.TLSConfig = &tls.Config{
134		InsecureSkipVerify: true,
135	}
136
137	r, _, err := c.Exchange(m, addrstr)
138	if err != nil {
139		t.Fatalf("failed to exchange: %v", err)
140	}
141	if r == nil {
142		t.Fatal("response is nil")
143	}
144	if r.Rcode != RcodeSuccess {
145		t.Errorf("failed to get an valid answer\n%v", r)
146	}
147
148	// test tcp4-tls
149	c.Net = "tcp4-tls"
150	c.TLSConfig = &tls.Config{
151		InsecureSkipVerify: true,
152	}
153
154	r, _, err = c.Exchange(m, addrstr)
155	if err != nil {
156		t.Fatalf("failed to exchange: %v", err)
157	}
158	if r == nil {
159		t.Fatal("response is nil")
160	}
161	if r.Rcode != RcodeSuccess {
162		t.Errorf("failed to get an valid answer\n%v", r)
163	}
164}
165
166func isNetworkTimeout(err error) bool {
167	// TODO: when Go 1.14 support is dropped, do this: https://golang.org/doc/go1.15#net
168	var netError net.Error
169	return errors.As(err, &netError) && netError.Timeout()
170}
171
172func TestClientSyncBadID(t *testing.T) {
173	HandleFunc("miek.nl.", HelloServerBadID)
174	defer HandleRemove("miek.nl.")
175
176	s, addrstr, _, err := RunLocalUDPServer(":0")
177	if err != nil {
178		t.Fatalf("unable to run test server: %v", err)
179	}
180	defer s.Shutdown()
181
182	m := new(Msg)
183	m.SetQuestion("miek.nl.", TypeSOA)
184
185	c := &Client{
186		Timeout: 50 * time.Millisecond,
187	}
188	if _, _, err := c.Exchange(m, addrstr); err == nil || !isNetworkTimeout(err) {
189		t.Errorf("query did not time out")
190	}
191	// And now with plain Exchange().
192	if _, err = Exchange(m, addrstr); err == nil || !isNetworkTimeout(err) {
193		t.Errorf("query did not time out")
194	}
195}
196
197func TestClientSyncBadThenGoodID(t *testing.T) {
198	HandleFunc("miek.nl.", HelloServerBadThenGoodID)
199	defer HandleRemove("miek.nl.")
200
201	s, addrstr, _, err := RunLocalUDPServer(":0")
202	if err != nil {
203		t.Fatalf("unable to run test server: %v", err)
204	}
205	defer s.Shutdown()
206
207	m := new(Msg)
208	m.SetQuestion("miek.nl.", TypeSOA)
209
210	c := new(Client)
211	r, _, err := c.Exchange(m, addrstr)
212	if err != nil {
213		t.Errorf("failed to exchange: %v", err)
214	}
215	if r.Id != m.Id {
216		t.Errorf("failed to get response with expected Id")
217	}
218	// And now with plain Exchange().
219	r, err = Exchange(m, addrstr)
220	if err != nil {
221		t.Errorf("failed to exchange: %v", err)
222	}
223	if r.Id != m.Id {
224		t.Errorf("failed to get response with expected Id")
225	}
226}
227
228func TestClientSyncTCPBadID(t *testing.T) {
229	HandleFunc("miek.nl.", HelloServerBadID)
230	defer HandleRemove("miek.nl.")
231
232	s, addrstr, _, err := RunLocalTCPServer(":0")
233	if err != nil {
234		t.Fatalf("unable to run test server: %v", err)
235	}
236	defer s.Shutdown()
237
238	m := new(Msg)
239	m.SetQuestion("miek.nl.", TypeSOA)
240
241	c := &Client{
242		Net: "tcp",
243	}
244	if _, _, err := c.Exchange(m, addrstr); err != ErrId {
245		t.Errorf("did not find a bad Id")
246	}
247}
248
249func TestClientEDNS0(t *testing.T) {
250	HandleFunc("miek.nl.", HelloServer)
251	defer HandleRemove("miek.nl.")
252
253	s, addrstr, _, err := RunLocalUDPServer(":0")
254	if err != nil {
255		t.Fatalf("unable to run test server: %v", err)
256	}
257	defer s.Shutdown()
258
259	m := new(Msg)
260	m.SetQuestion("miek.nl.", TypeDNSKEY)
261
262	m.SetEdns0(2048, true)
263
264	c := new(Client)
265	r, _, err := c.Exchange(m, addrstr)
266	if err != nil {
267		t.Fatalf("failed to exchange: %v", err)
268	}
269
270	if r != nil && r.Rcode != RcodeSuccess {
271		t.Errorf("failed to get a valid answer\n%v", r)
272	}
273}
274
275// Validates the transmission and parsing of local EDNS0 options.
276func TestClientEDNS0Local(t *testing.T) {
277	optStr1 := "1979:0x0707"
278	optStr2 := strconv.Itoa(EDNS0LOCALSTART) + ":0x0601"
279
280	handler := func(w ResponseWriter, req *Msg) {
281		m := new(Msg)
282		m.SetReply(req)
283
284		m.Extra = make([]RR, 1, 2)
285		m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{"Hello local edns"}}
286
287		// If the local options are what we expect, then reflect them back.
288		ec1 := req.Extra[0].(*OPT).Option[0].(*EDNS0_LOCAL).String()
289		ec2 := req.Extra[0].(*OPT).Option[1].(*EDNS0_LOCAL).String()
290		if ec1 == optStr1 && ec2 == optStr2 {
291			m.Extra = append(m.Extra, req.Extra[0])
292		}
293
294		w.WriteMsg(m)
295	}
296
297	HandleFunc("miek.nl.", handler)
298	defer HandleRemove("miek.nl.")
299
300	s, addrstr, _, err := RunLocalUDPServer(":0")
301	if err != nil {
302		t.Fatalf("unable to run test server: %s", err)
303	}
304	defer s.Shutdown()
305
306	m := new(Msg)
307	m.SetQuestion("miek.nl.", TypeTXT)
308
309	// Add two local edns options to the query.
310	ec1 := &EDNS0_LOCAL{Code: 1979, Data: []byte{7, 7}}
311	ec2 := &EDNS0_LOCAL{Code: EDNS0LOCALSTART, Data: []byte{6, 1}}
312	o := &OPT{Hdr: RR_Header{Name: ".", Rrtype: TypeOPT}, Option: []EDNS0{ec1, ec2}}
313	m.Extra = append(m.Extra, o)
314
315	c := new(Client)
316	r, _, err := c.Exchange(m, addrstr)
317	if err != nil {
318		t.Fatalf("failed to exchange: %s", err)
319	}
320
321	if r == nil {
322		t.Fatal("response is nil")
323	}
324	if r.Rcode != RcodeSuccess {
325		t.Fatal("failed to get a valid answer")
326	}
327
328	txt := r.Extra[0].(*TXT).Txt[0]
329	if txt != "Hello local edns" {
330		t.Error("Unexpected result for miek.nl", txt, "!= Hello local edns")
331	}
332
333	// Validate the local options in the reply.
334	got := r.Extra[1].(*OPT).Option[0].(*EDNS0_LOCAL).String()
335	if got != optStr1 {
336		t.Errorf("failed to get local edns0 answer; got %s, expected %s", got, optStr1)
337	}
338
339	got = r.Extra[1].(*OPT).Option[1].(*EDNS0_LOCAL).String()
340	if got != optStr2 {
341		t.Errorf("failed to get local edns0 answer; got %s, expected %s", got, optStr2)
342	}
343}
344
345func TestClientConn(t *testing.T) {
346	HandleFunc("miek.nl.", HelloServer)
347	defer HandleRemove("miek.nl.")
348
349	// This uses TCP just to make it slightly different than TestClientSync
350	s, addrstr, _, err := RunLocalTCPServer(":0")
351	if err != nil {
352		t.Fatalf("unable to run test server: %v", err)
353	}
354	defer s.Shutdown()
355
356	m := new(Msg)
357	m.SetQuestion("miek.nl.", TypeSOA)
358
359	cn, err := Dial("tcp", addrstr)
360	if err != nil {
361		t.Errorf("failed to dial %s: %v", addrstr, err)
362	}
363
364	err = cn.WriteMsg(m)
365	if err != nil {
366		t.Errorf("failed to exchange: %v", err)
367	}
368	r, err := cn.ReadMsg()
369	if err != nil {
370		t.Errorf("failed to get a valid answer: %v", err)
371	}
372	if r == nil || r.Rcode != RcodeSuccess {
373		t.Errorf("failed to get an valid answer\n%v", r)
374	}
375
376	err = cn.WriteMsg(m)
377	if err != nil {
378		t.Errorf("failed to exchange: %v", err)
379	}
380	h := new(Header)
381	buf, err := cn.ReadMsgHeader(h)
382	if buf == nil {
383		t.Errorf("failed to get an valid answer\n%v", r)
384	}
385	if err != nil {
386		t.Errorf("failed to get a valid answer: %v", err)
387	}
388	if int(h.Bits&0xF) != RcodeSuccess {
389		t.Errorf("failed to get an valid answer in ReadMsgHeader\n%v", r)
390	}
391	if h.Ancount != 0 || h.Qdcount != 1 || h.Nscount != 0 || h.Arcount != 1 {
392		t.Errorf("expected to have question and additional in response; got something else: %+v", h)
393	}
394	if err = r.Unpack(buf); err != nil {
395		t.Errorf("unable to unpack message fully: %v", err)
396	}
397}
398
399func TestClientConnWriteSinglePacket(t *testing.T) {
400	c := &countingConn{}
401	conn := Conn{
402		Conn: c,
403	}
404	m := new(Msg)
405	m.SetQuestion("miek.nl.", TypeTXT)
406	err := conn.WriteMsg(m)
407
408	if err != nil {
409		t.Fatalf("failed to write: %v", err)
410	}
411
412	if c.writes != 1 {
413		t.Fatalf("incorrect number of Write calls")
414	}
415}
416
417func TestTruncatedMsg(t *testing.T) {
418	m := new(Msg)
419	m.SetQuestion("miek.nl.", TypeSRV)
420	cnt := 10
421	for i := 0; i < cnt; i++ {
422		r := &SRV{
423			Hdr:    RR_Header{Name: m.Question[0].Name, Rrtype: TypeSRV, Class: ClassINET, Ttl: 0},
424			Port:   uint16(i + 8000),
425			Target: "target.miek.nl.",
426		}
427		m.Answer = append(m.Answer, r)
428
429		re := &A{
430			Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeA, Class: ClassINET, Ttl: 0},
431			A:   net.ParseIP(fmt.Sprintf("127.0.0.%d", i)).To4(),
432		}
433		m.Extra = append(m.Extra, re)
434	}
435	buf, err := m.Pack()
436	if err != nil {
437		t.Errorf("failed to pack: %v", err)
438	}
439
440	r := new(Msg)
441	if err = r.Unpack(buf); err != nil {
442		t.Errorf("unable to unpack message: %v", err)
443	}
444	if len(r.Answer) != cnt {
445		t.Errorf("answer count after regular unpack doesn't match: %d", len(r.Answer))
446	}
447	if len(r.Extra) != cnt {
448		t.Errorf("extra count after regular unpack doesn't match: %d", len(r.Extra))
449	}
450
451	m.Truncated = true
452	buf, err = m.Pack()
453	if err != nil {
454		t.Errorf("failed to pack truncated message: %v", err)
455	}
456
457	r = new(Msg)
458	if err = r.Unpack(buf); err != nil {
459		t.Errorf("failed to unpack truncated message: %v", err)
460	}
461	if !r.Truncated {
462		t.Errorf("truncated message wasn't unpacked as truncated")
463	}
464	if len(r.Answer) != cnt {
465		t.Errorf("answer count after truncated unpack doesn't match: %d", len(r.Answer))
466	}
467	if len(r.Extra) != cnt {
468		t.Errorf("extra count after truncated unpack doesn't match: %d", len(r.Extra))
469	}
470
471	// Now we want to remove almost all of the extra records
472	// We're going to loop over the extra to get the count of the size of all
473	// of them
474	off := 0
475	buf1 := make([]byte, m.Len())
476	for i := 0; i < len(m.Extra); i++ {
477		off, err = PackRR(m.Extra[i], buf1, off, nil, m.Compress)
478		if err != nil {
479			t.Errorf("failed to pack extra: %v", err)
480		}
481	}
482
483	// Remove all of the extra bytes but 10 bytes from the end of buf
484	off -= 10
485	buf1 = buf[:len(buf)-off]
486
487	r = new(Msg)
488	if err = r.Unpack(buf1); err == nil {
489		t.Error("cutoff message should have failed to unpack")
490	}
491	// r's header might be still usable.
492	if !r.Truncated {
493		t.Error("truncated cutoff message wasn't unpacked as truncated")
494	}
495	if len(r.Answer) != cnt {
496		t.Errorf("answer count after cutoff unpack doesn't match: %d", len(r.Answer))
497	}
498	if len(r.Extra) != 0 {
499		t.Errorf("extra count after cutoff unpack is not zero: %d", len(r.Extra))
500	}
501
502	// Now we want to remove almost all of the answer records too
503	buf1 = make([]byte, m.Len())
504	as := 0
505	for i := 0; i < len(m.Extra); i++ {
506		off1 := off
507		off, err = PackRR(m.Extra[i], buf1, off, nil, m.Compress)
508		as = off - off1
509		if err != nil {
510			t.Errorf("failed to pack extra: %v", err)
511		}
512	}
513
514	// Keep exactly one answer left
515	// This should still cause Answer to be nil
516	off -= as
517	buf1 = buf[:len(buf)-off]
518
519	r = new(Msg)
520	if err = r.Unpack(buf1); err == nil {
521		t.Error("cutoff message should have failed to unpack")
522	}
523	if !r.Truncated {
524		t.Error("truncated cutoff message wasn't unpacked as truncated")
525	}
526	if len(r.Answer) != 0 {
527		t.Errorf("answer count after second cutoff unpack is not zero: %d", len(r.Answer))
528	}
529
530	// Now leave only 1 byte of the question
531	// Since the header is always 12 bytes, we just need to keep 13
532	buf1 = buf[:13]
533
534	r = new(Msg)
535	err = r.Unpack(buf1)
536	if err == nil {
537		t.Errorf("error should be nil after question cutoff unpack: %v", err)
538	}
539
540	// Finally, if we only have the header, we don't return an error.
541	buf1 = buf[:12]
542
543	r = new(Msg)
544	if err = r.Unpack(buf1); err != nil {
545		t.Errorf("from header-only unpack should not return an error: %v", err)
546	}
547}
548
549func TestTimeout(t *testing.T) {
550	// Set up a dummy UDP server that won't respond
551	addr, err := net.ResolveUDPAddr("udp", ":0")
552	if err != nil {
553		t.Fatalf("unable to resolve local udp address: %v", err)
554	}
555	conn, err := net.ListenUDP("udp", addr)
556	if err != nil {
557		t.Fatalf("unable to run test server: %v", err)
558	}
559	defer conn.Close()
560	addrstr := conn.LocalAddr().String()
561
562	// Message to send
563	m := new(Msg)
564	m.SetQuestion("miek.nl.", TypeTXT)
565
566	runTest := func(name string, exchange func(m *Msg, addr string, timeout time.Duration) (*Msg, time.Duration, error)) {
567		t.Run(name, func(t *testing.T) {
568			start := time.Now()
569
570			timeout := time.Millisecond
571			allowable := timeout + 10*time.Millisecond
572
573			_, _, err := exchange(m, addrstr, timeout)
574			if err == nil {
575				t.Errorf("no timeout using Client.%s", name)
576			}
577
578			length := time.Since(start)
579			if length > allowable {
580				t.Errorf("exchange took longer %v than specified Timeout %v", length, allowable)
581			}
582		})
583	}
584	runTest("Exchange", func(m *Msg, addr string, timeout time.Duration) (*Msg, time.Duration, error) {
585		c := &Client{Timeout: timeout}
586		return c.Exchange(m, addr)
587	})
588	runTest("ExchangeContext", func(m *Msg, addr string, timeout time.Duration) (*Msg, time.Duration, error) {
589		ctx, cancel := context.WithTimeout(context.Background(), timeout)
590		defer cancel()
591
592		return new(Client).ExchangeContext(ctx, m, addrstr)
593	})
594}
595
596// Check that responses from deduplicated requests aren't shared between callers
597func TestConcurrentExchanges(t *testing.T) {
598	cases := make([]*Msg, 2)
599	cases[0] = new(Msg)
600	cases[1] = new(Msg)
601	cases[1].Truncated = true
602
603	for _, m := range cases {
604		mm := m // redeclare m so as not to trip the race detector
605		handler := func(w ResponseWriter, req *Msg) {
606			r := mm.Copy()
607			r.SetReply(req)
608
609			w.WriteMsg(r)
610		}
611
612		HandleFunc("miek.nl.", handler)
613		defer HandleRemove("miek.nl.")
614
615		s, addrstr, _, err := RunLocalUDPServer(":0")
616		if err != nil {
617			t.Fatalf("unable to run test server: %s", err)
618		}
619		defer s.Shutdown()
620
621		m := new(Msg)
622		m.SetQuestion("miek.nl.", TypeSRV)
623
624		c := &Client{
625			SingleInflight: true,
626		}
627		// Force this client to always return the same request,
628		// even though we're querying sequentially. Running the
629		// Exchange calls below concurrently can fail due to
630		// goroutine scheduling, but this simulates the same
631		// outcome.
632		c.group.dontDeleteForTesting = true
633
634		r := make([]*Msg, 2)
635		for i := range r {
636			r[i], _, _ = c.Exchange(m.Copy(), addrstr)
637			if r[i] == nil {
638				t.Errorf("response %d is nil", i)
639			}
640		}
641
642		if r[0] == r[1] {
643			t.Errorf("got same response, expected non-shared responses")
644		}
645	}
646}
647
648func TestExchangeWithConn(t *testing.T) {
649	HandleFunc("miek.nl.", HelloServer)
650	defer HandleRemove("miek.nl.")
651
652	s, addrstr, _, err := RunLocalUDPServer(":0")
653	if err != nil {
654		t.Fatalf("unable to run test server: %v", err)
655	}
656	defer s.Shutdown()
657
658	m := new(Msg)
659	m.SetQuestion("miek.nl.", TypeSOA)
660
661	c := new(Client)
662	conn, err := c.Dial(addrstr)
663	if err != nil {
664		t.Fatalf("failed to dial: %v", err)
665	}
666
667	r, _, err := c.ExchangeWithConn(m, conn)
668	if err != nil {
669		t.Fatalf("failed to exchange: %v", err)
670	}
671	if r == nil {
672		t.Fatal("response is nil")
673	}
674	if r.Rcode != RcodeSuccess {
675		t.Errorf("failed to get an valid answer\n%v", r)
676	}
677}
678