1package dns
2
3import (
4	"context"
5	"crypto/tls"
6	"fmt"
7	"io"
8	"net"
9	"runtime"
10	"strings"
11	"sync"
12	"sync/atomic"
13	"testing"
14	"time"
15
16	"golang.org/x/sync/errgroup"
17)
18
19func HelloServer(w ResponseWriter, req *Msg) {
20	m := new(Msg)
21	m.SetReply(req)
22
23	m.Extra = make([]RR, 1)
24	m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{"Hello world"}}
25	w.WriteMsg(m)
26}
27
28func HelloServerBadID(w ResponseWriter, req *Msg) {
29	m := new(Msg)
30	m.SetReply(req)
31	m.Id++
32
33	m.Extra = make([]RR, 1)
34	m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{"Hello world"}}
35	w.WriteMsg(m)
36}
37
38func HelloServerBadThenGoodID(w ResponseWriter, req *Msg) {
39	m := new(Msg)
40	m.SetReply(req)
41	m.Id++
42
43	m.Extra = make([]RR, 1)
44	m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{"Hello world"}}
45	w.WriteMsg(m)
46
47	m.Id--
48	w.WriteMsg(m)
49}
50
51func HelloServerEchoAddrPort(w ResponseWriter, req *Msg) {
52	m := new(Msg)
53	m.SetReply(req)
54
55	remoteAddr := w.RemoteAddr().String()
56	m.Extra = make([]RR, 1)
57	m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{remoteAddr}}
58	w.WriteMsg(m)
59}
60
61func AnotherHelloServer(w ResponseWriter, req *Msg) {
62	m := new(Msg)
63	m.SetReply(req)
64
65	m.Extra = make([]RR, 1)
66	m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{"Hello example"}}
67	w.WriteMsg(m)
68}
69
70func RunLocalUDPServer(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) {
71	pc, err := net.ListenPacket("udp", laddr)
72	if err != nil {
73		return nil, "", nil, err
74	}
75	server := &Server{PacketConn: pc, ReadTimeout: time.Hour, WriteTimeout: time.Hour}
76
77	waitLock := sync.Mutex{}
78	waitLock.Lock()
79	server.NotifyStartedFunc = waitLock.Unlock
80
81	for _, opt := range opts {
82		opt(server)
83	}
84
85	// fin must be buffered so the goroutine below won't block
86	// forever if fin is never read from. This always happens
87	// if the channel is discarded and can happen in TestShutdownUDP.
88	fin := make(chan error, 1)
89
90	go func() {
91		fin <- server.ActivateAndServe()
92		pc.Close()
93	}()
94
95	waitLock.Lock()
96	return server, pc.LocalAddr().String(), fin, nil
97}
98
99func RunLocalPacketConnServer(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) {
100	return RunLocalUDPServer(laddr, append(opts, func(srv *Server) {
101		// Make srv.PacketConn opaque to trigger the generic code paths.
102		srv.PacketConn = struct{ net.PacketConn }{srv.PacketConn}
103	})...)
104}
105
106func RunLocalTCPServer(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) {
107	l, err := net.Listen("tcp", laddr)
108	if err != nil {
109		return nil, "", nil, err
110	}
111
112	server := &Server{Listener: l, ReadTimeout: time.Hour, WriteTimeout: time.Hour}
113
114	waitLock := sync.Mutex{}
115	waitLock.Lock()
116	server.NotifyStartedFunc = waitLock.Unlock
117
118	for _, opt := range opts {
119		opt(server)
120	}
121
122	// See the comment in RunLocalUDPServer as to why fin must be buffered.
123	fin := make(chan error, 1)
124
125	go func() {
126		fin <- server.ActivateAndServe()
127		l.Close()
128	}()
129
130	waitLock.Lock()
131	return server, l.Addr().String(), fin, nil
132}
133
134func RunLocalTLSServer(laddr string, config *tls.Config) (*Server, string, chan error, error) {
135	return RunLocalTCPServer(laddr, func(srv *Server) {
136		srv.Listener = tls.NewListener(srv.Listener, config)
137	})
138}
139
140func TestServing(t *testing.T) {
141	for _, tc := range []struct {
142		name      string
143		network   string
144		runServer func(laddr string, opts ...func(*Server)) (*Server, string, chan error, error)
145	}{
146		{"udp", "udp", RunLocalUDPServer},
147		{"tcp", "tcp", RunLocalTCPServer},
148		{"PacketConn", "udp", RunLocalPacketConnServer},
149	} {
150		t.Run(tc.name, func(t *testing.T) {
151			HandleFunc("miek.nl.", HelloServer)
152			HandleFunc("example.com.", AnotherHelloServer)
153			defer HandleRemove("miek.nl.")
154			defer HandleRemove("example.com.")
155
156			s, addrstr, _, err := tc.runServer(":0")
157			if err != nil {
158				t.Fatalf("unable to run test server: %v", err)
159			}
160			defer s.Shutdown()
161
162			c := &Client{
163				Net: tc.network,
164			}
165			m := new(Msg)
166			m.SetQuestion("miek.nl.", TypeTXT)
167			r, _, err := c.Exchange(m, addrstr)
168			if err != nil || len(r.Extra) == 0 {
169				t.Fatal("failed to exchange miek.nl", err)
170			}
171			txt := r.Extra[0].(*TXT).Txt[0]
172			if txt != "Hello world" {
173				t.Error("unexpected result for miek.nl", txt, "!= Hello world")
174			}
175
176			m.SetQuestion("example.com.", TypeTXT)
177			r, _, err = c.Exchange(m, addrstr)
178			if err != nil {
179				t.Fatal("failed to exchange example.com", err)
180			}
181			txt = r.Extra[0].(*TXT).Txt[0]
182			if txt != "Hello example" {
183				t.Error("unexpected result for example.com", txt, "!= Hello example")
184			}
185
186			// Test Mixes cased as noticed by Ask.
187			m.SetQuestion("eXaMplE.cOm.", TypeTXT)
188			r, _, err = c.Exchange(m, addrstr)
189			if err != nil {
190				t.Error("failed to exchange eXaMplE.cOm", err)
191			}
192			txt = r.Extra[0].(*TXT).Txt[0]
193			if txt != "Hello example" {
194				t.Error("unexpected result for example.com", txt, "!= Hello example")
195			}
196		})
197	}
198}
199
200// Verify that the server responds to a query with Z flag on, ignoring the flag, and does not echoes it back
201func TestServeIgnoresZFlag(t *testing.T) {
202	HandleFunc("example.com.", AnotherHelloServer)
203
204	s, addrstr, _, err := RunLocalUDPServer(":0")
205	if err != nil {
206		t.Fatalf("unable to run test server: %v", err)
207	}
208	defer s.Shutdown()
209
210	c := new(Client)
211	m := new(Msg)
212
213	// Test the Z flag is not echoed
214	m.SetQuestion("example.com.", TypeTXT)
215	m.Zero = true
216	r, _, err := c.Exchange(m, addrstr)
217	if err != nil {
218		t.Fatal("failed to exchange example.com with +zflag", err)
219	}
220	if r.Zero {
221		t.Error("the response should not have Z flag set - even for a query which does")
222	}
223	if r.Rcode != RcodeSuccess {
224		t.Errorf("expected rcode %v, got %v", RcodeSuccess, r.Rcode)
225	}
226}
227
228// Verify that the server responds to a query with unsupported Opcode with a NotImplemented error and that Opcode is unchanged.
229func TestServeNotImplemented(t *testing.T) {
230	HandleFunc("example.com.", AnotherHelloServer)
231	opcode := 15
232
233	s, addrstr, _, err := RunLocalUDPServer(":0")
234	if err != nil {
235		t.Fatalf("unable to run test server: %v", err)
236	}
237	defer s.Shutdown()
238
239	c := new(Client)
240	m := new(Msg)
241
242	// Test that Opcode is like the unchanged from request Opcode and that Rcode is set to NotImplemented
243	m.SetQuestion("example.com.", TypeTXT)
244	m.Opcode = opcode
245	r, _, err := c.Exchange(m, addrstr)
246	if err != nil {
247		t.Fatal("failed to exchange example.com with +zflag", err)
248	}
249	if r.Opcode != opcode {
250		t.Errorf("expected opcode %v, got %v", opcode, r.Opcode)
251	}
252	if r.Rcode != RcodeNotImplemented {
253		t.Errorf("expected rcode %v, got %v", RcodeNotImplemented, r.Rcode)
254	}
255}
256
257func TestServingTLS(t *testing.T) {
258	HandleFunc("miek.nl.", HelloServer)
259	HandleFunc("example.com.", AnotherHelloServer)
260	defer HandleRemove("miek.nl.")
261	defer HandleRemove("example.com.")
262
263	cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock)
264	if err != nil {
265		t.Fatalf("unable to build certificate: %v", err)
266	}
267
268	config := tls.Config{
269		Certificates: []tls.Certificate{cert},
270	}
271
272	s, addrstr, _, err := RunLocalTLSServer(":0", &config)
273	if err != nil {
274		t.Fatalf("unable to run test server: %v", err)
275	}
276	defer s.Shutdown()
277
278	c := new(Client)
279	c.Net = "tcp-tls"
280	c.TLSConfig = &tls.Config{
281		InsecureSkipVerify: true,
282	}
283
284	m := new(Msg)
285	m.SetQuestion("miek.nl.", TypeTXT)
286	r, _, err := c.Exchange(m, addrstr)
287	if err != nil || len(r.Extra) == 0 {
288		t.Fatal("failed to exchange miek.nl", err)
289	}
290	txt := r.Extra[0].(*TXT).Txt[0]
291	if txt != "Hello world" {
292		t.Error("unexpected result for miek.nl", txt, "!= Hello world")
293	}
294
295	m.SetQuestion("example.com.", TypeTXT)
296	r, _, err = c.Exchange(m, addrstr)
297	if err != nil {
298		t.Fatal("failed to exchange example.com", err)
299	}
300	txt = r.Extra[0].(*TXT).Txt[0]
301	if txt != "Hello example" {
302		t.Error("unexpected result for example.com", txt, "!= Hello example")
303	}
304
305	// Test Mixes cased as noticed by Ask.
306	m.SetQuestion("eXaMplE.cOm.", TypeTXT)
307	r, _, err = c.Exchange(m, addrstr)
308	if err != nil {
309		t.Error("failed to exchange eXaMplE.cOm", err)
310	}
311	txt = r.Extra[0].(*TXT).Txt[0]
312	if txt != "Hello example" {
313		t.Error("unexpected result for example.com", txt, "!= Hello example")
314	}
315}
316
317// TestServingTLSConnectionState tests that we only can access
318// tls.ConnectionState under a DNS query handled by a TLS DNS server.
319// This test will sequentially create a TLS, UDP and TCP server, attach a custom
320// handler which will set a testing error if tls.ConnectionState is available
321// when it is not expected, or the other way around.
322func TestServingTLSConnectionState(t *testing.T) {
323	handlerResponse := "Hello example"
324	// tlsHandlerTLS is a HandlerFunc that can be set to expect or not TLS
325	// connection state.
326	tlsHandlerTLS := func(tlsExpected bool) func(ResponseWriter, *Msg) {
327		return func(w ResponseWriter, req *Msg) {
328			m := new(Msg)
329			m.SetReply(req)
330			tlsFound := true
331			if connState := w.(ConnectionStater).ConnectionState(); connState == nil {
332				tlsFound = false
333			}
334			if tlsFound != tlsExpected {
335				t.Errorf("TLS connection state available: %t, expected: %t", tlsFound, tlsExpected)
336			}
337			m.Extra = make([]RR, 1)
338			m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{handlerResponse}}
339			w.WriteMsg(m)
340		}
341	}
342
343	// Question used in tests
344	m := new(Msg)
345	m.SetQuestion("tlsstate.example.net.", TypeTXT)
346
347	// TLS DNS server
348	HandleFunc(".", tlsHandlerTLS(true))
349	cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock)
350	if err != nil {
351		t.Fatalf("unable to build certificate: %v", err)
352	}
353
354	config := tls.Config{
355		Certificates: []tls.Certificate{cert},
356	}
357
358	s, addrstr, _, err := RunLocalTLSServer(":0", &config)
359	if err != nil {
360		t.Fatalf("unable to run test server: %v", err)
361	}
362	defer s.Shutdown()
363
364	// TLS DNS query
365	c := &Client{
366		Net: "tcp-tls",
367		TLSConfig: &tls.Config{
368			InsecureSkipVerify: true,
369		},
370	}
371
372	_, _, err = c.Exchange(m, addrstr)
373	if err != nil {
374		t.Error("failed to exchange tlsstate.example.net", err)
375	}
376
377	HandleRemove(".")
378	// UDP DNS Server
379	HandleFunc(".", tlsHandlerTLS(false))
380	defer HandleRemove(".")
381	s, addrstr, _, err = RunLocalUDPServer(":0")
382	if err != nil {
383		t.Fatalf("unable to run test server: %v", err)
384	}
385	defer s.Shutdown()
386
387	// UDP DNS query
388	c = new(Client)
389	_, _, err = c.Exchange(m, addrstr)
390	if err != nil {
391		t.Error("failed to exchange tlsstate.example.net", err)
392	}
393
394	// TCP DNS Server
395	s, addrstr, _, err = RunLocalTCPServer(":0")
396	if err != nil {
397		t.Fatalf("unable to run test server: %v", err)
398	}
399	defer s.Shutdown()
400
401	// TCP DNS query
402	c = &Client{Net: "tcp"}
403	_, _, err = c.Exchange(m, addrstr)
404	if err != nil {
405		t.Error("failed to exchange tlsstate.example.net", err)
406	}
407}
408
409func TestServingListenAndServe(t *testing.T) {
410	HandleFunc("example.com.", AnotherHelloServer)
411	defer HandleRemove("example.com.")
412
413	waitLock := sync.Mutex{}
414	server := &Server{Addr: ":0", Net: "udp", ReadTimeout: time.Hour, WriteTimeout: time.Hour, NotifyStartedFunc: waitLock.Unlock}
415	waitLock.Lock()
416
417	go func() {
418		server.ListenAndServe()
419	}()
420	waitLock.Lock()
421
422	c, m := new(Client), new(Msg)
423	m.SetQuestion("example.com.", TypeTXT)
424	addr := server.PacketConn.LocalAddr().String() // Get address via the PacketConn that gets set.
425	r, _, err := c.Exchange(m, addr)
426	if err != nil {
427		t.Fatal("failed to exchange example.com", err)
428	}
429	txt := r.Extra[0].(*TXT).Txt[0]
430	if txt != "Hello example" {
431		t.Error("unexpected result for example.com", txt, "!= Hello example")
432	}
433	server.Shutdown()
434}
435
436func TestServingListenAndServeTLS(t *testing.T) {
437	HandleFunc("example.com.", AnotherHelloServer)
438	defer HandleRemove("example.com.")
439
440	cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock)
441	if err != nil {
442		t.Fatalf("unable to build certificate: %v", err)
443	}
444
445	config := &tls.Config{
446		Certificates: []tls.Certificate{cert},
447	}
448
449	waitLock := sync.Mutex{}
450	server := &Server{Addr: ":0", Net: "tcp", TLSConfig: config, ReadTimeout: time.Hour, WriteTimeout: time.Hour, NotifyStartedFunc: waitLock.Unlock}
451	waitLock.Lock()
452
453	go func() {
454		server.ListenAndServe()
455	}()
456	waitLock.Lock()
457
458	c, m := new(Client), new(Msg)
459	c.Net = "tcp"
460	m.SetQuestion("example.com.", TypeTXT)
461	addr := server.Listener.Addr().String() // Get address via the Listener that gets set.
462	r, _, err := c.Exchange(m, addr)
463	if err != nil {
464		t.Fatal(err)
465	}
466	txt := r.Extra[0].(*TXT).Txt[0]
467	if txt != "Hello example" {
468		t.Error("unexpected result for example.com", txt, "!= Hello example")
469	}
470	server.Shutdown()
471}
472
473func BenchmarkServe(b *testing.B) {
474	b.StopTimer()
475	HandleFunc("miek.nl.", HelloServer)
476	defer HandleRemove("miek.nl.")
477	a := runtime.GOMAXPROCS(4)
478
479	s, addrstr, _, err := RunLocalUDPServer(":0")
480	if err != nil {
481		b.Fatalf("unable to run test server: %v", err)
482	}
483	defer s.Shutdown()
484
485	c := new(Client)
486	m := new(Msg)
487	m.SetQuestion("miek.nl.", TypeSOA)
488
489	b.StartTimer()
490	for i := 0; i < b.N; i++ {
491		_, _, err := c.Exchange(m, addrstr)
492		if err != nil {
493			b.Fatalf("Exchange failed: %v", err)
494		}
495	}
496	runtime.GOMAXPROCS(a)
497}
498
499func BenchmarkServe6(b *testing.B) {
500	b.StopTimer()
501	HandleFunc("miek.nl.", HelloServer)
502	defer HandleRemove("miek.nl.")
503	a := runtime.GOMAXPROCS(4)
504	s, addrstr, _, err := RunLocalUDPServer("[::1]:0")
505	if err != nil {
506		if strings.Contains(err.Error(), "bind: cannot assign requested address") {
507			b.Skip("missing IPv6 support")
508		}
509		b.Fatalf("unable to run test server: %v", err)
510	}
511	defer s.Shutdown()
512
513	c := new(Client)
514	m := new(Msg)
515	m.SetQuestion("miek.nl.", TypeSOA)
516
517	b.StartTimer()
518	for i := 0; i < b.N; i++ {
519		_, _, err := c.Exchange(m, addrstr)
520		if err != nil {
521			b.Fatalf("Exchange failed: %v", err)
522		}
523	}
524	runtime.GOMAXPROCS(a)
525}
526
527func HelloServerCompress(w ResponseWriter, req *Msg) {
528	m := new(Msg)
529	m.SetReply(req)
530	m.Extra = make([]RR, 1)
531	m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{"Hello world"}}
532	m.Compress = true
533	w.WriteMsg(m)
534}
535
536func BenchmarkServeCompress(b *testing.B) {
537	b.StopTimer()
538	HandleFunc("miek.nl.", HelloServerCompress)
539	defer HandleRemove("miek.nl.")
540	a := runtime.GOMAXPROCS(4)
541	s, addrstr, _, err := RunLocalUDPServer(":0")
542	if err != nil {
543		b.Fatalf("unable to run test server: %v", err)
544	}
545	defer s.Shutdown()
546
547	c := new(Client)
548	m := new(Msg)
549	m.SetQuestion("miek.nl.", TypeSOA)
550	b.StartTimer()
551	for i := 0; i < b.N; i++ {
552		_, _, err := c.Exchange(m, addrstr)
553		if err != nil {
554			b.Fatalf("Exchange failed: %v", err)
555		}
556	}
557	runtime.GOMAXPROCS(a)
558}
559
560type maxRec struct {
561	max int
562	sync.RWMutex
563}
564
565var M = new(maxRec)
566
567func HelloServerLargeResponse(resp ResponseWriter, req *Msg) {
568	m := new(Msg)
569	m.SetReply(req)
570	m.Authoritative = true
571	m1 := 0
572	M.RLock()
573	m1 = M.max
574	M.RUnlock()
575	for i := 0; i < m1; i++ {
576		aRec := &A{
577			Hdr: RR_Header{
578				Name:   req.Question[0].Name,
579				Rrtype: TypeA,
580				Class:  ClassINET,
581				Ttl:    0,
582			},
583			A: net.ParseIP(fmt.Sprintf("127.0.0.%d", i+1)).To4(),
584		}
585		m.Answer = append(m.Answer, aRec)
586	}
587	resp.WriteMsg(m)
588}
589
590func TestServingLargeResponses(t *testing.T) {
591	HandleFunc("example.", HelloServerLargeResponse)
592	defer HandleRemove("example.")
593
594	s, addrstr, _, err := RunLocalUDPServer(":0")
595	if err != nil {
596		t.Fatalf("unable to run test server: %v", err)
597	}
598	defer s.Shutdown()
599
600	// Create request
601	m := new(Msg)
602	m.SetQuestion("web.service.example.", TypeANY)
603
604	c := new(Client)
605	c.Net = "udp"
606	M.Lock()
607	M.max = 2
608	M.Unlock()
609	_, _, err = c.Exchange(m, addrstr)
610	if err != nil {
611		t.Errorf("failed to exchange: %v", err)
612	}
613	// This must fail
614	M.Lock()
615	M.max = 20
616	M.Unlock()
617	_, _, err = c.Exchange(m, addrstr)
618	if err == nil {
619		t.Error("failed to fail exchange, this should generate packet error")
620	}
621	// But this must work again
622	c.UDPSize = 7000
623	_, _, err = c.Exchange(m, addrstr)
624	if err != nil {
625		t.Errorf("failed to exchange: %v", err)
626	}
627}
628
629func TestServingResponse(t *testing.T) {
630	if testing.Short() {
631		t.Skip("skipping test in short mode.")
632	}
633	HandleFunc("miek.nl.", HelloServer)
634	s, addrstr, _, err := RunLocalUDPServer(":0")
635	if err != nil {
636		t.Fatalf("unable to run test server: %v", err)
637	}
638	defer s.Shutdown()
639
640	c := new(Client)
641	m := new(Msg)
642	m.SetQuestion("miek.nl.", TypeTXT)
643	m.Response = false
644	_, _, err = c.Exchange(m, addrstr)
645	if err != nil {
646		t.Fatal("failed to exchange", err)
647	}
648	m.Response = true // this holds up the reply, set short read time out to avoid waiting too long
649	c.ReadTimeout = 100 * time.Millisecond
650	_, _, err = c.Exchange(m, addrstr)
651	if err == nil {
652		t.Fatal("exchanged response message")
653	}
654}
655
656func TestShutdownTCP(t *testing.T) {
657	s, _, fin, err := RunLocalTCPServer(":0")
658	if err != nil {
659		t.Fatalf("unable to run test server: %v", err)
660	}
661	err = s.Shutdown()
662	if err != nil {
663		t.Fatalf("could not shutdown test TCP server, %v", err)
664	}
665	select {
666	case err := <-fin:
667		if err != nil {
668			t.Errorf("error returned from ActivateAndServe, %v", err)
669		}
670	case <-time.After(2 * time.Second):
671		t.Error("could not shutdown test TCP server. Gave up waiting")
672	}
673}
674
675func init() {
676	testShutdownNotify = &sync.Cond{
677		L: new(sync.Mutex),
678	}
679}
680
681func checkInProgressQueriesAtShutdownServer(t *testing.T, srv *Server, addr string, client *Client) {
682	const requests = 15 // enough to make this interesting? TODO: find a proper value
683
684	var errOnce sync.Once
685	// t.Fail will panic if it's called after the test function has
686	// finished. Burning the sync.Once with a defer will prevent the
687	// handler from calling t.Errorf after we've returned.
688	defer errOnce.Do(func() {})
689
690	toHandle := int32(requests)
691	HandleFunc("example.com.", func(w ResponseWriter, req *Msg) {
692		defer atomic.AddInt32(&toHandle, -1)
693
694		// Wait until ShutdownContext is called before replying.
695		testShutdownNotify.L.Lock()
696		testShutdownNotify.Wait()
697		testShutdownNotify.L.Unlock()
698
699		m := new(Msg)
700		m.SetReply(req)
701		m.Extra = make([]RR, 1)
702		m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{"Hello world"}}
703
704		if err := w.WriteMsg(m); err != nil {
705			errOnce.Do(func() {
706				t.Errorf("ResponseWriter.WriteMsg error: %s", err)
707			})
708		}
709	})
710	defer HandleRemove("example.com.")
711
712	client.Timeout = 1 * time.Second
713
714	conns := make([]*Conn, requests)
715	eg := new(errgroup.Group)
716
717	for i := range conns {
718		conn := &conns[i]
719		eg.Go(func() error {
720			var err error
721			*conn, err = client.Dial(addr)
722			return err
723		})
724	}
725
726	if eg.Wait() != nil {
727		t.Fatalf("client.Dial error: %v", eg.Wait())
728	}
729
730	m := new(Msg)
731	m.SetQuestion("example.com.", TypeTXT)
732	eg = new(errgroup.Group)
733
734	for _, conn := range conns {
735		conn := conn
736		eg.Go(func() error {
737			conn.SetWriteDeadline(time.Now().Add(client.Timeout))
738
739			return conn.WriteMsg(m)
740		})
741	}
742
743	if eg.Wait() != nil {
744		t.Fatalf("conn.WriteMsg error: %v", eg.Wait())
745	}
746
747	// This sleep is needed to allow time for the requests to
748	// pass from the client through the kernel and back into
749	// the server. Without it, some requests may still be in
750	// the kernel's buffer when ShutdownContext is called.
751	time.Sleep(100 * time.Millisecond)
752
753	eg = new(errgroup.Group)
754
755	for _, conn := range conns {
756		conn := conn
757		eg.Go(func() error {
758			conn.SetReadDeadline(time.Now().Add(client.Timeout))
759
760			_, err := conn.ReadMsg()
761			return err
762		})
763	}
764
765	ctx, cancel := context.WithTimeout(context.Background(), client.Timeout)
766	defer cancel()
767
768	if err := srv.ShutdownContext(ctx); err != nil {
769		t.Errorf("could not shutdown test server: %v", err)
770	}
771
772	if left := atomic.LoadInt32(&toHandle); left != 0 {
773		t.Errorf("ShutdownContext returned before %d replies", left)
774	}
775
776	if eg.Wait() != nil {
777		t.Errorf("conn.ReadMsg error: %v", eg.Wait())
778	}
779
780	srv.lock.RLock()
781	defer srv.lock.RUnlock()
782	if len(srv.conns) != 0 {
783		t.Errorf("TCP connection tracking map not empty after ShutdownContext; map still contains %d connections", len(srv.conns))
784	}
785}
786
787func TestInProgressQueriesAtShutdownTCP(t *testing.T) {
788	s, addr, _, err := RunLocalTCPServer(":0")
789	if err != nil {
790		t.Fatalf("unable to run test server: %v", err)
791	}
792
793	c := &Client{Net: "tcp"}
794	checkInProgressQueriesAtShutdownServer(t, s, addr, c)
795}
796
797func TestShutdownTLS(t *testing.T) {
798	cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock)
799	if err != nil {
800		t.Fatalf("unable to build certificate: %v", err)
801	}
802
803	config := tls.Config{
804		Certificates: []tls.Certificate{cert},
805	}
806
807	s, _, _, err := RunLocalTLSServer(":0", &config)
808	if err != nil {
809		t.Fatalf("unable to run test server: %v", err)
810	}
811	err = s.Shutdown()
812	if err != nil {
813		t.Errorf("could not shutdown test TLS server, %v", err)
814	}
815}
816
817func TestInProgressQueriesAtShutdownTLS(t *testing.T) {
818	cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock)
819	if err != nil {
820		t.Fatalf("unable to build certificate: %v", err)
821	}
822
823	config := tls.Config{
824		Certificates: []tls.Certificate{cert},
825	}
826
827	s, addr, _, err := RunLocalTLSServer(":0", &config)
828	if err != nil {
829		t.Fatalf("unable to run test server: %v", err)
830	}
831
832	c := &Client{
833		Net: "tcp-tls",
834		TLSConfig: &tls.Config{
835			InsecureSkipVerify: true,
836		},
837	}
838	checkInProgressQueriesAtShutdownServer(t, s, addr, c)
839}
840
841func TestHandlerCloseTCP(t *testing.T) {
842	ln, err := net.Listen("tcp", ":0")
843	if err != nil {
844		panic(err)
845	}
846	addr := ln.Addr().String()
847
848	server := &Server{Addr: addr, Net: "tcp", Listener: ln}
849
850	hname := "testhandlerclosetcp."
851	triggered := make(chan struct{})
852	HandleFunc(hname, func(w ResponseWriter, r *Msg) {
853		close(triggered)
854		w.Close()
855	})
856	defer HandleRemove(hname)
857
858	go func() {
859		defer server.Shutdown()
860		c := &Client{Net: "tcp"}
861		m := new(Msg).SetQuestion(hname, 1)
862		tries := 0
863	exchange:
864		_, _, err := c.Exchange(m, addr)
865		if err != nil && err != io.EOF {
866			t.Errorf("exchange failed: %v", err)
867			if tries == 3 {
868				return
869			}
870			time.Sleep(time.Second / 10)
871			tries++
872			goto exchange
873		}
874	}()
875	if err := server.ActivateAndServe(); err != nil {
876		t.Fatalf("ActivateAndServe failed: %v", err)
877	}
878	select {
879	case <-triggered:
880	default:
881		t.Fatalf("handler never called")
882	}
883}
884
885func TestShutdownUDP(t *testing.T) {
886	s, _, fin, err := RunLocalUDPServer(":0")
887	if err != nil {
888		t.Fatalf("unable to run test server: %v", err)
889	}
890	err = s.Shutdown()
891	if err != nil {
892		t.Errorf("could not shutdown test UDP server, %v", err)
893	}
894	select {
895	case err := <-fin:
896		if err != nil {
897			t.Errorf("error returned from ActivateAndServe, %v", err)
898		}
899	case <-time.After(2 * time.Second):
900		t.Error("could not shutdown test UDP server. Gave up waiting")
901	}
902}
903
904func TestShutdownPacketConn(t *testing.T) {
905	s, _, fin, err := RunLocalPacketConnServer(":0")
906	if err != nil {
907		t.Fatalf("unable to run test server: %v", err)
908	}
909	err = s.Shutdown()
910	if err != nil {
911		t.Errorf("could not shutdown test UDP server, %v", err)
912	}
913	select {
914	case err := <-fin:
915		if err != nil {
916			t.Errorf("error returned from ActivateAndServe, %v", err)
917		}
918	case <-time.After(2 * time.Second):
919		t.Error("could not shutdown test UDP server. Gave up waiting")
920	}
921}
922
923func TestInProgressQueriesAtShutdownUDP(t *testing.T) {
924	s, addr, _, err := RunLocalUDPServer(":0")
925	if err != nil {
926		t.Fatalf("unable to run test server: %v", err)
927	}
928
929	c := &Client{Net: "udp"}
930	checkInProgressQueriesAtShutdownServer(t, s, addr, c)
931}
932
933func TestInProgressQueriesAtShutdownPacketConn(t *testing.T) {
934	s, addr, _, err := RunLocalPacketConnServer(":0")
935	if err != nil {
936		t.Fatalf("unable to run test server: %v", err)
937	}
938
939	c := &Client{Net: "udp"}
940	checkInProgressQueriesAtShutdownServer(t, s, addr, c)
941}
942
943func TestServerStartStopRace(t *testing.T) {
944	var wg sync.WaitGroup
945	for i := 0; i < 10; i++ {
946		wg.Add(1)
947		s, _, _, err := RunLocalUDPServer(":0")
948		if err != nil {
949			t.Fatalf("could not start server: %s", err)
950		}
951		go func() {
952			defer wg.Done()
953			if err := s.Shutdown(); err != nil {
954				t.Errorf("could not stop server: %s", err)
955			}
956		}()
957	}
958	wg.Wait()
959}
960
961func TestServerReuseport(t *testing.T) {
962	if !supportsReusePort {
963		t.Skip("reuseport is not supported")
964	}
965
966	startServer := func(addr string) (*Server, chan error) {
967		wait := make(chan struct{})
968		srv := &Server{
969			Net:               "udp",
970			Addr:              addr,
971			NotifyStartedFunc: func() { close(wait) },
972			ReusePort:         true,
973		}
974
975		fin := make(chan error, 1)
976		go func() {
977			fin <- srv.ListenAndServe()
978		}()
979
980		select {
981		case <-wait:
982		case err := <-fin:
983			t.Fatalf("failed to start server: %v", err)
984		}
985
986		return srv, fin
987	}
988
989	srv1, fin1 := startServer(":0") // :0 is resolved to a random free port by the kernel
990	srv2, fin2 := startServer(srv1.PacketConn.LocalAddr().String())
991
992	if err := srv1.Shutdown(); err != nil {
993		t.Fatalf("failed to shutdown first server: %v", err)
994	}
995	if err := srv2.Shutdown(); err != nil {
996		t.Fatalf("failed to shutdown second server: %v", err)
997	}
998
999	if err := <-fin1; err != nil {
1000		t.Fatalf("first ListenAndServe returned error after Shutdown: %v", err)
1001	}
1002	if err := <-fin2; err != nil {
1003		t.Fatalf("second ListenAndServe returned error after Shutdown: %v", err)
1004	}
1005}
1006
1007func TestServerRoundtripTsig(t *testing.T) {
1008	secret := map[string]string{"test.": "so6ZGir4GPAqINNh9U5c3A=="}
1009
1010	s, addrstr, _, err := RunLocalUDPServer(":0", func(srv *Server) {
1011		srv.TsigSecret = secret
1012		srv.MsgAcceptFunc = func(dh Header) MsgAcceptAction {
1013			// defaultMsgAcceptFunc does reject UPDATE queries
1014			return MsgAccept
1015		}
1016	})
1017	if err != nil {
1018		t.Fatalf("unable to run test server: %v", err)
1019	}
1020	defer s.Shutdown()
1021
1022	handlerFired := make(chan struct{})
1023	HandleFunc("example.com.", func(w ResponseWriter, r *Msg) {
1024		close(handlerFired)
1025
1026		m := new(Msg)
1027		m.SetReply(r)
1028		if r.IsTsig() != nil {
1029			status := w.TsigStatus()
1030			if status == nil {
1031				// *Msg r has an TSIG record and it was validated
1032				m.SetTsig("test.", HmacSHA256, 300, time.Now().Unix())
1033			} else {
1034				// *Msg r has an TSIG records and it was not validated
1035				t.Errorf("invalid TSIG: %v", status)
1036			}
1037		} else {
1038			t.Error("missing TSIG")
1039		}
1040		if err := w.WriteMsg(m); err != nil {
1041			t.Error("writemsg failed", err)
1042		}
1043	})
1044
1045	c := new(Client)
1046	m := new(Msg)
1047	m.Opcode = OpcodeUpdate
1048	m.SetQuestion("example.com.", TypeSOA)
1049	m.Ns = []RR{&CNAME{
1050		Hdr: RR_Header{
1051			Name:   "foo.example.com.",
1052			Rrtype: TypeCNAME,
1053			Class:  ClassINET,
1054			Ttl:    300,
1055		},
1056		Target: "bar.example.com.",
1057	}}
1058	c.TsigSecret = secret
1059	m.SetTsig("test.", HmacSHA256, 300, time.Now().Unix())
1060	_, _, err = c.Exchange(m, addrstr)
1061	if err != nil {
1062		t.Fatal("failed to exchange", err)
1063	}
1064	select {
1065	case <-handlerFired:
1066		// ok, handler was actually called
1067	default:
1068		t.Error("handler was not called")
1069	}
1070}
1071
1072func TestResponseAfterClose(t *testing.T) {
1073	testError := func(name string, err error) {
1074		t.Helper()
1075
1076		expect := fmt.Sprintf("dns: %s called after Close", name)
1077		if err == nil {
1078			t.Errorf("expected error from %s after Close", name)
1079		} else if err.Error() != expect {
1080			t.Errorf("expected explicit error from %s after Close, expected %q, got %q", name, expect, err)
1081		}
1082	}
1083
1084	rw := &response{
1085		closed: true,
1086	}
1087
1088	_, err := rw.Write(make([]byte, 2))
1089	testError("Write", err)
1090
1091	testError("WriteMsg", rw.WriteMsg(new(Msg)))
1092}
1093
1094func TestResponseDoubleClose(t *testing.T) {
1095	rw := &response{
1096		closed: true,
1097	}
1098	if err, expect := rw.Close(), "dns: connection already closed"; err == nil || err.Error() != expect {
1099		t.Errorf("Close did not return expected: error %q, got: %v", expect, err)
1100	}
1101}
1102
1103type countingConn struct {
1104	net.Conn
1105	writes int
1106}
1107
1108func (c *countingConn) Write(p []byte) (int, error) {
1109	c.writes++
1110	return len(p), nil
1111}
1112
1113func TestResponseWriteSinglePacket(t *testing.T) {
1114	c := &countingConn{}
1115	rw := &response{
1116		tcp: c,
1117	}
1118	rw.writer = rw
1119
1120	m := new(Msg)
1121	m.SetQuestion("miek.nl.", TypeTXT)
1122	m.Response = true
1123	err := rw.WriteMsg(m)
1124
1125	if err != nil {
1126		t.Fatalf("failed to write: %v", err)
1127	}
1128
1129	if c.writes != 1 {
1130		t.Fatalf("incorrect number of Write calls")
1131	}
1132}
1133
1134type ExampleFrameLengthWriter struct {
1135	Writer
1136}
1137
1138func (e *ExampleFrameLengthWriter) Write(m []byte) (int, error) {
1139	fmt.Println("writing raw DNS message of length", len(m))
1140	return e.Writer.Write(m)
1141}
1142
1143func ExampleDecorateWriter() {
1144	// instrument raw DNS message writing
1145	wf := DecorateWriter(func(w Writer) Writer {
1146		return &ExampleFrameLengthWriter{w}
1147	})
1148
1149	// simple UDP server
1150	pc, err := net.ListenPacket("udp", ":0")
1151	if err != nil {
1152		fmt.Println(err.Error())
1153		return
1154	}
1155	server := &Server{
1156		PacketConn:     pc,
1157		DecorateWriter: wf,
1158		ReadTimeout:    time.Hour, WriteTimeout: time.Hour,
1159	}
1160
1161	waitLock := sync.Mutex{}
1162	waitLock.Lock()
1163	server.NotifyStartedFunc = waitLock.Unlock
1164	defer server.Shutdown()
1165
1166	go func() {
1167		server.ActivateAndServe()
1168		pc.Close()
1169	}()
1170
1171	waitLock.Lock()
1172
1173	HandleFunc("miek.nl.", HelloServer)
1174
1175	c := new(Client)
1176	m := new(Msg)
1177	m.SetQuestion("miek.nl.", TypeTXT)
1178	_, _, err = c.Exchange(m, pc.LocalAddr().String())
1179	if err != nil {
1180		fmt.Println("failed to exchange", err.Error())
1181		return
1182	}
1183	// Output: writing raw DNS message of length 56
1184}
1185
1186var (
1187	// CertPEMBlock is a X509 data used to test TLS servers (used with tls.X509KeyPair)
1188	CertPEMBlock = []byte(`-----BEGIN CERTIFICATE-----
1189MIIDAzCCAeugAwIBAgIRAJFYMkcn+b8dpU15wjf++GgwDQYJKoZIhvcNAQELBQAw
1190EjEQMA4GA1UEChMHQWNtZSBDbzAeFw0xNjAxMDgxMjAzNTNaFw0xNzAxMDcxMjAz
1191NTNaMBIxEDAOBgNVBAoTB0FjbWUgQ28wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAw
1192ggEKAoIBAQDXjqO6skvP03k58CNjQggd9G/mt+Wa+xRU+WXiKCCHttawM8x+slq5
1193yfsHCwxlwsGn79HmJqecNqgHb2GWBXAvVVokFDTcC1hUP4+gp2gu9Ny27UHTjlLm
1194O0l/xZ5MN8tfKyYlFw18tXu3fkaPyHj8v/D1RDkuo4ARdFvGSe8TqisbhLk2+9ow
1195xfIGbEM9Fdiw8qByC2+d+FfvzIKz3GfQVwn0VoRom8L6NBIANq1IGrB5JefZB6nv
1196DnfuxkBmY7F1513HKuEJ8KsLWWZWV9OPU4j4I4Rt+WJNlKjbD2srHxyrS2RDsr91
11978nCkNoWVNO3sZq0XkWKecdc921vL4ginAgMBAAGjVDBSMA4GA1UdDwEB/wQEAwIC
1198pDATBgNVHSUEDDAKBggrBgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MBoGA1UdEQQT
1199MBGCCWxvY2FsaG9zdIcEfwAAATANBgkqhkiG9w0BAQsFAAOCAQEAGcU3iyLBIVZj
1200aDzSvEDHUd1bnLBl1C58Xu/CyKlPqVU7mLfK0JcgEaYQTSX6fCJVNLbbCrcGLsPJ
1201fbjlBbyeLjTV413fxPVuona62pBFjqdtbli2Qe8FRH2KBdm41JUJGdo+SdsFu7nc
1202BFOcubdw6LLIXvsTvwndKcHWx1rMX709QU1Vn1GAIsbJV/DWI231Jyyb+lxAUx/C
12038vce5uVxiKcGS+g6OjsN3D3TtiEQGSXLh013W6Wsih8td8yMCMZ3w8LQ38br1GUe
1204ahLIgUJ9l6HDguM17R7kGqxNvbElsMUHfTtXXP7UDQUiYXDakg8xDP6n9DCDhJ8Y
1205bSt7OLB7NQ==
1206-----END CERTIFICATE-----`)
1207
1208	// KeyPEMBlock is a X509 data used to test TLS servers (used with tls.X509KeyPair)
1209	KeyPEMBlock = []byte(`-----BEGIN RSA PRIVATE KEY-----
1210MIIEpQIBAAKCAQEA146jurJLz9N5OfAjY0IIHfRv5rflmvsUVPll4iggh7bWsDPM
1211frJaucn7BwsMZcLBp+/R5iannDaoB29hlgVwL1VaJBQ03AtYVD+PoKdoLvTctu1B
1212045S5jtJf8WeTDfLXysmJRcNfLV7t35Gj8h4/L/w9UQ5LqOAEXRbxknvE6orG4S5
1213NvvaMMXyBmxDPRXYsPKgcgtvnfhX78yCs9xn0FcJ9FaEaJvC+jQSADatSBqweSXn
12142Qep7w537sZAZmOxdeddxyrhCfCrC1lmVlfTj1OI+COEbfliTZSo2w9rKx8cq0tk
1215Q7K/dfJwpDaFlTTt7GatF5FinnHXPdtby+IIpwIDAQABAoIBAAJK4RDmPooqTJrC
1216JA41MJLo+5uvjwCT9QZmVKAQHzByUFw1YNJkITTiognUI0CdzqNzmH7jIFs39ZeG
1217proKusO2G6xQjrNcZ4cV2fgyb5g4QHStl0qhs94A+WojduiGm2IaumAgm6Mc5wDv
1218ld6HmknN3Mku/ZCyanVFEIjOVn2WB7ZQLTBs6ZYaebTJG2Xv6p9t2YJW7pPQ9Xce
1219s9ohAWohyM4X/OvfnfnLtQp2YLw/BxwehBsCR5SXM3ibTKpFNtxJC8hIfTuWtxZu
12202ywrmXShYBRB1WgtZt5k04bY/HFncvvcHK3YfI1+w4URKtwdaQgPUQRbVwDwuyBn
1221flfkCJECgYEA/eWt01iEyE/lXkGn6V9lCocUU7lCU6yk5UT8VXVUc5If4KZKPfCk
1222p4zJDOqwn2eM673aWz/mG9mtvAvmnugaGjcaVCyXOp/D/GDmKSoYcvW5B/yjfkLy
1223dK6Yaa5LDRVYlYgyzcdCT5/9Qc626NzFwKCZNI4ncIU8g7ViATRxWJ8CgYEA2Ver
1224vZ0M606sfgC0H3NtwNBxmuJ+lIF5LNp/wDi07lDfxRR1rnZMX5dnxjcpDr/zvm8J
1225WtJJX3xMgqjtHuWKL3yKKony9J5ZPjichSbSbhrzfovgYIRZLxLLDy4MP9L3+CX/
1226yBXnqMWuSnFX+M5fVGxdDWiYF3V+wmeOv9JvavkCgYEAiXAPDFzaY+R78O3xiu7M
1227r0o3wqqCMPE/wav6O/hrYrQy9VSO08C0IM6g9pEEUwWmzuXSkZqhYWoQFb8Lc/GI
1228T7CMXAxXQLDDUpbRgG79FR3Wr3AewHZU8LyiXHKwxcBMV4WGmsXGK3wbh8fyU1NO
12296NsGk+BvkQVOoK1LBAPzZ1kCgYEAsBSmD8U33T9s4dxiEYTrqyV0lH3g/SFz8ZHH
1230pAyNEPI2iC1ONhyjPWKlcWHpAokiyOqeUpVBWnmSZtzC1qAydsxYB6ShT+sl9BHb
1231RMix/QAauzBJhQhUVJ3OIys0Q1UBDmqCsjCE8SfOT4NKOUnA093C+YT+iyrmmktZ
1232zDCJkckCgYEAndqM5KXGk5xYo+MAA1paZcbTUXwaWwjLU+XSRSSoyBEi5xMtfvUb
12337+a1OMhLwWbuz+pl64wFKrbSUyimMOYQpjVE/1vk/kb99pxbgol27hdKyTH1d+ov
1234kFsxKCqxAnBVGEWAvVZAiiTOxleQFjz5RnL0BQp9Lg2cQe+dvuUmIAA=
1235-----END RSA PRIVATE KEY-----`)
1236)
1237