1package dns
2
3import (
4	"context"
5	"crypto/tls"
6	"fmt"
7	"net"
8	"strconv"
9	"strings"
10	"testing"
11	"time"
12)
13
14func TestDialUDP(t *testing.T) {
15	HandleFunc("miek.nl.", HelloServer)
16	defer HandleRemove("miek.nl.")
17
18	s, addrstr, err := RunLocalUDPServer(":0")
19	if err != nil {
20		t.Fatalf("unable to run test server: %v", err)
21	}
22	defer s.Shutdown()
23
24	m := new(Msg)
25	m.SetQuestion("miek.nl.", TypeSOA)
26
27	c := new(Client)
28	conn, err := c.Dial(addrstr)
29	if err != nil {
30		t.Fatalf("failed to dial: %v", err)
31	}
32	if conn == nil {
33		t.Fatalf("conn is nil")
34	}
35}
36
37func TestClientSync(t *testing.T) {
38	HandleFunc("miek.nl.", HelloServer)
39	defer HandleRemove("miek.nl.")
40
41	s, addrstr, err := RunLocalUDPServer(":0")
42	if err != nil {
43		t.Fatalf("unable to run test server: %v", err)
44	}
45	defer s.Shutdown()
46
47	m := new(Msg)
48	m.SetQuestion("miek.nl.", TypeSOA)
49
50	c := new(Client)
51	r, _, err := c.Exchange(m, addrstr)
52	if err != nil {
53		t.Fatalf("failed to exchange: %v", err)
54	}
55	if r == nil {
56		t.Fatal("response is nil")
57	}
58	if r.Rcode != RcodeSuccess {
59		t.Errorf("failed to get an valid answer\n%v", r)
60	}
61	// And now with plain Exchange().
62	r, err = Exchange(m, addrstr)
63	if err != nil {
64		t.Errorf("failed to exchange: %v", err)
65	}
66	if r == nil || r.Rcode != RcodeSuccess {
67		t.Errorf("failed to get an valid answer\n%v", r)
68	}
69}
70
71func TestClientLocalAddress(t *testing.T) {
72	HandleFunc("miek.nl.", HelloServerEchoAddrPort)
73	defer HandleRemove("miek.nl.")
74
75	s, addrstr, err := RunLocalUDPServer(":0")
76	if err != nil {
77		t.Fatalf("unable to run test server: %v", err)
78	}
79	defer s.Shutdown()
80
81	m := new(Msg)
82	m.SetQuestion("miek.nl.", TypeSOA)
83
84	c := new(Client)
85	laddr := net.UDPAddr{IP: net.ParseIP("0.0.0.0"), Port: 12345, Zone: ""}
86	c.Dialer = &net.Dialer{LocalAddr: &laddr}
87	r, _, err := c.Exchange(m, addrstr)
88	if err != nil {
89		t.Fatalf("failed to exchange: %v", err)
90	}
91	if r != nil && r.Rcode != RcodeSuccess {
92		t.Errorf("failed to get an valid answer\n%v", r)
93	}
94	if len(r.Extra) != 1 {
95		t.Fatalf("failed to get additional answers\n%v", r)
96	}
97	txt := r.Extra[0].(*TXT)
98	if txt == nil {
99		t.Errorf("invalid TXT response\n%v", txt)
100	}
101	if len(txt.Txt) != 1 || !strings.Contains(txt.Txt[0], ":12345") {
102		t.Errorf("invalid TXT response\n%v", txt.Txt)
103	}
104}
105
106func TestClientTLSSyncV4(t *testing.T) {
107	HandleFunc("miek.nl.", HelloServer)
108	defer HandleRemove("miek.nl.")
109
110	cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock)
111	if err != nil {
112		t.Fatalf("unable to build certificate: %v", err)
113	}
114
115	config := tls.Config{
116		Certificates: []tls.Certificate{cert},
117	}
118
119	s, addrstr, err := RunLocalTLSServer(":0", &config)
120	if err != nil {
121		t.Fatalf("unable to run test server: %v", err)
122	}
123	defer s.Shutdown()
124
125	m := new(Msg)
126	m.SetQuestion("miek.nl.", TypeSOA)
127
128	c := new(Client)
129
130	// test tcp-tls
131	c.Net = "tcp-tls"
132	c.TLSConfig = &tls.Config{
133		InsecureSkipVerify: true,
134	}
135
136	r, _, err := c.Exchange(m, addrstr)
137	if err != nil {
138		t.Fatalf("failed to exchange: %v", err)
139	}
140	if r == nil {
141		t.Fatal("response is nil")
142	}
143	if r.Rcode != RcodeSuccess {
144		t.Errorf("failed to get an valid answer\n%v", r)
145	}
146
147	// test tcp4-tls
148	c.Net = "tcp4-tls"
149	c.TLSConfig = &tls.Config{
150		InsecureSkipVerify: true,
151	}
152
153	r, _, err = c.Exchange(m, addrstr)
154	if err != nil {
155		t.Fatalf("failed to exchange: %v", err)
156	}
157	if r == nil {
158		t.Fatal("response is nil")
159	}
160	if r.Rcode != RcodeSuccess {
161		t.Errorf("failed to get an valid answer\n%v", r)
162	}
163}
164
165func TestClientSyncBadID(t *testing.T) {
166	HandleFunc("miek.nl.", HelloServerBadID)
167	defer HandleRemove("miek.nl.")
168
169	s, addrstr, err := RunLocalUDPServer(":0")
170	if err != nil {
171		t.Fatalf("unable to run test server: %v", err)
172	}
173	defer s.Shutdown()
174
175	m := new(Msg)
176	m.SetQuestion("miek.nl.", TypeSOA)
177
178	c := new(Client)
179	if _, _, err := c.Exchange(m, addrstr); err != ErrId {
180		t.Errorf("did not find a bad Id")
181	}
182	// And now with plain Exchange().
183	if _, err := Exchange(m, addrstr); err != ErrId {
184		t.Errorf("did not find a bad Id")
185	}
186}
187
188func TestClientEDNS0(t *testing.T) {
189	HandleFunc("miek.nl.", HelloServer)
190	defer HandleRemove("miek.nl.")
191
192	s, addrstr, err := RunLocalUDPServer(":0")
193	if err != nil {
194		t.Fatalf("unable to run test server: %v", err)
195	}
196	defer s.Shutdown()
197
198	m := new(Msg)
199	m.SetQuestion("miek.nl.", TypeDNSKEY)
200
201	m.SetEdns0(2048, true)
202
203	c := new(Client)
204	r, _, err := c.Exchange(m, addrstr)
205	if err != nil {
206		t.Fatalf("failed to exchange: %v", err)
207	}
208
209	if r != nil && r.Rcode != RcodeSuccess {
210		t.Errorf("failed to get a valid answer\n%v", r)
211	}
212}
213
214// Validates the transmission and parsing of local EDNS0 options.
215func TestClientEDNS0Local(t *testing.T) {
216	optStr1 := "1979:0x0707"
217	optStr2 := strconv.Itoa(EDNS0LOCALSTART) + ":0x0601"
218
219	handler := func(w ResponseWriter, req *Msg) {
220		m := new(Msg)
221		m.SetReply(req)
222
223		m.Extra = make([]RR, 1, 2)
224		m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{"Hello local edns"}}
225
226		// If the local options are what we expect, then reflect them back.
227		ec1 := req.Extra[0].(*OPT).Option[0].(*EDNS0_LOCAL).String()
228		ec2 := req.Extra[0].(*OPT).Option[1].(*EDNS0_LOCAL).String()
229		if ec1 == optStr1 && ec2 == optStr2 {
230			m.Extra = append(m.Extra, req.Extra[0])
231		}
232
233		w.WriteMsg(m)
234	}
235
236	HandleFunc("miek.nl.", handler)
237	defer HandleRemove("miek.nl.")
238
239	s, addrstr, err := RunLocalUDPServer(":0")
240	if err != nil {
241		t.Fatalf("unable to run test server: %s", err)
242	}
243	defer s.Shutdown()
244
245	m := new(Msg)
246	m.SetQuestion("miek.nl.", TypeTXT)
247
248	// Add two local edns options to the query.
249	ec1 := &EDNS0_LOCAL{Code: 1979, Data: []byte{7, 7}}
250	ec2 := &EDNS0_LOCAL{Code: EDNS0LOCALSTART, Data: []byte{6, 1}}
251	o := &OPT{Hdr: RR_Header{Name: ".", Rrtype: TypeOPT}, Option: []EDNS0{ec1, ec2}}
252	m.Extra = append(m.Extra, o)
253
254	c := new(Client)
255	r, _, err := c.Exchange(m, addrstr)
256	if err != nil {
257		t.Fatalf("failed to exchange: %s", err)
258	}
259
260	if r == nil {
261		t.Fatal("response is nil")
262	}
263	if r.Rcode != RcodeSuccess {
264		t.Fatal("failed to get a valid answer")
265	}
266
267	txt := r.Extra[0].(*TXT).Txt[0]
268	if txt != "Hello local edns" {
269		t.Error("Unexpected result for miek.nl", txt, "!= Hello local edns")
270	}
271
272	// Validate the local options in the reply.
273	got := r.Extra[1].(*OPT).Option[0].(*EDNS0_LOCAL).String()
274	if got != optStr1 {
275		t.Errorf("failed to get local edns0 answer; got %s, expected %s", got, optStr1)
276	}
277
278	got = r.Extra[1].(*OPT).Option[1].(*EDNS0_LOCAL).String()
279	if got != optStr2 {
280		t.Errorf("failed to get local edns0 answer; got %s, expected %s", got, optStr2)
281	}
282}
283
284func TestClientConn(t *testing.T) {
285	HandleFunc("miek.nl.", HelloServer)
286	defer HandleRemove("miek.nl.")
287
288	// This uses TCP just to make it slightly different than TestClientSync
289	s, addrstr, err := RunLocalTCPServer(":0")
290	if err != nil {
291		t.Fatalf("unable to run test server: %v", err)
292	}
293	defer s.Shutdown()
294
295	m := new(Msg)
296	m.SetQuestion("miek.nl.", TypeSOA)
297
298	cn, err := Dial("tcp", addrstr)
299	if err != nil {
300		t.Errorf("failed to dial %s: %v", addrstr, err)
301	}
302
303	err = cn.WriteMsg(m)
304	if err != nil {
305		t.Errorf("failed to exchange: %v", err)
306	}
307	r, err := cn.ReadMsg()
308	if err != nil {
309		t.Errorf("failed to get a valid answer: %v", err)
310	}
311	if r == nil || r.Rcode != RcodeSuccess {
312		t.Errorf("failed to get an valid answer\n%v", r)
313	}
314
315	err = cn.WriteMsg(m)
316	if err != nil {
317		t.Errorf("failed to exchange: %v", err)
318	}
319	h := new(Header)
320	buf, err := cn.ReadMsgHeader(h)
321	if buf == nil {
322		t.Errorf("failed to get an valid answer\n%v", r)
323	}
324	if err != nil {
325		t.Errorf("failed to get a valid answer: %v", err)
326	}
327	if int(h.Bits&0xF) != RcodeSuccess {
328		t.Errorf("failed to get an valid answer in ReadMsgHeader\n%v", r)
329	}
330	if h.Ancount != 0 || h.Qdcount != 1 || h.Nscount != 0 || h.Arcount != 1 {
331		t.Errorf("expected to have question and additional in response; got something else: %+v", h)
332	}
333	if err = r.Unpack(buf); err != nil {
334		t.Errorf("unable to unpack message fully: %v", err)
335	}
336}
337
338func TestTruncatedMsg(t *testing.T) {
339	m := new(Msg)
340	m.SetQuestion("miek.nl.", TypeSRV)
341	cnt := 10
342	for i := 0; i < cnt; i++ {
343		r := &SRV{
344			Hdr:    RR_Header{Name: m.Question[0].Name, Rrtype: TypeSRV, Class: ClassINET, Ttl: 0},
345			Port:   uint16(i + 8000),
346			Target: "target.miek.nl.",
347		}
348		m.Answer = append(m.Answer, r)
349
350		re := &A{
351			Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeA, Class: ClassINET, Ttl: 0},
352			A:   net.ParseIP(fmt.Sprintf("127.0.0.%d", i)).To4(),
353		}
354		m.Extra = append(m.Extra, re)
355	}
356	buf, err := m.Pack()
357	if err != nil {
358		t.Errorf("failed to pack: %v", err)
359	}
360
361	r := new(Msg)
362	if err = r.Unpack(buf); err != nil {
363		t.Errorf("unable to unpack message: %v", err)
364	}
365	if len(r.Answer) != cnt {
366		t.Errorf("answer count after regular unpack doesn't match: %d", len(r.Answer))
367	}
368	if len(r.Extra) != cnt {
369		t.Errorf("extra count after regular unpack doesn't match: %d", len(r.Extra))
370	}
371
372	m.Truncated = true
373	buf, err = m.Pack()
374	if err != nil {
375		t.Errorf("failed to pack truncated message: %v", err)
376	}
377
378	r = new(Msg)
379	if err = r.Unpack(buf); err != nil {
380		t.Errorf("failed to unpack truncated message: %v", err)
381	}
382	if !r.Truncated {
383		t.Errorf("truncated message wasn't unpacked as truncated")
384	}
385	if len(r.Answer) != cnt {
386		t.Errorf("answer count after truncated unpack doesn't match: %d", len(r.Answer))
387	}
388	if len(r.Extra) != cnt {
389		t.Errorf("extra count after truncated unpack doesn't match: %d", len(r.Extra))
390	}
391
392	// Now we want to remove almost all of the extra records
393	// We're going to loop over the extra to get the count of the size of all
394	// of them
395	off := 0
396	buf1 := make([]byte, m.Len())
397	for i := 0; i < len(m.Extra); i++ {
398		off, err = PackRR(m.Extra[i], buf1, off, nil, m.Compress)
399		if err != nil {
400			t.Errorf("failed to pack extra: %v", err)
401		}
402	}
403
404	// Remove all of the extra bytes but 10 bytes from the end of buf
405	off -= 10
406	buf1 = buf[:len(buf)-off]
407
408	r = new(Msg)
409	if err = r.Unpack(buf1); err == nil {
410		t.Error("cutoff message should have failed to unpack")
411	}
412	// r's header might be still usable.
413	if !r.Truncated {
414		t.Error("truncated cutoff message wasn't unpacked as truncated")
415	}
416	if len(r.Answer) != cnt {
417		t.Errorf("answer count after cutoff unpack doesn't match: %d", len(r.Answer))
418	}
419	if len(r.Extra) != 0 {
420		t.Errorf("extra count after cutoff unpack is not zero: %d", len(r.Extra))
421	}
422
423	// Now we want to remove almost all of the answer records too
424	buf1 = make([]byte, m.Len())
425	as := 0
426	for i := 0; i < len(m.Extra); i++ {
427		off1 := off
428		off, err = PackRR(m.Extra[i], buf1, off, nil, m.Compress)
429		as = off - off1
430		if err != nil {
431			t.Errorf("failed to pack extra: %v", err)
432		}
433	}
434
435	// Keep exactly one answer left
436	// This should still cause Answer to be nil
437	off -= as
438	buf1 = buf[:len(buf)-off]
439
440	r = new(Msg)
441	if err = r.Unpack(buf1); err == nil {
442		t.Error("cutoff message should have failed to unpack")
443	}
444	if !r.Truncated {
445		t.Error("truncated cutoff message wasn't unpacked as truncated")
446	}
447	if len(r.Answer) != 0 {
448		t.Errorf("answer count after second cutoff unpack is not zero: %d", len(r.Answer))
449	}
450
451	// Now leave only 1 byte of the question
452	// Since the header is always 12 bytes, we just need to keep 13
453	buf1 = buf[:13]
454
455	r = new(Msg)
456	err = r.Unpack(buf1)
457	if err == nil {
458		t.Errorf("error should be nil after question cutoff unpack: %v", err)
459	}
460
461	// Finally, if we only have the header, we don't return an error.
462	buf1 = buf[:12]
463
464	r = new(Msg)
465	if err = r.Unpack(buf1); err != nil {
466		t.Errorf("from header-only unpack should not return an error: %v", err)
467	}
468}
469
470func TestTimeout(t *testing.T) {
471	// Set up a dummy UDP server that won't respond
472	addr, err := net.ResolveUDPAddr("udp", ":0")
473	if err != nil {
474		t.Fatalf("unable to resolve local udp address: %v", err)
475	}
476	conn, err := net.ListenUDP("udp", addr)
477	if err != nil {
478		t.Fatalf("unable to run test server: %v", err)
479	}
480	defer conn.Close()
481	addrstr := conn.LocalAddr().String()
482
483	// Message to send
484	m := new(Msg)
485	m.SetQuestion("miek.nl.", TypeTXT)
486
487	runTest := func(name string, exchange func(m *Msg, addr string, timeout time.Duration) (*Msg, time.Duration, error)) {
488		t.Run(name, func(t *testing.T) {
489			start := time.Now()
490
491			timeout := time.Millisecond
492			allowable := timeout + 10*time.Millisecond
493
494			_, _, err := exchange(m, addrstr, timeout)
495			if err == nil {
496				t.Errorf("no timeout using Client.%s", name)
497			}
498
499			length := time.Since(start)
500			if length > allowable {
501				t.Errorf("exchange took longer %v than specified Timeout %v", length, allowable)
502			}
503		})
504	}
505	runTest("Exchange", func(m *Msg, addr string, timeout time.Duration) (*Msg, time.Duration, error) {
506		c := &Client{Timeout: timeout}
507		return c.Exchange(m, addr)
508	})
509	runTest("ExchangeContext", func(m *Msg, addr string, timeout time.Duration) (*Msg, time.Duration, error) {
510		ctx, cancel := context.WithTimeout(context.Background(), timeout)
511		defer cancel()
512
513		return new(Client).ExchangeContext(ctx, m, addrstr)
514	})
515}
516
517// Check that responses from deduplicated requests aren't shared between callers
518func TestConcurrentExchanges(t *testing.T) {
519	cases := make([]*Msg, 2)
520	cases[0] = new(Msg)
521	cases[1] = new(Msg)
522	cases[1].Truncated = true
523
524	for _, m := range cases {
525		mm := m // redeclare m so as not to trip the race detector
526		handler := func(w ResponseWriter, req *Msg) {
527			r := mm.Copy()
528			r.SetReply(req)
529
530			w.WriteMsg(r)
531		}
532
533		HandleFunc("miek.nl.", handler)
534		defer HandleRemove("miek.nl.")
535
536		s, addrstr, err := RunLocalUDPServer(":0")
537		if err != nil {
538			t.Fatalf("unable to run test server: %s", err)
539		}
540		defer s.Shutdown()
541
542		m := new(Msg)
543		m.SetQuestion("miek.nl.", TypeSRV)
544
545		c := &Client{
546			SingleInflight: true,
547		}
548		// Force this client to always return the same request,
549		// even though we're querying sequentially. Running the
550		// Exchange calls below concurrently can fail due to
551		// goroutine scheduling, but this simulates the same
552		// outcome.
553		c.group.dontDeleteForTesting = true
554
555		r := make([]*Msg, 2)
556		for i := range r {
557			r[i], _, _ = c.Exchange(m.Copy(), addrstr)
558			if r[i] == nil {
559				t.Errorf("response %d is nil", i)
560			}
561		}
562
563		if r[0] == r[1] {
564			t.Errorf("got same response, expected non-shared responses")
565		}
566	}
567}
568