1package dns
2
3import (
4	"context"
5	"crypto/tls"
6	"fmt"
7	"io"
8	"net"
9	"runtime"
10	"strings"
11	"sync"
12	"sync/atomic"
13	"testing"
14	"time"
15
16	"golang.org/x/sync/errgroup"
17)
18
19func HelloServer(w ResponseWriter, req *Msg) {
20	m := new(Msg)
21	m.SetReply(req)
22
23	m.Extra = make([]RR, 1)
24	m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{"Hello world"}}
25	w.WriteMsg(m)
26}
27
28func HelloServerBadID(w ResponseWriter, req *Msg) {
29	m := new(Msg)
30	m.SetReply(req)
31	m.Id++
32
33	m.Extra = make([]RR, 1)
34	m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{"Hello world"}}
35	w.WriteMsg(m)
36}
37
38func HelloServerEchoAddrPort(w ResponseWriter, req *Msg) {
39	m := new(Msg)
40	m.SetReply(req)
41
42	remoteAddr := w.RemoteAddr().String()
43	m.Extra = make([]RR, 1)
44	m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{remoteAddr}}
45	w.WriteMsg(m)
46}
47
48func AnotherHelloServer(w ResponseWriter, req *Msg) {
49	m := new(Msg)
50	m.SetReply(req)
51
52	m.Extra = make([]RR, 1)
53	m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{"Hello example"}}
54	w.WriteMsg(m)
55}
56
57func RunLocalUDPServer(laddr string) (*Server, string, error) {
58	server, l, _, err := RunLocalUDPServerWithFinChan(laddr)
59
60	return server, l, err
61}
62
63func RunLocalUDPServerWithFinChan(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) {
64	pc, err := net.ListenPacket("udp", laddr)
65	if err != nil {
66		return nil, "", nil, err
67	}
68	server := &Server{PacketConn: pc, ReadTimeout: time.Hour, WriteTimeout: time.Hour}
69
70	waitLock := sync.Mutex{}
71	waitLock.Lock()
72	server.NotifyStartedFunc = waitLock.Unlock
73
74	// fin must be buffered so the goroutine below won't block
75	// forever if fin is never read from. This always happens
76	// in RunLocalUDPServer and can happen in TestShutdownUDP.
77	fin := make(chan error, 1)
78
79	for _, opt := range opts {
80		opt(server)
81	}
82
83	go func() {
84		fin <- server.ActivateAndServe()
85		pc.Close()
86	}()
87
88	waitLock.Lock()
89	return server, pc.LocalAddr().String(), fin, nil
90}
91
92func RunLocalTCPServer(laddr string) (*Server, string, error) {
93	server, l, _, err := RunLocalTCPServerWithFinChan(laddr)
94
95	return server, l, err
96}
97
98func RunLocalTCPServerWithFinChan(laddr string) (*Server, string, chan error, error) {
99	l, err := net.Listen("tcp", laddr)
100	if err != nil {
101		return nil, "", nil, err
102	}
103
104	server := &Server{Listener: l, ReadTimeout: time.Hour, WriteTimeout: time.Hour}
105
106	waitLock := sync.Mutex{}
107	waitLock.Lock()
108	server.NotifyStartedFunc = waitLock.Unlock
109
110	// See the comment in RunLocalUDPServerWithFinChan as to
111	// why fin must be buffered.
112	fin := make(chan error, 1)
113
114	go func() {
115		fin <- server.ActivateAndServe()
116		l.Close()
117	}()
118
119	waitLock.Lock()
120	return server, l.Addr().String(), fin, nil
121}
122
123func RunLocalTLSServer(laddr string, config *tls.Config) (*Server, string, error) {
124	l, err := tls.Listen("tcp", laddr, config)
125	if err != nil {
126		return nil, "", err
127	}
128
129	server := &Server{Listener: l, ReadTimeout: time.Hour, WriteTimeout: time.Hour}
130
131	waitLock := sync.Mutex{}
132	waitLock.Lock()
133	server.NotifyStartedFunc = waitLock.Unlock
134
135	go func() {
136		server.ActivateAndServe()
137		l.Close()
138	}()
139
140	waitLock.Lock()
141	return server, l.Addr().String(), nil
142}
143
144func TestServing(t *testing.T) {
145	HandleFunc("miek.nl.", HelloServer)
146	HandleFunc("example.com.", AnotherHelloServer)
147	defer HandleRemove("miek.nl.")
148	defer HandleRemove("example.com.")
149
150	s, addrstr, err := RunLocalUDPServer(":0")
151	if err != nil {
152		t.Fatalf("unable to run test server: %v", err)
153	}
154	defer s.Shutdown()
155
156	c := new(Client)
157	m := new(Msg)
158	m.SetQuestion("miek.nl.", TypeTXT)
159	r, _, err := c.Exchange(m, addrstr)
160	if err != nil || len(r.Extra) == 0 {
161		t.Fatal("failed to exchange miek.nl", err)
162	}
163	txt := r.Extra[0].(*TXT).Txt[0]
164	if txt != "Hello world" {
165		t.Error("unexpected result for miek.nl", txt, "!= Hello world")
166	}
167
168	m.SetQuestion("example.com.", TypeTXT)
169	r, _, err = c.Exchange(m, addrstr)
170	if err != nil {
171		t.Fatal("failed to exchange example.com", err)
172	}
173	txt = r.Extra[0].(*TXT).Txt[0]
174	if txt != "Hello example" {
175		t.Error("unexpected result for example.com", txt, "!= Hello example")
176	}
177
178	// Test Mixes cased as noticed by Ask.
179	m.SetQuestion("eXaMplE.cOm.", TypeTXT)
180	r, _, err = c.Exchange(m, addrstr)
181	if err != nil {
182		t.Error("failed to exchange eXaMplE.cOm", err)
183	}
184	txt = r.Extra[0].(*TXT).Txt[0]
185	if txt != "Hello example" {
186		t.Error("unexpected result for example.com", txt, "!= Hello example")
187	}
188}
189
190// Verify that the server responds to a query with Z flag on, ignoring the flag, and does not echoes it back
191func TestServeIgnoresZFlag(t *testing.T) {
192	HandleFunc("example.com.", AnotherHelloServer)
193
194	s, addrstr, err := RunLocalUDPServer(":0")
195	if err != nil {
196		t.Fatalf("unable to run test server: %v", err)
197	}
198	defer s.Shutdown()
199
200	c := new(Client)
201	m := new(Msg)
202
203	// Test the Z flag is not echoed
204	m.SetQuestion("example.com.", TypeTXT)
205	m.Zero = true
206	r, _, err := c.Exchange(m, addrstr)
207	if err != nil {
208		t.Fatal("failed to exchange example.com with +zflag", err)
209	}
210	if r.Zero {
211		t.Error("the response should not have Z flag set - even for a query which does")
212	}
213	if r.Rcode != RcodeSuccess {
214		t.Errorf("expected rcode %v, got %v", RcodeSuccess, r.Rcode)
215	}
216}
217
218// Verify that the server responds to a query with unsupported Opcode with a NotImplemented error and that Opcode is unchanged.
219func TestServeNotImplemented(t *testing.T) {
220	HandleFunc("example.com.", AnotherHelloServer)
221	opcode := 15
222
223	s, addrstr, err := RunLocalUDPServer(":0")
224	if err != nil {
225		t.Fatalf("unable to run test server: %v", err)
226	}
227	defer s.Shutdown()
228
229	c := new(Client)
230	m := new(Msg)
231
232	// Test that Opcode is like the unchanged from request Opcode and that Rcode is set to NotImplemnented
233	m.SetQuestion("example.com.", TypeTXT)
234	m.Opcode = opcode
235	r, _, err := c.Exchange(m, addrstr)
236	if err != nil {
237		t.Fatal("failed to exchange example.com with +zflag", err)
238	}
239	if r.Opcode != opcode {
240		t.Errorf("expected opcode %v, got %v", opcode, r.Opcode)
241	}
242	if r.Rcode != RcodeNotImplemented {
243		t.Errorf("expected rcode %v, got %v", RcodeNotImplemented, r.Rcode)
244	}
245}
246
247func TestServingTLS(t *testing.T) {
248	HandleFunc("miek.nl.", HelloServer)
249	HandleFunc("example.com.", AnotherHelloServer)
250	defer HandleRemove("miek.nl.")
251	defer HandleRemove("example.com.")
252
253	cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock)
254	if err != nil {
255		t.Fatalf("unable to build certificate: %v", err)
256	}
257
258	config := tls.Config{
259		Certificates: []tls.Certificate{cert},
260	}
261
262	s, addrstr, err := RunLocalTLSServer(":0", &config)
263	if err != nil {
264		t.Fatalf("unable to run test server: %v", err)
265	}
266	defer s.Shutdown()
267
268	c := new(Client)
269	c.Net = "tcp-tls"
270	c.TLSConfig = &tls.Config{
271		InsecureSkipVerify: true,
272	}
273
274	m := new(Msg)
275	m.SetQuestion("miek.nl.", TypeTXT)
276	r, _, err := c.Exchange(m, addrstr)
277	if err != nil || len(r.Extra) == 0 {
278		t.Fatal("failed to exchange miek.nl", err)
279	}
280	txt := r.Extra[0].(*TXT).Txt[0]
281	if txt != "Hello world" {
282		t.Error("unexpected result for miek.nl", txt, "!= Hello world")
283	}
284
285	m.SetQuestion("example.com.", TypeTXT)
286	r, _, err = c.Exchange(m, addrstr)
287	if err != nil {
288		t.Fatal("failed to exchange example.com", err)
289	}
290	txt = r.Extra[0].(*TXT).Txt[0]
291	if txt != "Hello example" {
292		t.Error("unexpected result for example.com", txt, "!= Hello example")
293	}
294
295	// Test Mixes cased as noticed by Ask.
296	m.SetQuestion("eXaMplE.cOm.", TypeTXT)
297	r, _, err = c.Exchange(m, addrstr)
298	if err != nil {
299		t.Error("failed to exchange eXaMplE.cOm", err)
300	}
301	txt = r.Extra[0].(*TXT).Txt[0]
302	if txt != "Hello example" {
303		t.Error("unexpected result for example.com", txt, "!= Hello example")
304	}
305}
306
307// TestServingTLSConnectionState tests that we only can access
308// tls.ConnectionState under a DNS query handled by a TLS DNS server.
309// This test will sequentially create a TLS, UDP and TCP server, attach a custom
310// handler which will set a testing error if tls.ConnectionState is available
311// when it is not expected, or the other way around.
312func TestServingTLSConnectionState(t *testing.T) {
313	handlerResponse := "Hello example"
314	// tlsHandlerTLS is a HandlerFunc that can be set to expect or not TLS
315	// connection state.
316	tlsHandlerTLS := func(tlsExpected bool) func(ResponseWriter, *Msg) {
317		return func(w ResponseWriter, req *Msg) {
318			m := new(Msg)
319			m.SetReply(req)
320			tlsFound := true
321			if connState := w.(ConnectionStater).ConnectionState(); connState == nil {
322				tlsFound = false
323			}
324			if tlsFound != tlsExpected {
325				t.Errorf("TLS connection state available: %t, expected: %t", tlsFound, tlsExpected)
326			}
327			m.Extra = make([]RR, 1)
328			m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{handlerResponse}}
329			w.WriteMsg(m)
330		}
331	}
332
333	// Question used in tests
334	m := new(Msg)
335	m.SetQuestion("tlsstate.example.net.", TypeTXT)
336
337	// TLS DNS server
338	HandleFunc(".", tlsHandlerTLS(true))
339	cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock)
340	if err != nil {
341		t.Fatalf("unable to build certificate: %v", err)
342	}
343
344	config := tls.Config{
345		Certificates: []tls.Certificate{cert},
346	}
347
348	s, addrstr, err := RunLocalTLSServer(":0", &config)
349	if err != nil {
350		t.Fatalf("unable to run test server: %v", err)
351	}
352	defer s.Shutdown()
353
354	// TLS DNS query
355	c := &Client{
356		Net: "tcp-tls",
357		TLSConfig: &tls.Config{
358			InsecureSkipVerify: true,
359		},
360	}
361
362	_, _, err = c.Exchange(m, addrstr)
363	if err != nil {
364		t.Error("failed to exchange tlsstate.example.net", err)
365	}
366
367	HandleRemove(".")
368	// UDP DNS Server
369	HandleFunc(".", tlsHandlerTLS(false))
370	defer HandleRemove(".")
371	s, addrstr, err = RunLocalUDPServer(":0")
372	if err != nil {
373		t.Fatalf("unable to run test server: %v", err)
374	}
375	defer s.Shutdown()
376
377	// UDP DNS query
378	c = new(Client)
379	_, _, err = c.Exchange(m, addrstr)
380	if err != nil {
381		t.Error("failed to exchange tlsstate.example.net", err)
382	}
383
384	// TCP DNS Server
385	s, addrstr, err = RunLocalTCPServer(":0")
386	if err != nil {
387		t.Fatalf("unable to run test server: %v", err)
388	}
389	defer s.Shutdown()
390
391	// TCP DNS query
392	c = &Client{Net: "tcp"}
393	_, _, err = c.Exchange(m, addrstr)
394	if err != nil {
395		t.Error("failed to exchange tlsstate.example.net", err)
396	}
397}
398
399func TestServingListenAndServe(t *testing.T) {
400	HandleFunc("example.com.", AnotherHelloServer)
401	defer HandleRemove("example.com.")
402
403	waitLock := sync.Mutex{}
404	server := &Server{Addr: ":0", Net: "udp", ReadTimeout: time.Hour, WriteTimeout: time.Hour, NotifyStartedFunc: waitLock.Unlock}
405	waitLock.Lock()
406
407	go func() {
408		server.ListenAndServe()
409	}()
410	waitLock.Lock()
411
412	c, m := new(Client), new(Msg)
413	m.SetQuestion("example.com.", TypeTXT)
414	addr := server.PacketConn.LocalAddr().String() // Get address via the PacketConn that gets set.
415	r, _, err := c.Exchange(m, addr)
416	if err != nil {
417		t.Fatal("failed to exchange example.com", err)
418	}
419	txt := r.Extra[0].(*TXT).Txt[0]
420	if txt != "Hello example" {
421		t.Error("unexpected result for example.com", txt, "!= Hello example")
422	}
423	server.Shutdown()
424}
425
426func TestServingListenAndServeTLS(t *testing.T) {
427	HandleFunc("example.com.", AnotherHelloServer)
428	defer HandleRemove("example.com.")
429
430	cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock)
431	if err != nil {
432		t.Fatalf("unable to build certificate: %v", err)
433	}
434
435	config := &tls.Config{
436		Certificates: []tls.Certificate{cert},
437	}
438
439	waitLock := sync.Mutex{}
440	server := &Server{Addr: ":0", Net: "tcp", TLSConfig: config, ReadTimeout: time.Hour, WriteTimeout: time.Hour, NotifyStartedFunc: waitLock.Unlock}
441	waitLock.Lock()
442
443	go func() {
444		server.ListenAndServe()
445	}()
446	waitLock.Lock()
447
448	c, m := new(Client), new(Msg)
449	c.Net = "tcp"
450	m.SetQuestion("example.com.", TypeTXT)
451	addr := server.Listener.Addr().String() // Get address via the Listener that gets set.
452	r, _, err := c.Exchange(m, addr)
453	if err != nil {
454		t.Fatal(err)
455	}
456	txt := r.Extra[0].(*TXT).Txt[0]
457	if txt != "Hello example" {
458		t.Error("unexpected result for example.com", txt, "!= Hello example")
459	}
460	server.Shutdown()
461}
462
463func BenchmarkServe(b *testing.B) {
464	b.StopTimer()
465	HandleFunc("miek.nl.", HelloServer)
466	defer HandleRemove("miek.nl.")
467	a := runtime.GOMAXPROCS(4)
468
469	s, addrstr, err := RunLocalUDPServer(":0")
470	if err != nil {
471		b.Fatalf("unable to run test server: %v", err)
472	}
473	defer s.Shutdown()
474
475	c := new(Client)
476	m := new(Msg)
477	m.SetQuestion("miek.nl.", TypeSOA)
478
479	b.StartTimer()
480	for i := 0; i < b.N; i++ {
481		_, _, err := c.Exchange(m, addrstr)
482		if err != nil {
483			b.Fatalf("Exchange failed: %v", err)
484		}
485	}
486	runtime.GOMAXPROCS(a)
487}
488
489func BenchmarkServe6(b *testing.B) {
490	b.StopTimer()
491	HandleFunc("miek.nl.", HelloServer)
492	defer HandleRemove("miek.nl.")
493	a := runtime.GOMAXPROCS(4)
494	s, addrstr, err := RunLocalUDPServer("[::1]:0")
495	if err != nil {
496		if strings.Contains(err.Error(), "bind: cannot assign requested address") {
497			b.Skip("missing IPv6 support")
498		}
499		b.Fatalf("unable to run test server: %v", err)
500	}
501	defer s.Shutdown()
502
503	c := new(Client)
504	m := new(Msg)
505	m.SetQuestion("miek.nl.", TypeSOA)
506
507	b.StartTimer()
508	for i := 0; i < b.N; i++ {
509		_, _, err := c.Exchange(m, addrstr)
510		if err != nil {
511			b.Fatalf("Exchange failed: %v", err)
512		}
513	}
514	runtime.GOMAXPROCS(a)
515}
516
517func HelloServerCompress(w ResponseWriter, req *Msg) {
518	m := new(Msg)
519	m.SetReply(req)
520	m.Extra = make([]RR, 1)
521	m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{"Hello world"}}
522	m.Compress = true
523	w.WriteMsg(m)
524}
525
526func BenchmarkServeCompress(b *testing.B) {
527	b.StopTimer()
528	HandleFunc("miek.nl.", HelloServerCompress)
529	defer HandleRemove("miek.nl.")
530	a := runtime.GOMAXPROCS(4)
531	s, addrstr, err := RunLocalUDPServer(":0")
532	if err != nil {
533		b.Fatalf("unable to run test server: %v", err)
534	}
535	defer s.Shutdown()
536
537	c := new(Client)
538	m := new(Msg)
539	m.SetQuestion("miek.nl.", TypeSOA)
540	b.StartTimer()
541	for i := 0; i < b.N; i++ {
542		_, _, err := c.Exchange(m, addrstr)
543		if err != nil {
544			b.Fatalf("Exchange failed: %v", err)
545		}
546	}
547	runtime.GOMAXPROCS(a)
548}
549
550type maxRec struct {
551	max int
552	sync.RWMutex
553}
554
555var M = new(maxRec)
556
557func HelloServerLargeResponse(resp ResponseWriter, req *Msg) {
558	m := new(Msg)
559	m.SetReply(req)
560	m.Authoritative = true
561	m1 := 0
562	M.RLock()
563	m1 = M.max
564	M.RUnlock()
565	for i := 0; i < m1; i++ {
566		aRec := &A{
567			Hdr: RR_Header{
568				Name:   req.Question[0].Name,
569				Rrtype: TypeA,
570				Class:  ClassINET,
571				Ttl:    0,
572			},
573			A: net.ParseIP(fmt.Sprintf("127.0.0.%d", i+1)).To4(),
574		}
575		m.Answer = append(m.Answer, aRec)
576	}
577	resp.WriteMsg(m)
578}
579
580func TestServingLargeResponses(t *testing.T) {
581	HandleFunc("example.", HelloServerLargeResponse)
582	defer HandleRemove("example.")
583
584	s, addrstr, err := RunLocalUDPServer(":0")
585	if err != nil {
586		t.Fatalf("unable to run test server: %v", err)
587	}
588	defer s.Shutdown()
589
590	// Create request
591	m := new(Msg)
592	m.SetQuestion("web.service.example.", TypeANY)
593
594	c := new(Client)
595	c.Net = "udp"
596	M.Lock()
597	M.max = 2
598	M.Unlock()
599	_, _, err = c.Exchange(m, addrstr)
600	if err != nil {
601		t.Errorf("failed to exchange: %v", err)
602	}
603	// This must fail
604	M.Lock()
605	M.max = 20
606	M.Unlock()
607	_, _, err = c.Exchange(m, addrstr)
608	if err == nil {
609		t.Error("failed to fail exchange, this should generate packet error")
610	}
611	// But this must work again
612	c.UDPSize = 7000
613	_, _, err = c.Exchange(m, addrstr)
614	if err != nil {
615		t.Errorf("failed to exchange: %v", err)
616	}
617}
618
619func TestServingResponse(t *testing.T) {
620	if testing.Short() {
621		t.Skip("skipping test in short mode.")
622	}
623	HandleFunc("miek.nl.", HelloServer)
624	s, addrstr, err := RunLocalUDPServer(":0")
625	if err != nil {
626		t.Fatalf("unable to run test server: %v", err)
627	}
628	defer s.Shutdown()
629
630	c := new(Client)
631	m := new(Msg)
632	m.SetQuestion("miek.nl.", TypeTXT)
633	m.Response = false
634	_, _, err = c.Exchange(m, addrstr)
635	if err != nil {
636		t.Fatal("failed to exchange", err)
637	}
638	m.Response = true
639	_, _, err = c.Exchange(m, addrstr)
640	if err == nil {
641		t.Fatal("exchanged response message")
642	}
643}
644
645func TestShutdownTCP(t *testing.T) {
646	s, _, fin, err := RunLocalTCPServerWithFinChan(":0")
647	if err != nil {
648		t.Fatalf("unable to run test server: %v", err)
649	}
650	err = s.Shutdown()
651	if err != nil {
652		t.Fatalf("could not shutdown test TCP server, %v", err)
653	}
654	select {
655	case err := <-fin:
656		if err != nil {
657			t.Errorf("error returned from ActivateAndServe, %v", err)
658		}
659	case <-time.After(2 * time.Second):
660		t.Error("could not shutdown test TCP server. Gave up waiting")
661	}
662}
663
664func init() {
665	testShutdownNotify = &sync.Cond{
666		L: new(sync.Mutex),
667	}
668}
669
670func checkInProgressQueriesAtShutdownServer(t *testing.T, srv *Server, addr string, client *Client) {
671	const requests = 100
672
673	var errOnce sync.Once
674	// t.Fail will panic if it's called after the test function has
675	// finished. Burning the sync.Once with a defer will prevent the
676	// handler from calling t.Errorf after we've returned.
677	defer errOnce.Do(func() {})
678
679	toHandle := int32(requests)
680	HandleFunc("example.com.", func(w ResponseWriter, req *Msg) {
681		defer atomic.AddInt32(&toHandle, -1)
682
683		// Wait until ShutdownContext is called before replying.
684		testShutdownNotify.L.Lock()
685		testShutdownNotify.Wait()
686		testShutdownNotify.L.Unlock()
687
688		m := new(Msg)
689		m.SetReply(req)
690		m.Extra = make([]RR, 1)
691		m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{"Hello world"}}
692
693		if err := w.WriteMsg(m); err != nil {
694			errOnce.Do(func() {
695				t.Errorf("ResponseWriter.WriteMsg error: %s", err)
696			})
697		}
698	})
699	defer HandleRemove("example.com.")
700
701	client.Timeout = 10 * time.Second
702
703	conns := make([]*Conn, requests)
704	eg := new(errgroup.Group)
705
706	for i := range conns {
707		conn := &conns[i]
708		eg.Go(func() error {
709			var err error
710			*conn, err = client.Dial(addr)
711			return err
712		})
713	}
714
715	if eg.Wait() != nil {
716		t.Fatalf("client.Dial error: %v", eg.Wait())
717	}
718
719	m := new(Msg)
720	m.SetQuestion("example.com.", TypeTXT)
721	eg = new(errgroup.Group)
722
723	for _, conn := range conns {
724		conn := conn
725		eg.Go(func() error {
726			conn.SetWriteDeadline(time.Now().Add(client.Timeout))
727
728			return conn.WriteMsg(m)
729		})
730	}
731
732	if eg.Wait() != nil {
733		t.Fatalf("conn.WriteMsg error: %v", eg.Wait())
734	}
735
736	// This sleep is needed to allow time for the requests to
737	// pass from the client through the kernel and back into
738	// the server. Without it, some requests may still be in
739	// the kernel's buffer when ShutdownContext is called.
740	time.Sleep(100 * time.Millisecond)
741
742	eg = new(errgroup.Group)
743
744	for _, conn := range conns {
745		conn := conn
746		eg.Go(func() error {
747			conn.SetReadDeadline(time.Now().Add(client.Timeout))
748
749			_, err := conn.ReadMsg()
750			return err
751		})
752	}
753
754	ctx, cancel := context.WithTimeout(context.Background(), client.Timeout)
755	defer cancel()
756
757	if err := srv.ShutdownContext(ctx); err != nil {
758		t.Errorf("could not shutdown test server: %v", err)
759	}
760
761	if left := atomic.LoadInt32(&toHandle); left != 0 {
762		t.Errorf("ShutdownContext returned before %d replies", left)
763	}
764
765	if eg.Wait() != nil {
766		t.Errorf("conn.ReadMsg error: %v", eg.Wait())
767	}
768
769	srv.lock.RLock()
770	defer srv.lock.RUnlock()
771	if len(srv.conns) != 0 {
772		t.Errorf("TCP connection tracking map not empty after ShutdownContext; map still contains %d connections", len(srv.conns))
773	}
774}
775
776func TestInProgressQueriesAtShutdownTCP(t *testing.T) {
777	s, addr, _, err := RunLocalTCPServerWithFinChan(":0")
778	if err != nil {
779		t.Fatalf("unable to run test server: %v", err)
780	}
781
782	c := &Client{Net: "tcp"}
783	checkInProgressQueriesAtShutdownServer(t, s, addr, c)
784}
785
786func TestShutdownTLS(t *testing.T) {
787	cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock)
788	if err != nil {
789		t.Fatalf("unable to build certificate: %v", err)
790	}
791
792	config := tls.Config{
793		Certificates: []tls.Certificate{cert},
794	}
795
796	s, _, err := RunLocalTLSServer(":0", &config)
797	if err != nil {
798		t.Fatalf("unable to run test server: %v", err)
799	}
800	err = s.Shutdown()
801	if err != nil {
802		t.Errorf("could not shutdown test TLS server, %v", err)
803	}
804}
805
806func TestInProgressQueriesAtShutdownTLS(t *testing.T) {
807	cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock)
808	if err != nil {
809		t.Fatalf("unable to build certificate: %v", err)
810	}
811
812	config := tls.Config{
813		Certificates: []tls.Certificate{cert},
814	}
815
816	s, addr, err := RunLocalTLSServer(":0", &config)
817	if err != nil {
818		t.Fatalf("unable to run test server: %v", err)
819	}
820
821	c := &Client{
822		Net: "tcp-tls",
823		TLSConfig: &tls.Config{
824			InsecureSkipVerify: true,
825		},
826	}
827	checkInProgressQueriesAtShutdownServer(t, s, addr, c)
828}
829
830func TestHandlerCloseTCP(t *testing.T) {
831
832	ln, err := net.Listen("tcp", ":0")
833	if err != nil {
834		panic(err)
835	}
836	addr := ln.Addr().String()
837
838	server := &Server{Addr: addr, Net: "tcp", Listener: ln}
839
840	hname := "testhandlerclosetcp."
841	triggered := make(chan struct{})
842	HandleFunc(hname, func(w ResponseWriter, r *Msg) {
843		close(triggered)
844		w.Close()
845	})
846	defer HandleRemove(hname)
847
848	go func() {
849		defer server.Shutdown()
850		c := &Client{Net: "tcp"}
851		m := new(Msg).SetQuestion(hname, 1)
852		tries := 0
853	exchange:
854		_, _, err := c.Exchange(m, addr)
855		if err != nil && err != io.EOF {
856			t.Errorf("exchange failed: %v", err)
857			if tries == 3 {
858				return
859			}
860			time.Sleep(time.Second / 10)
861			tries++
862			goto exchange
863		}
864	}()
865	if err := server.ActivateAndServe(); err != nil {
866		t.Fatalf("ActivateAndServe failed: %v", err)
867	}
868	select {
869	case <-triggered:
870	default:
871		t.Fatalf("handler never called")
872	}
873}
874
875func TestShutdownUDP(t *testing.T) {
876	s, _, fin, err := RunLocalUDPServerWithFinChan(":0")
877	if err != nil {
878		t.Fatalf("unable to run test server: %v", err)
879	}
880	err = s.Shutdown()
881	if err != nil {
882		t.Errorf("could not shutdown test UDP server, %v", err)
883	}
884	select {
885	case err := <-fin:
886		if err != nil {
887			t.Errorf("error returned from ActivateAndServe, %v", err)
888		}
889	case <-time.After(2 * time.Second):
890		t.Error("could not shutdown test UDP server. Gave up waiting")
891	}
892}
893
894func TestInProgressQueriesAtShutdownUDP(t *testing.T) {
895	s, addr, _, err := RunLocalUDPServerWithFinChan(":0")
896	if err != nil {
897		t.Fatalf("unable to run test server: %v", err)
898	}
899
900	c := &Client{Net: "udp"}
901	checkInProgressQueriesAtShutdownServer(t, s, addr, c)
902}
903
904func TestServerStartStopRace(t *testing.T) {
905	var wg sync.WaitGroup
906	for i := 0; i < 10; i++ {
907		wg.Add(1)
908		s, _, _, err := RunLocalUDPServerWithFinChan(":0")
909		if err != nil {
910			t.Fatalf("could not start server: %s", err)
911		}
912		go func() {
913			defer wg.Done()
914			if err := s.Shutdown(); err != nil {
915				t.Errorf("could not stop server: %s", err)
916			}
917		}()
918	}
919	wg.Wait()
920}
921
922func TestServerReuseport(t *testing.T) {
923	if !supportsReusePort {
924		t.Skip("reuseport is not supported")
925	}
926
927	startServer := func(addr string) (*Server, chan error) {
928		wait := make(chan struct{})
929		srv := &Server{
930			Net:               "udp",
931			Addr:              addr,
932			NotifyStartedFunc: func() { close(wait) },
933			ReusePort:         true,
934		}
935
936		fin := make(chan error, 1)
937		go func() {
938			fin <- srv.ListenAndServe()
939		}()
940
941		select {
942		case <-wait:
943		case err := <-fin:
944			t.Fatalf("failed to start server: %v", err)
945		}
946
947		return srv, fin
948	}
949
950	srv1, fin1 := startServer(":0") // :0 is resolved to a random free port by the kernel
951	srv2, fin2 := startServer(srv1.PacketConn.LocalAddr().String())
952
953	if err := srv1.Shutdown(); err != nil {
954		t.Fatalf("failed to shutdown first server: %v", err)
955	}
956	if err := srv2.Shutdown(); err != nil {
957		t.Fatalf("failed to shutdown second server: %v", err)
958	}
959
960	if err := <-fin1; err != nil {
961		t.Fatalf("first ListenAndServe returned error after Shutdown: %v", err)
962	}
963	if err := <-fin2; err != nil {
964		t.Fatalf("second ListenAndServe returned error after Shutdown: %v", err)
965	}
966}
967
968func TestServerRoundtripTsig(t *testing.T) {
969	secret := map[string]string{"test.": "so6ZGir4GPAqINNh9U5c3A=="}
970
971	s, addrstr, _, err := RunLocalUDPServerWithFinChan(":0", func(srv *Server) {
972		srv.TsigSecret = secret
973	})
974	if err != nil {
975		t.Fatalf("unable to run test server: %v", err)
976	}
977	defer s.Shutdown()
978
979	HandleFunc("example.com.", func(w ResponseWriter, r *Msg) {
980		m := new(Msg)
981		m.SetReply(r)
982		if r.IsTsig() != nil {
983			status := w.TsigStatus()
984			if status == nil {
985				// *Msg r has an TSIG record and it was validated
986				m.SetTsig("test.", HmacMD5, 300, time.Now().Unix())
987			} else {
988				// *Msg r has an TSIG records and it was not valided
989				t.Errorf("invalid TSIG: %v", status)
990			}
991		} else {
992			t.Error("missing TSIG")
993		}
994		w.WriteMsg(m)
995	})
996
997	c := new(Client)
998	m := new(Msg)
999	m.Opcode = OpcodeUpdate
1000	m.SetQuestion("example.com.", TypeSOA)
1001	m.Ns = []RR{&CNAME{
1002		Hdr: RR_Header{
1003			Name:   "foo.example.com.",
1004			Rrtype: TypeCNAME,
1005			Class:  ClassINET,
1006			Ttl:    300,
1007		},
1008		Target: "bar.example.com.",
1009	}}
1010	c.TsigSecret = secret
1011	m.SetTsig("test.", HmacMD5, 300, time.Now().Unix())
1012	_, _, err = c.Exchange(m, addrstr)
1013	if err != nil {
1014		t.Fatal("failed to exchange", err)
1015	}
1016}
1017
1018func TestResponseAfterClose(t *testing.T) {
1019	testError := func(name string, err error) {
1020		t.Helper()
1021
1022		expect := fmt.Sprintf("dns: %s called after Close", name)
1023		if err == nil {
1024			t.Errorf("expected error from %s after Close", name)
1025		} else if err.Error() != expect {
1026			t.Errorf("expected explicit error from %s after Close, expected %q, got %q", name, expect, err)
1027		}
1028	}
1029
1030	rw := &response{
1031		closed: true,
1032	}
1033
1034	_, err := rw.Write(make([]byte, 2))
1035	testError("Write", err)
1036
1037	testError("WriteMsg", rw.WriteMsg(new(Msg)))
1038}
1039
1040func TestResponseDoubleClose(t *testing.T) {
1041	rw := &response{
1042		closed: true,
1043	}
1044	if err, expect := rw.Close(), "dns: connection already closed"; err == nil || err.Error() != expect {
1045		t.Errorf("Close did not return expected: error %q, got: %v", expect, err)
1046	}
1047}
1048
1049type ExampleFrameLengthWriter struct {
1050	Writer
1051}
1052
1053func (e *ExampleFrameLengthWriter) Write(m []byte) (int, error) {
1054	fmt.Println("writing raw DNS message of length", len(m))
1055	return e.Writer.Write(m)
1056}
1057
1058func ExampleDecorateWriter() {
1059	// instrument raw DNS message writing
1060	wf := DecorateWriter(func(w Writer) Writer {
1061		return &ExampleFrameLengthWriter{w}
1062	})
1063
1064	// simple UDP server
1065	pc, err := net.ListenPacket("udp", ":0")
1066	if err != nil {
1067		fmt.Println(err.Error())
1068		return
1069	}
1070	server := &Server{
1071		PacketConn:     pc,
1072		DecorateWriter: wf,
1073		ReadTimeout:    time.Hour, WriteTimeout: time.Hour,
1074	}
1075
1076	waitLock := sync.Mutex{}
1077	waitLock.Lock()
1078	server.NotifyStartedFunc = waitLock.Unlock
1079	defer server.Shutdown()
1080
1081	go func() {
1082		server.ActivateAndServe()
1083		pc.Close()
1084	}()
1085
1086	waitLock.Lock()
1087
1088	HandleFunc("miek.nl.", HelloServer)
1089
1090	c := new(Client)
1091	m := new(Msg)
1092	m.SetQuestion("miek.nl.", TypeTXT)
1093	_, _, err = c.Exchange(m, pc.LocalAddr().String())
1094	if err != nil {
1095		fmt.Println("failed to exchange", err.Error())
1096		return
1097	}
1098	// Output: writing raw DNS message of length 56
1099}
1100
1101var (
1102	// CertPEMBlock is a X509 data used to test TLS servers (used with tls.X509KeyPair)
1103	CertPEMBlock = []byte(`-----BEGIN CERTIFICATE-----
1104MIIDAzCCAeugAwIBAgIRAJFYMkcn+b8dpU15wjf++GgwDQYJKoZIhvcNAQELBQAw
1105EjEQMA4GA1UEChMHQWNtZSBDbzAeFw0xNjAxMDgxMjAzNTNaFw0xNzAxMDcxMjAz
1106NTNaMBIxEDAOBgNVBAoTB0FjbWUgQ28wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAw
1107ggEKAoIBAQDXjqO6skvP03k58CNjQggd9G/mt+Wa+xRU+WXiKCCHttawM8x+slq5
1108yfsHCwxlwsGn79HmJqecNqgHb2GWBXAvVVokFDTcC1hUP4+gp2gu9Ny27UHTjlLm
1109O0l/xZ5MN8tfKyYlFw18tXu3fkaPyHj8v/D1RDkuo4ARdFvGSe8TqisbhLk2+9ow
1110xfIGbEM9Fdiw8qByC2+d+FfvzIKz3GfQVwn0VoRom8L6NBIANq1IGrB5JefZB6nv
1111DnfuxkBmY7F1513HKuEJ8KsLWWZWV9OPU4j4I4Rt+WJNlKjbD2srHxyrS2RDsr91
11128nCkNoWVNO3sZq0XkWKecdc921vL4ginAgMBAAGjVDBSMA4GA1UdDwEB/wQEAwIC
1113pDATBgNVHSUEDDAKBggrBgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MBoGA1UdEQQT
1114MBGCCWxvY2FsaG9zdIcEfwAAATANBgkqhkiG9w0BAQsFAAOCAQEAGcU3iyLBIVZj
1115aDzSvEDHUd1bnLBl1C58Xu/CyKlPqVU7mLfK0JcgEaYQTSX6fCJVNLbbCrcGLsPJ
1116fbjlBbyeLjTV413fxPVuona62pBFjqdtbli2Qe8FRH2KBdm41JUJGdo+SdsFu7nc
1117BFOcubdw6LLIXvsTvwndKcHWx1rMX709QU1Vn1GAIsbJV/DWI231Jyyb+lxAUx/C
11188vce5uVxiKcGS+g6OjsN3D3TtiEQGSXLh013W6Wsih8td8yMCMZ3w8LQ38br1GUe
1119ahLIgUJ9l6HDguM17R7kGqxNvbElsMUHfTtXXP7UDQUiYXDakg8xDP6n9DCDhJ8Y
1120bSt7OLB7NQ==
1121-----END CERTIFICATE-----`)
1122
1123	// KeyPEMBlock is a X509 data used to test TLS servers (used with tls.X509KeyPair)
1124	KeyPEMBlock = []byte(`-----BEGIN RSA PRIVATE KEY-----
1125MIIEpQIBAAKCAQEA146jurJLz9N5OfAjY0IIHfRv5rflmvsUVPll4iggh7bWsDPM
1126frJaucn7BwsMZcLBp+/R5iannDaoB29hlgVwL1VaJBQ03AtYVD+PoKdoLvTctu1B
1127045S5jtJf8WeTDfLXysmJRcNfLV7t35Gj8h4/L/w9UQ5LqOAEXRbxknvE6orG4S5
1128NvvaMMXyBmxDPRXYsPKgcgtvnfhX78yCs9xn0FcJ9FaEaJvC+jQSADatSBqweSXn
11292Qep7w537sZAZmOxdeddxyrhCfCrC1lmVlfTj1OI+COEbfliTZSo2w9rKx8cq0tk
1130Q7K/dfJwpDaFlTTt7GatF5FinnHXPdtby+IIpwIDAQABAoIBAAJK4RDmPooqTJrC
1131JA41MJLo+5uvjwCT9QZmVKAQHzByUFw1YNJkITTiognUI0CdzqNzmH7jIFs39ZeG
1132proKusO2G6xQjrNcZ4cV2fgyb5g4QHStl0qhs94A+WojduiGm2IaumAgm6Mc5wDv
1133ld6HmknN3Mku/ZCyanVFEIjOVn2WB7ZQLTBs6ZYaebTJG2Xv6p9t2YJW7pPQ9Xce
1134s9ohAWohyM4X/OvfnfnLtQp2YLw/BxwehBsCR5SXM3ibTKpFNtxJC8hIfTuWtxZu
11352ywrmXShYBRB1WgtZt5k04bY/HFncvvcHK3YfI1+w4URKtwdaQgPUQRbVwDwuyBn
1136flfkCJECgYEA/eWt01iEyE/lXkGn6V9lCocUU7lCU6yk5UT8VXVUc5If4KZKPfCk
1137p4zJDOqwn2eM673aWz/mG9mtvAvmnugaGjcaVCyXOp/D/GDmKSoYcvW5B/yjfkLy
1138dK6Yaa5LDRVYlYgyzcdCT5/9Qc626NzFwKCZNI4ncIU8g7ViATRxWJ8CgYEA2Ver
1139vZ0M606sfgC0H3NtwNBxmuJ+lIF5LNp/wDi07lDfxRR1rnZMX5dnxjcpDr/zvm8J
1140WtJJX3xMgqjtHuWKL3yKKony9J5ZPjichSbSbhrzfovgYIRZLxLLDy4MP9L3+CX/
1141yBXnqMWuSnFX+M5fVGxdDWiYF3V+wmeOv9JvavkCgYEAiXAPDFzaY+R78O3xiu7M
1142r0o3wqqCMPE/wav6O/hrYrQy9VSO08C0IM6g9pEEUwWmzuXSkZqhYWoQFb8Lc/GI
1143T7CMXAxXQLDDUpbRgG79FR3Wr3AewHZU8LyiXHKwxcBMV4WGmsXGK3wbh8fyU1NO
11446NsGk+BvkQVOoK1LBAPzZ1kCgYEAsBSmD8U33T9s4dxiEYTrqyV0lH3g/SFz8ZHH
1145pAyNEPI2iC1ONhyjPWKlcWHpAokiyOqeUpVBWnmSZtzC1qAydsxYB6ShT+sl9BHb
1146RMix/QAauzBJhQhUVJ3OIys0Q1UBDmqCsjCE8SfOT4NKOUnA093C+YT+iyrmmktZ
1147zDCJkckCgYEAndqM5KXGk5xYo+MAA1paZcbTUXwaWwjLU+XSRSSoyBEi5xMtfvUb
11487+a1OMhLwWbuz+pl64wFKrbSUyimMOYQpjVE/1vk/kb99pxbgol27hdKyTH1d+ov
1149kFsxKCqxAnBVGEWAvVZAiiTOxleQFjz5RnL0BQp9Lg2cQe+dvuUmIAA=
1150-----END RSA PRIVATE KEY-----`)
1151)
1152