1package dns
2
3import (
4	"context"
5	"crypto/tls"
6	"fmt"
7	"net"
8	"strconv"
9	"strings"
10	"sync"
11	"testing"
12	"time"
13)
14
15func TestDialUDP(t *testing.T) {
16	HandleFunc("miek.nl.", HelloServer)
17	defer HandleRemove("miek.nl.")
18
19	s, addrstr, err := RunLocalUDPServer(":0")
20	if err != nil {
21		t.Fatalf("unable to run test server: %v", err)
22	}
23	defer s.Shutdown()
24
25	m := new(Msg)
26	m.SetQuestion("miek.nl.", TypeSOA)
27
28	c := new(Client)
29	conn, err := c.Dial(addrstr)
30	if err != nil {
31		t.Fatalf("failed to dial: %v", err)
32	}
33	if conn == nil {
34		t.Fatalf("conn is nil")
35	}
36}
37
38func TestClientSync(t *testing.T) {
39	HandleFunc("miek.nl.", HelloServer)
40	defer HandleRemove("miek.nl.")
41
42	s, addrstr, err := RunLocalUDPServer(":0")
43	if err != nil {
44		t.Fatalf("unable to run test server: %v", err)
45	}
46	defer s.Shutdown()
47
48	m := new(Msg)
49	m.SetQuestion("miek.nl.", TypeSOA)
50
51	c := new(Client)
52	r, _, err := c.Exchange(m, addrstr)
53	if err != nil {
54		t.Fatalf("failed to exchange: %v", err)
55	}
56	if r == nil {
57		t.Fatal("response is nil")
58	}
59	if r.Rcode != RcodeSuccess {
60		t.Errorf("failed to get an valid answer\n%v", r)
61	}
62	// And now with plain Exchange().
63	r, err = Exchange(m, addrstr)
64	if err != nil {
65		t.Errorf("failed to exchange: %v", err)
66	}
67	if r == nil || r.Rcode != RcodeSuccess {
68		t.Errorf("failed to get an valid answer\n%v", r)
69	}
70}
71
72func TestClientLocalAddress(t *testing.T) {
73	HandleFunc("miek.nl.", HelloServerEchoAddrPort)
74	defer HandleRemove("miek.nl.")
75
76	s, addrstr, err := RunLocalUDPServer(":0")
77	if err != nil {
78		t.Fatalf("unable to run test server: %v", err)
79	}
80	defer s.Shutdown()
81
82	m := new(Msg)
83	m.SetQuestion("miek.nl.", TypeSOA)
84
85	c := new(Client)
86	laddr := net.UDPAddr{IP: net.ParseIP("0.0.0.0"), Port: 12345, Zone: ""}
87	c.Dialer = &net.Dialer{LocalAddr: &laddr}
88	r, _, err := c.Exchange(m, addrstr)
89	if err != nil {
90		t.Fatalf("failed to exchange: %v", err)
91	}
92	if r != nil && r.Rcode != RcodeSuccess {
93		t.Errorf("failed to get an valid answer\n%v", r)
94	}
95	if len(r.Extra) != 1 {
96		t.Errorf("failed to get additional answers\n%v", r)
97	}
98	txt := r.Extra[0].(*TXT)
99	if txt == nil {
100		t.Errorf("invalid TXT response\n%v", txt)
101	}
102	if len(txt.Txt) != 1 || !strings.Contains(txt.Txt[0], ":12345") {
103		t.Errorf("invalid TXT response\n%v", txt.Txt)
104	}
105}
106
107func TestClientTLSSyncV4(t *testing.T) {
108	HandleFunc("miek.nl.", HelloServer)
109	defer HandleRemove("miek.nl.")
110
111	cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock)
112	if err != nil {
113		t.Fatalf("unable to build certificate: %v", err)
114	}
115
116	config := tls.Config{
117		Certificates: []tls.Certificate{cert},
118	}
119
120	s, addrstr, err := RunLocalTLSServer(":0", &config)
121	if err != nil {
122		t.Fatalf("unable to run test server: %v", err)
123	}
124	defer s.Shutdown()
125
126	m := new(Msg)
127	m.SetQuestion("miek.nl.", TypeSOA)
128
129	c := new(Client)
130
131	// test tcp-tls
132	c.Net = "tcp-tls"
133	c.TLSConfig = &tls.Config{
134		InsecureSkipVerify: true,
135	}
136
137	r, _, err := c.Exchange(m, addrstr)
138	if err != nil {
139		t.Fatalf("failed to exchange: %v", err)
140	}
141	if r == nil {
142		t.Fatal("response is nil")
143	}
144	if r.Rcode != RcodeSuccess {
145		t.Errorf("failed to get an valid answer\n%v", r)
146	}
147
148	// test tcp4-tls
149	c.Net = "tcp4-tls"
150	c.TLSConfig = &tls.Config{
151		InsecureSkipVerify: true,
152	}
153
154	r, _, err = c.Exchange(m, addrstr)
155	if err != nil {
156		t.Fatalf("failed to exchange: %v", err)
157	}
158	if r == nil {
159		t.Fatal("response is nil")
160	}
161	if r.Rcode != RcodeSuccess {
162		t.Errorf("failed to get an valid answer\n%v", r)
163	}
164}
165
166func TestClientSyncBadID(t *testing.T) {
167	HandleFunc("miek.nl.", HelloServerBadID)
168	defer HandleRemove("miek.nl.")
169
170	s, addrstr, err := RunLocalUDPServer(":0")
171	if err != nil {
172		t.Fatalf("unable to run test server: %v", err)
173	}
174	defer s.Shutdown()
175
176	m := new(Msg)
177	m.SetQuestion("miek.nl.", TypeSOA)
178
179	c := new(Client)
180	if _, _, err := c.Exchange(m, addrstr); err != ErrId {
181		t.Errorf("did not find a bad Id")
182	}
183	// And now with plain Exchange().
184	if _, err := Exchange(m, addrstr); err != ErrId {
185		t.Errorf("did not find a bad Id")
186	}
187}
188
189func TestClientEDNS0(t *testing.T) {
190	HandleFunc("miek.nl.", HelloServer)
191	defer HandleRemove("miek.nl.")
192
193	s, addrstr, err := RunLocalUDPServer(":0")
194	if err != nil {
195		t.Fatalf("unable to run test server: %v", err)
196	}
197	defer s.Shutdown()
198
199	m := new(Msg)
200	m.SetQuestion("miek.nl.", TypeDNSKEY)
201
202	m.SetEdns0(2048, true)
203
204	c := new(Client)
205	r, _, err := c.Exchange(m, addrstr)
206	if err != nil {
207		t.Fatalf("failed to exchange: %v", err)
208	}
209
210	if r != nil && r.Rcode != RcodeSuccess {
211		t.Errorf("failed to get a valid answer\n%v", r)
212	}
213}
214
215// Validates the transmission and parsing of local EDNS0 options.
216func TestClientEDNS0Local(t *testing.T) {
217	optStr1 := "1979:0x0707"
218	optStr2 := strconv.Itoa(EDNS0LOCALSTART) + ":0x0601"
219
220	handler := func(w ResponseWriter, req *Msg) {
221		m := new(Msg)
222		m.SetReply(req)
223
224		m.Extra = make([]RR, 1, 2)
225		m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{"Hello local edns"}}
226
227		// If the local options are what we expect, then reflect them back.
228		ec1 := req.Extra[0].(*OPT).Option[0].(*EDNS0_LOCAL).String()
229		ec2 := req.Extra[0].(*OPT).Option[1].(*EDNS0_LOCAL).String()
230		if ec1 == optStr1 && ec2 == optStr2 {
231			m.Extra = append(m.Extra, req.Extra[0])
232		}
233
234		w.WriteMsg(m)
235	}
236
237	HandleFunc("miek.nl.", handler)
238	defer HandleRemove("miek.nl.")
239
240	s, addrstr, err := RunLocalUDPServer(":0")
241	if err != nil {
242		t.Fatalf("unable to run test server: %s", err)
243	}
244	defer s.Shutdown()
245
246	m := new(Msg)
247	m.SetQuestion("miek.nl.", TypeTXT)
248
249	// Add two local edns options to the query.
250	ec1 := &EDNS0_LOCAL{Code: 1979, Data: []byte{7, 7}}
251	ec2 := &EDNS0_LOCAL{Code: EDNS0LOCALSTART, Data: []byte{6, 1}}
252	o := &OPT{Hdr: RR_Header{Name: ".", Rrtype: TypeOPT}, Option: []EDNS0{ec1, ec2}}
253	m.Extra = append(m.Extra, o)
254
255	c := new(Client)
256	r, _, err := c.Exchange(m, addrstr)
257	if err != nil {
258		t.Fatalf("failed to exchange: %s", err)
259	}
260
261	if r == nil {
262		t.Fatal("response is nil")
263	}
264	if r.Rcode != RcodeSuccess {
265		t.Fatal("failed to get a valid answer")
266	}
267
268	txt := r.Extra[0].(*TXT).Txt[0]
269	if txt != "Hello local edns" {
270		t.Error("Unexpected result for miek.nl", txt, "!= Hello local edns")
271	}
272
273	// Validate the local options in the reply.
274	got := r.Extra[1].(*OPT).Option[0].(*EDNS0_LOCAL).String()
275	if got != optStr1 {
276		t.Errorf("failed to get local edns0 answer; got %s, expected %s", got, optStr1)
277	}
278
279	got = r.Extra[1].(*OPT).Option[1].(*EDNS0_LOCAL).String()
280	if got != optStr2 {
281		t.Errorf("failed to get local edns0 answer; got %s, expected %s", got, optStr2)
282	}
283}
284
285func TestClientConn(t *testing.T) {
286	HandleFunc("miek.nl.", HelloServer)
287	defer HandleRemove("miek.nl.")
288
289	// This uses TCP just to make it slightly different than TestClientSync
290	s, addrstr, err := RunLocalTCPServer(":0")
291	if err != nil {
292		t.Fatalf("unable to run test server: %v", err)
293	}
294	defer s.Shutdown()
295
296	m := new(Msg)
297	m.SetQuestion("miek.nl.", TypeSOA)
298
299	cn, err := Dial("tcp", addrstr)
300	if err != nil {
301		t.Errorf("failed to dial %s: %v", addrstr, err)
302	}
303
304	err = cn.WriteMsg(m)
305	if err != nil {
306		t.Errorf("failed to exchange: %v", err)
307	}
308	r, err := cn.ReadMsg()
309	if err != nil {
310		t.Errorf("failed to get a valid answer: %v", err)
311	}
312	if r == nil || r.Rcode != RcodeSuccess {
313		t.Errorf("failed to get an valid answer\n%v", r)
314	}
315
316	err = cn.WriteMsg(m)
317	if err != nil {
318		t.Errorf("failed to exchange: %v", err)
319	}
320	h := new(Header)
321	buf, err := cn.ReadMsgHeader(h)
322	if buf == nil {
323		t.Errorf("failed to get an valid answer\n%v", r)
324	}
325	if err != nil {
326		t.Errorf("failed to get a valid answer: %v", err)
327	}
328	if int(h.Bits&0xF) != RcodeSuccess {
329		t.Errorf("failed to get an valid answer in ReadMsgHeader\n%v", r)
330	}
331	if h.Ancount != 0 || h.Qdcount != 1 || h.Nscount != 0 || h.Arcount != 1 {
332		t.Errorf("expected to have question and additional in response; got something else: %+v", h)
333	}
334	if err = r.Unpack(buf); err != nil {
335		t.Errorf("unable to unpack message fully: %v", err)
336	}
337}
338
339func TestTruncatedMsg(t *testing.T) {
340	m := new(Msg)
341	m.SetQuestion("miek.nl.", TypeSRV)
342	cnt := 10
343	for i := 0; i < cnt; i++ {
344		r := &SRV{
345			Hdr:    RR_Header{Name: m.Question[0].Name, Rrtype: TypeSRV, Class: ClassINET, Ttl: 0},
346			Port:   uint16(i + 8000),
347			Target: "target.miek.nl.",
348		}
349		m.Answer = append(m.Answer, r)
350
351		re := &A{
352			Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeA, Class: ClassINET, Ttl: 0},
353			A:   net.ParseIP(fmt.Sprintf("127.0.0.%d", i)).To4(),
354		}
355		m.Extra = append(m.Extra, re)
356	}
357	buf, err := m.Pack()
358	if err != nil {
359		t.Errorf("failed to pack: %v", err)
360	}
361
362	r := new(Msg)
363	if err = r.Unpack(buf); err != nil {
364		t.Errorf("unable to unpack message: %v", err)
365	}
366	if len(r.Answer) != cnt {
367		t.Errorf("answer count after regular unpack doesn't match: %d", len(r.Answer))
368	}
369	if len(r.Extra) != cnt {
370		t.Errorf("extra count after regular unpack doesn't match: %d", len(r.Extra))
371	}
372
373	m.Truncated = true
374	buf, err = m.Pack()
375	if err != nil {
376		t.Errorf("failed to pack truncated: %v", err)
377	}
378
379	r = new(Msg)
380	if err = r.Unpack(buf); err != nil && err != ErrTruncated {
381		t.Errorf("unable to unpack truncated message: %v", err)
382	}
383	if !r.Truncated {
384		t.Errorf("truncated message wasn't unpacked as truncated")
385	}
386	if len(r.Answer) != cnt {
387		t.Errorf("answer count after truncated unpack doesn't match: %d", len(r.Answer))
388	}
389	if len(r.Extra) != cnt {
390		t.Errorf("extra count after truncated unpack doesn't match: %d", len(r.Extra))
391	}
392
393	// Now we want to remove almost all of the extra records
394	// We're going to loop over the extra to get the count of the size of all
395	// of them
396	off := 0
397	buf1 := make([]byte, m.Len())
398	for i := 0; i < len(m.Extra); i++ {
399		off, err = PackRR(m.Extra[i], buf1, off, nil, m.Compress)
400		if err != nil {
401			t.Errorf("failed to pack extra: %v", err)
402		}
403	}
404
405	// Remove all of the extra bytes but 10 bytes from the end of buf
406	off -= 10
407	buf1 = buf[:len(buf)-off]
408
409	r = new(Msg)
410	if err = r.Unpack(buf1); err != nil && err != ErrTruncated {
411		t.Errorf("unable to unpack cutoff message: %v", err)
412	}
413	if !r.Truncated {
414		t.Error("truncated cutoff message wasn't unpacked as truncated")
415	}
416	if len(r.Answer) != cnt {
417		t.Errorf("answer count after cutoff unpack doesn't match: %d", len(r.Answer))
418	}
419	if len(r.Extra) != 0 {
420		t.Errorf("extra count after cutoff unpack is not zero: %d", len(r.Extra))
421	}
422
423	// Now we want to remove almost all of the answer records too
424	buf1 = make([]byte, m.Len())
425	as := 0
426	for i := 0; i < len(m.Extra); i++ {
427		off1 := off
428		off, err = PackRR(m.Extra[i], buf1, off, nil, m.Compress)
429		as = off - off1
430		if err != nil {
431			t.Errorf("failed to pack extra: %v", err)
432		}
433	}
434
435	// Keep exactly one answer left
436	// This should still cause Answer to be nil
437	off -= as
438	buf1 = buf[:len(buf)-off]
439
440	r = new(Msg)
441	if err = r.Unpack(buf1); err != nil && err != ErrTruncated {
442		t.Errorf("unable to unpack cutoff message: %v", err)
443	}
444	if !r.Truncated {
445		t.Error("truncated cutoff message wasn't unpacked as truncated")
446	}
447	if len(r.Answer) != 0 {
448		t.Errorf("answer count after second cutoff unpack is not zero: %d", len(r.Answer))
449	}
450
451	// Now leave only 1 byte of the question
452	// Since the header is always 12 bytes, we just need to keep 13
453	buf1 = buf[:13]
454
455	r = new(Msg)
456	err = r.Unpack(buf1)
457	if err == nil || err == ErrTruncated {
458		t.Errorf("error should not be ErrTruncated from question cutoff unpack: %v", err)
459	}
460
461	// Finally, if we only have the header, we don't return an error.
462	buf1 = buf[:12]
463
464	r = new(Msg)
465	if err = r.Unpack(buf1); err != nil {
466		t.Errorf("from header-only unpack should not return an error: %v", err)
467	}
468}
469
470func TestTimeout(t *testing.T) {
471	// Set up a dummy UDP server that won't respond
472	addr, err := net.ResolveUDPAddr("udp", ":0")
473	if err != nil {
474		t.Fatalf("unable to resolve local udp address: %v", err)
475	}
476	conn, err := net.ListenUDP("udp", addr)
477	if err != nil {
478		t.Fatalf("unable to run test server: %v", err)
479	}
480	defer conn.Close()
481	addrstr := conn.LocalAddr().String()
482
483	// Message to send
484	m := new(Msg)
485	m.SetQuestion("miek.nl.", TypeTXT)
486
487	// Use a channel + timeout to ensure we don't get stuck if the
488	// Client Timeout is not working properly
489	done := make(chan struct{}, 2)
490
491	timeout := time.Millisecond
492	allowable := timeout + 10*time.Millisecond
493	abortAfter := timeout + 100*time.Millisecond
494
495	start := time.Now()
496
497	go func() {
498		c := &Client{Timeout: timeout}
499		_, _, err := c.Exchange(m, addrstr)
500		if err == nil {
501			t.Error("no timeout using Client.Exchange")
502		}
503		done <- struct{}{}
504	}()
505
506	go func() {
507		ctx, cancel := context.WithTimeout(context.Background(), timeout)
508		defer cancel()
509		c := &Client{}
510		_, _, err := c.ExchangeContext(ctx, m, addrstr)
511		if err == nil {
512			t.Error("no timeout using Client.ExchangeContext")
513		}
514		done <- struct{}{}
515	}()
516
517	// Wait for both the Exchange and ExchangeContext tests to be done.
518	for i := 0; i < 2; i++ {
519		select {
520		case <-done:
521		case <-time.After(abortAfter):
522		}
523	}
524
525	length := time.Since(start)
526
527	if length > allowable {
528		t.Errorf("exchange took longer %v than specified Timeout %v", length, allowable)
529	}
530}
531
532// Check that responses from deduplicated requests aren't shared between callers
533func TestConcurrentExchanges(t *testing.T) {
534	cases := make([]*Msg, 2)
535	cases[0] = new(Msg)
536	cases[1] = new(Msg)
537	cases[1].Truncated = true
538	for _, m := range cases {
539		block := make(chan struct{})
540		waiting := make(chan struct{})
541
542		mm := m // redeclare m so as not to trip the race detector
543		handler := func(w ResponseWriter, req *Msg) {
544			r := mm.Copy()
545			r.SetReply(req)
546
547			waiting <- struct{}{}
548			<-block
549			w.WriteMsg(r)
550		}
551
552		HandleFunc("miek.nl.", handler)
553		defer HandleRemove("miek.nl.")
554
555		s, addrstr, err := RunLocalUDPServer(":0")
556		if err != nil {
557			t.Fatalf("unable to run test server: %s", err)
558		}
559		defer s.Shutdown()
560
561		m := new(Msg)
562		m.SetQuestion("miek.nl.", TypeSRV)
563		c := &Client{
564			SingleInflight: true,
565		}
566		r := make([]*Msg, 2)
567
568		var wg sync.WaitGroup
569		wg.Add(len(r))
570		for i := 0; i < len(r); i++ {
571			go func(i int) {
572				defer wg.Done()
573				r[i], _, _ = c.Exchange(m.Copy(), addrstr)
574				if r[i] == nil {
575					t.Errorf("response %d is nil", i)
576				}
577			}(i)
578		}
579		select {
580		case <-waiting:
581		case <-time.After(time.Second):
582			t.FailNow()
583		}
584		close(block)
585		wg.Wait()
586
587		if r[0] == r[1] {
588			t.Errorf("got same response, expected non-shared responses")
589		}
590	}
591}
592
593func TestDoHExchange(t *testing.T) {
594	const addrstr = "https://dns.cloudflare.com/dns-query"
595
596	m := new(Msg)
597	m.SetQuestion("miek.nl.", TypeSOA)
598
599	cl := &Client{Net: "https"}
600
601	r, _, err := cl.Exchange(m, addrstr)
602	if err != nil {
603		t.Fatalf("failed to exchange: %v", err)
604	}
605
606	if r == nil || r.Rcode != RcodeSuccess {
607		t.Errorf("failed to get an valid answer\n%v", r)
608	}
609
610	// TODO: proper tests for this
611}
612