1package dns 2 3import ( 4 "context" 5 "crypto/tls" 6 "fmt" 7 "io" 8 "net" 9 "runtime" 10 "strings" 11 "sync" 12 "sync/atomic" 13 "testing" 14 "time" 15 16 "golang.org/x/sync/errgroup" 17) 18 19func HelloServer(w ResponseWriter, req *Msg) { 20 m := new(Msg) 21 m.SetReply(req) 22 23 m.Extra = make([]RR, 1) 24 m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{"Hello world"}} 25 w.WriteMsg(m) 26} 27 28func HelloServerBadID(w ResponseWriter, req *Msg) { 29 m := new(Msg) 30 m.SetReply(req) 31 m.Id++ 32 33 m.Extra = make([]RR, 1) 34 m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{"Hello world"}} 35 w.WriteMsg(m) 36} 37 38func HelloServerEchoAddrPort(w ResponseWriter, req *Msg) { 39 m := new(Msg) 40 m.SetReply(req) 41 42 remoteAddr := w.RemoteAddr().String() 43 m.Extra = make([]RR, 1) 44 m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{remoteAddr}} 45 w.WriteMsg(m) 46} 47 48func AnotherHelloServer(w ResponseWriter, req *Msg) { 49 m := new(Msg) 50 m.SetReply(req) 51 52 m.Extra = make([]RR, 1) 53 m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{"Hello example"}} 54 w.WriteMsg(m) 55} 56 57func RunLocalUDPServer(laddr string) (*Server, string, error) { 58 server, l, _, err := RunLocalUDPServerWithFinChan(laddr) 59 60 return server, l, err 61} 62 63func RunLocalUDPServerWithFinChan(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) { 64 pc, err := net.ListenPacket("udp", laddr) 65 if err != nil { 66 return nil, "", nil, err 67 } 68 server := &Server{PacketConn: pc, ReadTimeout: time.Hour, WriteTimeout: time.Hour} 69 70 waitLock := sync.Mutex{} 71 waitLock.Lock() 72 server.NotifyStartedFunc = waitLock.Unlock 73 74 // fin must be buffered so the goroutine below won't block 75 // forever if fin is never read from. This always happens 76 // in RunLocalUDPServer and can happen in TestShutdownUDP. 77 fin := make(chan error, 1) 78 79 for _, opt := range opts { 80 opt(server) 81 } 82 83 go func() { 84 fin <- server.ActivateAndServe() 85 pc.Close() 86 }() 87 88 waitLock.Lock() 89 return server, pc.LocalAddr().String(), fin, nil 90} 91 92func RunLocalTCPServer(laddr string) (*Server, string, error) { 93 server, l, _, err := RunLocalTCPServerWithFinChan(laddr) 94 95 return server, l, err 96} 97 98func RunLocalTCPServerWithFinChan(laddr string) (*Server, string, chan error, error) { 99 l, err := net.Listen("tcp", laddr) 100 if err != nil { 101 return nil, "", nil, err 102 } 103 104 server := &Server{Listener: l, ReadTimeout: time.Hour, WriteTimeout: time.Hour} 105 106 waitLock := sync.Mutex{} 107 waitLock.Lock() 108 server.NotifyStartedFunc = waitLock.Unlock 109 110 // See the comment in RunLocalUDPServerWithFinChan as to 111 // why fin must be buffered. 112 fin := make(chan error, 1) 113 114 go func() { 115 fin <- server.ActivateAndServe() 116 l.Close() 117 }() 118 119 waitLock.Lock() 120 return server, l.Addr().String(), fin, nil 121} 122 123func RunLocalTLSServer(laddr string, config *tls.Config) (*Server, string, error) { 124 l, err := tls.Listen("tcp", laddr, config) 125 if err != nil { 126 return nil, "", err 127 } 128 129 server := &Server{Listener: l, ReadTimeout: time.Hour, WriteTimeout: time.Hour} 130 131 waitLock := sync.Mutex{} 132 waitLock.Lock() 133 server.NotifyStartedFunc = waitLock.Unlock 134 135 go func() { 136 server.ActivateAndServe() 137 l.Close() 138 }() 139 140 waitLock.Lock() 141 return server, l.Addr().String(), nil 142} 143 144func TestServing(t *testing.T) { 145 HandleFunc("miek.nl.", HelloServer) 146 HandleFunc("example.com.", AnotherHelloServer) 147 defer HandleRemove("miek.nl.") 148 defer HandleRemove("example.com.") 149 150 s, addrstr, err := RunLocalUDPServer(":0") 151 if err != nil { 152 t.Fatalf("unable to run test server: %v", err) 153 } 154 defer s.Shutdown() 155 156 c := new(Client) 157 m := new(Msg) 158 m.SetQuestion("miek.nl.", TypeTXT) 159 r, _, err := c.Exchange(m, addrstr) 160 if err != nil || len(r.Extra) == 0 { 161 t.Fatal("failed to exchange miek.nl", err) 162 } 163 txt := r.Extra[0].(*TXT).Txt[0] 164 if txt != "Hello world" { 165 t.Error("unexpected result for miek.nl", txt, "!= Hello world") 166 } 167 168 m.SetQuestion("example.com.", TypeTXT) 169 r, _, err = c.Exchange(m, addrstr) 170 if err != nil { 171 t.Fatal("failed to exchange example.com", err) 172 } 173 txt = r.Extra[0].(*TXT).Txt[0] 174 if txt != "Hello example" { 175 t.Error("unexpected result for example.com", txt, "!= Hello example") 176 } 177 178 // Test Mixes cased as noticed by Ask. 179 m.SetQuestion("eXaMplE.cOm.", TypeTXT) 180 r, _, err = c.Exchange(m, addrstr) 181 if err != nil { 182 t.Error("failed to exchange eXaMplE.cOm", err) 183 } 184 txt = r.Extra[0].(*TXT).Txt[0] 185 if txt != "Hello example" { 186 t.Error("unexpected result for example.com", txt, "!= Hello example") 187 } 188} 189 190// Verify that the server responds to a query with Z flag on, ignoring the flag, and does not echoes it back 191func TestServeIgnoresZFlag(t *testing.T) { 192 HandleFunc("example.com.", AnotherHelloServer) 193 194 s, addrstr, err := RunLocalUDPServer(":0") 195 if err != nil { 196 t.Fatalf("unable to run test server: %v", err) 197 } 198 defer s.Shutdown() 199 200 c := new(Client) 201 m := new(Msg) 202 203 // Test the Z flag is not echoed 204 m.SetQuestion("example.com.", TypeTXT) 205 m.Zero = true 206 r, _, err := c.Exchange(m, addrstr) 207 if err != nil { 208 t.Fatal("failed to exchange example.com with +zflag", err) 209 } 210 if r.Zero { 211 t.Error("the response should not have Z flag set - even for a query which does") 212 } 213 if r.Rcode != RcodeSuccess { 214 t.Errorf("expected rcode %v, got %v", RcodeSuccess, r.Rcode) 215 } 216} 217 218// Verify that the server responds to a query with unsupported Opcode with a NotImplemented error and that Opcode is unchanged. 219func TestServeNotImplemented(t *testing.T) { 220 HandleFunc("example.com.", AnotherHelloServer) 221 opcode := 15 222 223 s, addrstr, err := RunLocalUDPServer(":0") 224 if err != nil { 225 t.Fatalf("unable to run test server: %v", err) 226 } 227 defer s.Shutdown() 228 229 c := new(Client) 230 m := new(Msg) 231 232 // Test that Opcode is like the unchanged from request Opcode and that Rcode is set to NotImplemnented 233 m.SetQuestion("example.com.", TypeTXT) 234 m.Opcode = opcode 235 r, _, err := c.Exchange(m, addrstr) 236 if err != nil { 237 t.Fatal("failed to exchange example.com with +zflag", err) 238 } 239 if r.Opcode != opcode { 240 t.Errorf("expected opcode %v, got %v", opcode, r.Opcode) 241 } 242 if r.Rcode != RcodeNotImplemented { 243 t.Errorf("expected rcode %v, got %v", RcodeNotImplemented, r.Rcode) 244 } 245} 246 247func TestServingTLS(t *testing.T) { 248 HandleFunc("miek.nl.", HelloServer) 249 HandleFunc("example.com.", AnotherHelloServer) 250 defer HandleRemove("miek.nl.") 251 defer HandleRemove("example.com.") 252 253 cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock) 254 if err != nil { 255 t.Fatalf("unable to build certificate: %v", err) 256 } 257 258 config := tls.Config{ 259 Certificates: []tls.Certificate{cert}, 260 } 261 262 s, addrstr, err := RunLocalTLSServer(":0", &config) 263 if err != nil { 264 t.Fatalf("unable to run test server: %v", err) 265 } 266 defer s.Shutdown() 267 268 c := new(Client) 269 c.Net = "tcp-tls" 270 c.TLSConfig = &tls.Config{ 271 InsecureSkipVerify: true, 272 } 273 274 m := new(Msg) 275 m.SetQuestion("miek.nl.", TypeTXT) 276 r, _, err := c.Exchange(m, addrstr) 277 if err != nil || len(r.Extra) == 0 { 278 t.Fatal("failed to exchange miek.nl", err) 279 } 280 txt := r.Extra[0].(*TXT).Txt[0] 281 if txt != "Hello world" { 282 t.Error("unexpected result for miek.nl", txt, "!= Hello world") 283 } 284 285 m.SetQuestion("example.com.", TypeTXT) 286 r, _, err = c.Exchange(m, addrstr) 287 if err != nil { 288 t.Fatal("failed to exchange example.com", err) 289 } 290 txt = r.Extra[0].(*TXT).Txt[0] 291 if txt != "Hello example" { 292 t.Error("unexpected result for example.com", txt, "!= Hello example") 293 } 294 295 // Test Mixes cased as noticed by Ask. 296 m.SetQuestion("eXaMplE.cOm.", TypeTXT) 297 r, _, err = c.Exchange(m, addrstr) 298 if err != nil { 299 t.Error("failed to exchange eXaMplE.cOm", err) 300 } 301 txt = r.Extra[0].(*TXT).Txt[0] 302 if txt != "Hello example" { 303 t.Error("unexpected result for example.com", txt, "!= Hello example") 304 } 305} 306 307// TestServingTLSConnectionState tests that we only can access 308// tls.ConnectionState under a DNS query handled by a TLS DNS server. 309// This test will sequentially create a TLS, UDP and TCP server, attach a custom 310// handler which will set a testing error if tls.ConnectionState is available 311// when it is not expected, or the other way around. 312func TestServingTLSConnectionState(t *testing.T) { 313 handlerResponse := "Hello example" 314 // tlsHandlerTLS is a HandlerFunc that can be set to expect or not TLS 315 // connection state. 316 tlsHandlerTLS := func(tlsExpected bool) func(ResponseWriter, *Msg) { 317 return func(w ResponseWriter, req *Msg) { 318 m := new(Msg) 319 m.SetReply(req) 320 tlsFound := true 321 if connState := w.(ConnectionStater).ConnectionState(); connState == nil { 322 tlsFound = false 323 } 324 if tlsFound != tlsExpected { 325 t.Errorf("TLS connection state available: %t, expected: %t", tlsFound, tlsExpected) 326 } 327 m.Extra = make([]RR, 1) 328 m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{handlerResponse}} 329 w.WriteMsg(m) 330 } 331 } 332 333 // Question used in tests 334 m := new(Msg) 335 m.SetQuestion("tlsstate.example.net.", TypeTXT) 336 337 // TLS DNS server 338 HandleFunc(".", tlsHandlerTLS(true)) 339 cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock) 340 if err != nil { 341 t.Fatalf("unable to build certificate: %v", err) 342 } 343 344 config := tls.Config{ 345 Certificates: []tls.Certificate{cert}, 346 } 347 348 s, addrstr, err := RunLocalTLSServer(":0", &config) 349 if err != nil { 350 t.Fatalf("unable to run test server: %v", err) 351 } 352 defer s.Shutdown() 353 354 // TLS DNS query 355 c := &Client{ 356 Net: "tcp-tls", 357 TLSConfig: &tls.Config{ 358 InsecureSkipVerify: true, 359 }, 360 } 361 362 _, _, err = c.Exchange(m, addrstr) 363 if err != nil { 364 t.Error("failed to exchange tlsstate.example.net", err) 365 } 366 367 HandleRemove(".") 368 // UDP DNS Server 369 HandleFunc(".", tlsHandlerTLS(false)) 370 defer HandleRemove(".") 371 s, addrstr, err = RunLocalUDPServer(":0") 372 if err != nil { 373 t.Fatalf("unable to run test server: %v", err) 374 } 375 defer s.Shutdown() 376 377 // UDP DNS query 378 c = new(Client) 379 _, _, err = c.Exchange(m, addrstr) 380 if err != nil { 381 t.Error("failed to exchange tlsstate.example.net", err) 382 } 383 384 // TCP DNS Server 385 s, addrstr, err = RunLocalTCPServer(":0") 386 if err != nil { 387 t.Fatalf("unable to run test server: %v", err) 388 } 389 defer s.Shutdown() 390 391 // TCP DNS query 392 c = &Client{Net: "tcp"} 393 _, _, err = c.Exchange(m, addrstr) 394 if err != nil { 395 t.Error("failed to exchange tlsstate.example.net", err) 396 } 397} 398 399func TestServingListenAndServe(t *testing.T) { 400 HandleFunc("example.com.", AnotherHelloServer) 401 defer HandleRemove("example.com.") 402 403 waitLock := sync.Mutex{} 404 server := &Server{Addr: ":0", Net: "udp", ReadTimeout: time.Hour, WriteTimeout: time.Hour, NotifyStartedFunc: waitLock.Unlock} 405 waitLock.Lock() 406 407 go func() { 408 server.ListenAndServe() 409 }() 410 waitLock.Lock() 411 412 c, m := new(Client), new(Msg) 413 m.SetQuestion("example.com.", TypeTXT) 414 addr := server.PacketConn.LocalAddr().String() // Get address via the PacketConn that gets set. 415 r, _, err := c.Exchange(m, addr) 416 if err != nil { 417 t.Fatal("failed to exchange example.com", err) 418 } 419 txt := r.Extra[0].(*TXT).Txt[0] 420 if txt != "Hello example" { 421 t.Error("unexpected result for example.com", txt, "!= Hello example") 422 } 423 server.Shutdown() 424} 425 426func TestServingListenAndServeTLS(t *testing.T) { 427 HandleFunc("example.com.", AnotherHelloServer) 428 defer HandleRemove("example.com.") 429 430 cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock) 431 if err != nil { 432 t.Fatalf("unable to build certificate: %v", err) 433 } 434 435 config := &tls.Config{ 436 Certificates: []tls.Certificate{cert}, 437 } 438 439 waitLock := sync.Mutex{} 440 server := &Server{Addr: ":0", Net: "tcp", TLSConfig: config, ReadTimeout: time.Hour, WriteTimeout: time.Hour, NotifyStartedFunc: waitLock.Unlock} 441 waitLock.Lock() 442 443 go func() { 444 server.ListenAndServe() 445 }() 446 waitLock.Lock() 447 448 c, m := new(Client), new(Msg) 449 c.Net = "tcp" 450 m.SetQuestion("example.com.", TypeTXT) 451 addr := server.Listener.Addr().String() // Get address via the Listener that gets set. 452 r, _, err := c.Exchange(m, addr) 453 if err != nil { 454 t.Fatal(err) 455 } 456 txt := r.Extra[0].(*TXT).Txt[0] 457 if txt != "Hello example" { 458 t.Error("unexpected result for example.com", txt, "!= Hello example") 459 } 460 server.Shutdown() 461} 462 463func BenchmarkServe(b *testing.B) { 464 b.StopTimer() 465 HandleFunc("miek.nl.", HelloServer) 466 defer HandleRemove("miek.nl.") 467 a := runtime.GOMAXPROCS(4) 468 469 s, addrstr, err := RunLocalUDPServer(":0") 470 if err != nil { 471 b.Fatalf("unable to run test server: %v", err) 472 } 473 defer s.Shutdown() 474 475 c := new(Client) 476 m := new(Msg) 477 m.SetQuestion("miek.nl.", TypeSOA) 478 479 b.StartTimer() 480 for i := 0; i < b.N; i++ { 481 _, _, err := c.Exchange(m, addrstr) 482 if err != nil { 483 b.Fatalf("Exchange failed: %v", err) 484 } 485 } 486 runtime.GOMAXPROCS(a) 487} 488 489func BenchmarkServe6(b *testing.B) { 490 b.StopTimer() 491 HandleFunc("miek.nl.", HelloServer) 492 defer HandleRemove("miek.nl.") 493 a := runtime.GOMAXPROCS(4) 494 s, addrstr, err := RunLocalUDPServer("[::1]:0") 495 if err != nil { 496 if strings.Contains(err.Error(), "bind: cannot assign requested address") { 497 b.Skip("missing IPv6 support") 498 } 499 b.Fatalf("unable to run test server: %v", err) 500 } 501 defer s.Shutdown() 502 503 c := new(Client) 504 m := new(Msg) 505 m.SetQuestion("miek.nl.", TypeSOA) 506 507 b.StartTimer() 508 for i := 0; i < b.N; i++ { 509 _, _, err := c.Exchange(m, addrstr) 510 if err != nil { 511 b.Fatalf("Exchange failed: %v", err) 512 } 513 } 514 runtime.GOMAXPROCS(a) 515} 516 517func HelloServerCompress(w ResponseWriter, req *Msg) { 518 m := new(Msg) 519 m.SetReply(req) 520 m.Extra = make([]RR, 1) 521 m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{"Hello world"}} 522 m.Compress = true 523 w.WriteMsg(m) 524} 525 526func BenchmarkServeCompress(b *testing.B) { 527 b.StopTimer() 528 HandleFunc("miek.nl.", HelloServerCompress) 529 defer HandleRemove("miek.nl.") 530 a := runtime.GOMAXPROCS(4) 531 s, addrstr, err := RunLocalUDPServer(":0") 532 if err != nil { 533 b.Fatalf("unable to run test server: %v", err) 534 } 535 defer s.Shutdown() 536 537 c := new(Client) 538 m := new(Msg) 539 m.SetQuestion("miek.nl.", TypeSOA) 540 b.StartTimer() 541 for i := 0; i < b.N; i++ { 542 _, _, err := c.Exchange(m, addrstr) 543 if err != nil { 544 b.Fatalf("Exchange failed: %v", err) 545 } 546 } 547 runtime.GOMAXPROCS(a) 548} 549 550type maxRec struct { 551 max int 552 sync.RWMutex 553} 554 555var M = new(maxRec) 556 557func HelloServerLargeResponse(resp ResponseWriter, req *Msg) { 558 m := new(Msg) 559 m.SetReply(req) 560 m.Authoritative = true 561 m1 := 0 562 M.RLock() 563 m1 = M.max 564 M.RUnlock() 565 for i := 0; i < m1; i++ { 566 aRec := &A{ 567 Hdr: RR_Header{ 568 Name: req.Question[0].Name, 569 Rrtype: TypeA, 570 Class: ClassINET, 571 Ttl: 0, 572 }, 573 A: net.ParseIP(fmt.Sprintf("127.0.0.%d", i+1)).To4(), 574 } 575 m.Answer = append(m.Answer, aRec) 576 } 577 resp.WriteMsg(m) 578} 579 580func TestServingLargeResponses(t *testing.T) { 581 HandleFunc("example.", HelloServerLargeResponse) 582 defer HandleRemove("example.") 583 584 s, addrstr, err := RunLocalUDPServer(":0") 585 if err != nil { 586 t.Fatalf("unable to run test server: %v", err) 587 } 588 defer s.Shutdown() 589 590 // Create request 591 m := new(Msg) 592 m.SetQuestion("web.service.example.", TypeANY) 593 594 c := new(Client) 595 c.Net = "udp" 596 M.Lock() 597 M.max = 2 598 M.Unlock() 599 _, _, err = c.Exchange(m, addrstr) 600 if err != nil { 601 t.Errorf("failed to exchange: %v", err) 602 } 603 // This must fail 604 M.Lock() 605 M.max = 20 606 M.Unlock() 607 _, _, err = c.Exchange(m, addrstr) 608 if err == nil { 609 t.Error("failed to fail exchange, this should generate packet error") 610 } 611 // But this must work again 612 c.UDPSize = 7000 613 _, _, err = c.Exchange(m, addrstr) 614 if err != nil { 615 t.Errorf("failed to exchange: %v", err) 616 } 617} 618 619func TestServingResponse(t *testing.T) { 620 if testing.Short() { 621 t.Skip("skipping test in short mode.") 622 } 623 HandleFunc("miek.nl.", HelloServer) 624 s, addrstr, err := RunLocalUDPServer(":0") 625 if err != nil { 626 t.Fatalf("unable to run test server: %v", err) 627 } 628 defer s.Shutdown() 629 630 c := new(Client) 631 m := new(Msg) 632 m.SetQuestion("miek.nl.", TypeTXT) 633 m.Response = false 634 _, _, err = c.Exchange(m, addrstr) 635 if err != nil { 636 t.Fatal("failed to exchange", err) 637 } 638 m.Response = true 639 _, _, err = c.Exchange(m, addrstr) 640 if err == nil { 641 t.Fatal("exchanged response message") 642 } 643} 644 645func TestShutdownTCP(t *testing.T) { 646 s, _, fin, err := RunLocalTCPServerWithFinChan(":0") 647 if err != nil { 648 t.Fatalf("unable to run test server: %v", err) 649 } 650 err = s.Shutdown() 651 if err != nil { 652 t.Fatalf("could not shutdown test TCP server, %v", err) 653 } 654 select { 655 case err := <-fin: 656 if err != nil { 657 t.Errorf("error returned from ActivateAndServe, %v", err) 658 } 659 case <-time.After(2 * time.Second): 660 t.Error("could not shutdown test TCP server. Gave up waiting") 661 } 662} 663 664func init() { 665 testShutdownNotify = &sync.Cond{ 666 L: new(sync.Mutex), 667 } 668} 669 670func checkInProgressQueriesAtShutdownServer(t *testing.T, srv *Server, addr string, client *Client) { 671 const requests = 100 672 673 var errOnce sync.Once 674 // t.Fail will panic if it's called after the test function has 675 // finished. Burning the sync.Once with a defer will prevent the 676 // handler from calling t.Errorf after we've returned. 677 defer errOnce.Do(func() {}) 678 679 toHandle := int32(requests) 680 HandleFunc("example.com.", func(w ResponseWriter, req *Msg) { 681 defer atomic.AddInt32(&toHandle, -1) 682 683 // Wait until ShutdownContext is called before replying. 684 testShutdownNotify.L.Lock() 685 testShutdownNotify.Wait() 686 testShutdownNotify.L.Unlock() 687 688 m := new(Msg) 689 m.SetReply(req) 690 m.Extra = make([]RR, 1) 691 m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{"Hello world"}} 692 693 if err := w.WriteMsg(m); err != nil { 694 errOnce.Do(func() { 695 t.Errorf("ResponseWriter.WriteMsg error: %s", err) 696 }) 697 } 698 }) 699 defer HandleRemove("example.com.") 700 701 client.Timeout = 10 * time.Second 702 703 conns := make([]*Conn, requests) 704 eg := new(errgroup.Group) 705 706 for i := range conns { 707 conn := &conns[i] 708 eg.Go(func() error { 709 var err error 710 *conn, err = client.Dial(addr) 711 return err 712 }) 713 } 714 715 if eg.Wait() != nil { 716 t.Fatalf("client.Dial error: %v", eg.Wait()) 717 } 718 719 m := new(Msg) 720 m.SetQuestion("example.com.", TypeTXT) 721 eg = new(errgroup.Group) 722 723 for _, conn := range conns { 724 conn := conn 725 eg.Go(func() error { 726 conn.SetWriteDeadline(time.Now().Add(client.Timeout)) 727 728 return conn.WriteMsg(m) 729 }) 730 } 731 732 if eg.Wait() != nil { 733 t.Fatalf("conn.WriteMsg error: %v", eg.Wait()) 734 } 735 736 // This sleep is needed to allow time for the requests to 737 // pass from the client through the kernel and back into 738 // the server. Without it, some requests may still be in 739 // the kernel's buffer when ShutdownContext is called. 740 time.Sleep(100 * time.Millisecond) 741 742 eg = new(errgroup.Group) 743 744 for _, conn := range conns { 745 conn := conn 746 eg.Go(func() error { 747 conn.SetReadDeadline(time.Now().Add(client.Timeout)) 748 749 _, err := conn.ReadMsg() 750 return err 751 }) 752 } 753 754 ctx, cancel := context.WithTimeout(context.Background(), client.Timeout) 755 defer cancel() 756 757 if err := srv.ShutdownContext(ctx); err != nil { 758 t.Errorf("could not shutdown test server: %v", err) 759 } 760 761 if left := atomic.LoadInt32(&toHandle); left != 0 { 762 t.Errorf("ShutdownContext returned before %d replies", left) 763 } 764 765 if eg.Wait() != nil { 766 t.Errorf("conn.ReadMsg error: %v", eg.Wait()) 767 } 768 769 srv.lock.RLock() 770 defer srv.lock.RUnlock() 771 if len(srv.conns) != 0 { 772 t.Errorf("TCP connection tracking map not empty after ShutdownContext; map still contains %d connections", len(srv.conns)) 773 } 774} 775 776func TestInProgressQueriesAtShutdownTCP(t *testing.T) { 777 s, addr, _, err := RunLocalTCPServerWithFinChan(":0") 778 if err != nil { 779 t.Fatalf("unable to run test server: %v", err) 780 } 781 782 c := &Client{Net: "tcp"} 783 checkInProgressQueriesAtShutdownServer(t, s, addr, c) 784} 785 786func TestShutdownTLS(t *testing.T) { 787 cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock) 788 if err != nil { 789 t.Fatalf("unable to build certificate: %v", err) 790 } 791 792 config := tls.Config{ 793 Certificates: []tls.Certificate{cert}, 794 } 795 796 s, _, err := RunLocalTLSServer(":0", &config) 797 if err != nil { 798 t.Fatalf("unable to run test server: %v", err) 799 } 800 err = s.Shutdown() 801 if err != nil { 802 t.Errorf("could not shutdown test TLS server, %v", err) 803 } 804} 805 806func TestInProgressQueriesAtShutdownTLS(t *testing.T) { 807 cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock) 808 if err != nil { 809 t.Fatalf("unable to build certificate: %v", err) 810 } 811 812 config := tls.Config{ 813 Certificates: []tls.Certificate{cert}, 814 } 815 816 s, addr, err := RunLocalTLSServer(":0", &config) 817 if err != nil { 818 t.Fatalf("unable to run test server: %v", err) 819 } 820 821 c := &Client{ 822 Net: "tcp-tls", 823 TLSConfig: &tls.Config{ 824 InsecureSkipVerify: true, 825 }, 826 } 827 checkInProgressQueriesAtShutdownServer(t, s, addr, c) 828} 829 830func TestHandlerCloseTCP(t *testing.T) { 831 832 ln, err := net.Listen("tcp", ":0") 833 if err != nil { 834 panic(err) 835 } 836 addr := ln.Addr().String() 837 838 server := &Server{Addr: addr, Net: "tcp", Listener: ln} 839 840 hname := "testhandlerclosetcp." 841 triggered := make(chan struct{}) 842 HandleFunc(hname, func(w ResponseWriter, r *Msg) { 843 close(triggered) 844 w.Close() 845 }) 846 defer HandleRemove(hname) 847 848 go func() { 849 defer server.Shutdown() 850 c := &Client{Net: "tcp"} 851 m := new(Msg).SetQuestion(hname, 1) 852 tries := 0 853 exchange: 854 _, _, err := c.Exchange(m, addr) 855 if err != nil && err != io.EOF { 856 t.Errorf("exchange failed: %v", err) 857 if tries == 3 { 858 return 859 } 860 time.Sleep(time.Second / 10) 861 tries++ 862 goto exchange 863 } 864 }() 865 if err := server.ActivateAndServe(); err != nil { 866 t.Fatalf("ActivateAndServe failed: %v", err) 867 } 868 select { 869 case <-triggered: 870 default: 871 t.Fatalf("handler never called") 872 } 873} 874 875func TestShutdownUDP(t *testing.T) { 876 s, _, fin, err := RunLocalUDPServerWithFinChan(":0") 877 if err != nil { 878 t.Fatalf("unable to run test server: %v", err) 879 } 880 err = s.Shutdown() 881 if err != nil { 882 t.Errorf("could not shutdown test UDP server, %v", err) 883 } 884 select { 885 case err := <-fin: 886 if err != nil { 887 t.Errorf("error returned from ActivateAndServe, %v", err) 888 } 889 case <-time.After(2 * time.Second): 890 t.Error("could not shutdown test UDP server. Gave up waiting") 891 } 892} 893 894func TestInProgressQueriesAtShutdownUDP(t *testing.T) { 895 s, addr, _, err := RunLocalUDPServerWithFinChan(":0") 896 if err != nil { 897 t.Fatalf("unable to run test server: %v", err) 898 } 899 900 c := &Client{Net: "udp"} 901 checkInProgressQueriesAtShutdownServer(t, s, addr, c) 902} 903 904func TestServerStartStopRace(t *testing.T) { 905 var wg sync.WaitGroup 906 for i := 0; i < 10; i++ { 907 wg.Add(1) 908 s, _, _, err := RunLocalUDPServerWithFinChan(":0") 909 if err != nil { 910 t.Fatalf("could not start server: %s", err) 911 } 912 go func() { 913 defer wg.Done() 914 if err := s.Shutdown(); err != nil { 915 t.Errorf("could not stop server: %s", err) 916 } 917 }() 918 } 919 wg.Wait() 920} 921 922func TestServerReuseport(t *testing.T) { 923 if !supportsReusePort { 924 t.Skip("reuseport is not supported") 925 } 926 927 startServer := func(addr string) (*Server, chan error) { 928 wait := make(chan struct{}) 929 srv := &Server{ 930 Net: "udp", 931 Addr: addr, 932 NotifyStartedFunc: func() { close(wait) }, 933 ReusePort: true, 934 } 935 936 fin := make(chan error, 1) 937 go func() { 938 fin <- srv.ListenAndServe() 939 }() 940 941 select { 942 case <-wait: 943 case err := <-fin: 944 t.Fatalf("failed to start server: %v", err) 945 } 946 947 return srv, fin 948 } 949 950 srv1, fin1 := startServer(":0") // :0 is resolved to a random free port by the kernel 951 srv2, fin2 := startServer(srv1.PacketConn.LocalAddr().String()) 952 953 if err := srv1.Shutdown(); err != nil { 954 t.Fatalf("failed to shutdown first server: %v", err) 955 } 956 if err := srv2.Shutdown(); err != nil { 957 t.Fatalf("failed to shutdown second server: %v", err) 958 } 959 960 if err := <-fin1; err != nil { 961 t.Fatalf("first ListenAndServe returned error after Shutdown: %v", err) 962 } 963 if err := <-fin2; err != nil { 964 t.Fatalf("second ListenAndServe returned error after Shutdown: %v", err) 965 } 966} 967 968func TestServerRoundtripTsig(t *testing.T) { 969 secret := map[string]string{"test.": "so6ZGir4GPAqINNh9U5c3A=="} 970 971 s, addrstr, _, err := RunLocalUDPServerWithFinChan(":0", func(srv *Server) { 972 srv.TsigSecret = secret 973 }) 974 if err != nil { 975 t.Fatalf("unable to run test server: %v", err) 976 } 977 defer s.Shutdown() 978 979 HandleFunc("example.com.", func(w ResponseWriter, r *Msg) { 980 m := new(Msg) 981 m.SetReply(r) 982 if r.IsTsig() != nil { 983 status := w.TsigStatus() 984 if status == nil { 985 // *Msg r has an TSIG record and it was validated 986 m.SetTsig("test.", HmacMD5, 300, time.Now().Unix()) 987 } else { 988 // *Msg r has an TSIG records and it was not valided 989 t.Errorf("invalid TSIG: %v", status) 990 } 991 } else { 992 t.Error("missing TSIG") 993 } 994 w.WriteMsg(m) 995 }) 996 997 c := new(Client) 998 m := new(Msg) 999 m.Opcode = OpcodeUpdate 1000 m.SetQuestion("example.com.", TypeSOA) 1001 m.Ns = []RR{&CNAME{ 1002 Hdr: RR_Header{ 1003 Name: "foo.example.com.", 1004 Rrtype: TypeCNAME, 1005 Class: ClassINET, 1006 Ttl: 300, 1007 }, 1008 Target: "bar.example.com.", 1009 }} 1010 c.TsigSecret = secret 1011 m.SetTsig("test.", HmacMD5, 300, time.Now().Unix()) 1012 _, _, err = c.Exchange(m, addrstr) 1013 if err != nil { 1014 t.Fatal("failed to exchange", err) 1015 } 1016} 1017 1018func TestResponseAfterClose(t *testing.T) { 1019 testError := func(name string, err error) { 1020 t.Helper() 1021 1022 expect := fmt.Sprintf("dns: %s called after Close", name) 1023 if err == nil { 1024 t.Errorf("expected error from %s after Close", name) 1025 } else if err.Error() != expect { 1026 t.Errorf("expected explicit error from %s after Close, expected %q, got %q", name, expect, err) 1027 } 1028 } 1029 1030 rw := &response{ 1031 closed: true, 1032 } 1033 1034 _, err := rw.Write(make([]byte, 2)) 1035 testError("Write", err) 1036 1037 testError("WriteMsg", rw.WriteMsg(new(Msg))) 1038} 1039 1040func TestResponseDoubleClose(t *testing.T) { 1041 rw := &response{ 1042 closed: true, 1043 } 1044 if err, expect := rw.Close(), "dns: connection already closed"; err == nil || err.Error() != expect { 1045 t.Errorf("Close did not return expected: error %q, got: %v", expect, err) 1046 } 1047} 1048 1049type ExampleFrameLengthWriter struct { 1050 Writer 1051} 1052 1053func (e *ExampleFrameLengthWriter) Write(m []byte) (int, error) { 1054 fmt.Println("writing raw DNS message of length", len(m)) 1055 return e.Writer.Write(m) 1056} 1057 1058func ExampleDecorateWriter() { 1059 // instrument raw DNS message writing 1060 wf := DecorateWriter(func(w Writer) Writer { 1061 return &ExampleFrameLengthWriter{w} 1062 }) 1063 1064 // simple UDP server 1065 pc, err := net.ListenPacket("udp", ":0") 1066 if err != nil { 1067 fmt.Println(err.Error()) 1068 return 1069 } 1070 server := &Server{ 1071 PacketConn: pc, 1072 DecorateWriter: wf, 1073 ReadTimeout: time.Hour, WriteTimeout: time.Hour, 1074 } 1075 1076 waitLock := sync.Mutex{} 1077 waitLock.Lock() 1078 server.NotifyStartedFunc = waitLock.Unlock 1079 defer server.Shutdown() 1080 1081 go func() { 1082 server.ActivateAndServe() 1083 pc.Close() 1084 }() 1085 1086 waitLock.Lock() 1087 1088 HandleFunc("miek.nl.", HelloServer) 1089 1090 c := new(Client) 1091 m := new(Msg) 1092 m.SetQuestion("miek.nl.", TypeTXT) 1093 _, _, err = c.Exchange(m, pc.LocalAddr().String()) 1094 if err != nil { 1095 fmt.Println("failed to exchange", err.Error()) 1096 return 1097 } 1098 // Output: writing raw DNS message of length 56 1099} 1100 1101var ( 1102 // CertPEMBlock is a X509 data used to test TLS servers (used with tls.X509KeyPair) 1103 CertPEMBlock = []byte(`-----BEGIN CERTIFICATE----- 1104MIIDAzCCAeugAwIBAgIRAJFYMkcn+b8dpU15wjf++GgwDQYJKoZIhvcNAQELBQAw 1105EjEQMA4GA1UEChMHQWNtZSBDbzAeFw0xNjAxMDgxMjAzNTNaFw0xNzAxMDcxMjAz 1106NTNaMBIxEDAOBgNVBAoTB0FjbWUgQ28wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAw 1107ggEKAoIBAQDXjqO6skvP03k58CNjQggd9G/mt+Wa+xRU+WXiKCCHttawM8x+slq5 1108yfsHCwxlwsGn79HmJqecNqgHb2GWBXAvVVokFDTcC1hUP4+gp2gu9Ny27UHTjlLm 1109O0l/xZ5MN8tfKyYlFw18tXu3fkaPyHj8v/D1RDkuo4ARdFvGSe8TqisbhLk2+9ow 1110xfIGbEM9Fdiw8qByC2+d+FfvzIKz3GfQVwn0VoRom8L6NBIANq1IGrB5JefZB6nv 1111DnfuxkBmY7F1513HKuEJ8KsLWWZWV9OPU4j4I4Rt+WJNlKjbD2srHxyrS2RDsr91 11128nCkNoWVNO3sZq0XkWKecdc921vL4ginAgMBAAGjVDBSMA4GA1UdDwEB/wQEAwIC 1113pDATBgNVHSUEDDAKBggrBgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MBoGA1UdEQQT 1114MBGCCWxvY2FsaG9zdIcEfwAAATANBgkqhkiG9w0BAQsFAAOCAQEAGcU3iyLBIVZj 1115aDzSvEDHUd1bnLBl1C58Xu/CyKlPqVU7mLfK0JcgEaYQTSX6fCJVNLbbCrcGLsPJ 1116fbjlBbyeLjTV413fxPVuona62pBFjqdtbli2Qe8FRH2KBdm41JUJGdo+SdsFu7nc 1117BFOcubdw6LLIXvsTvwndKcHWx1rMX709QU1Vn1GAIsbJV/DWI231Jyyb+lxAUx/C 11188vce5uVxiKcGS+g6OjsN3D3TtiEQGSXLh013W6Wsih8td8yMCMZ3w8LQ38br1GUe 1119ahLIgUJ9l6HDguM17R7kGqxNvbElsMUHfTtXXP7UDQUiYXDakg8xDP6n9DCDhJ8Y 1120bSt7OLB7NQ== 1121-----END CERTIFICATE-----`) 1122 1123 // KeyPEMBlock is a X509 data used to test TLS servers (used with tls.X509KeyPair) 1124 KeyPEMBlock = []byte(`-----BEGIN RSA PRIVATE KEY----- 1125MIIEpQIBAAKCAQEA146jurJLz9N5OfAjY0IIHfRv5rflmvsUVPll4iggh7bWsDPM 1126frJaucn7BwsMZcLBp+/R5iannDaoB29hlgVwL1VaJBQ03AtYVD+PoKdoLvTctu1B 1127045S5jtJf8WeTDfLXysmJRcNfLV7t35Gj8h4/L/w9UQ5LqOAEXRbxknvE6orG4S5 1128NvvaMMXyBmxDPRXYsPKgcgtvnfhX78yCs9xn0FcJ9FaEaJvC+jQSADatSBqweSXn 11292Qep7w537sZAZmOxdeddxyrhCfCrC1lmVlfTj1OI+COEbfliTZSo2w9rKx8cq0tk 1130Q7K/dfJwpDaFlTTt7GatF5FinnHXPdtby+IIpwIDAQABAoIBAAJK4RDmPooqTJrC 1131JA41MJLo+5uvjwCT9QZmVKAQHzByUFw1YNJkITTiognUI0CdzqNzmH7jIFs39ZeG 1132proKusO2G6xQjrNcZ4cV2fgyb5g4QHStl0qhs94A+WojduiGm2IaumAgm6Mc5wDv 1133ld6HmknN3Mku/ZCyanVFEIjOVn2WB7ZQLTBs6ZYaebTJG2Xv6p9t2YJW7pPQ9Xce 1134s9ohAWohyM4X/OvfnfnLtQp2YLw/BxwehBsCR5SXM3ibTKpFNtxJC8hIfTuWtxZu 11352ywrmXShYBRB1WgtZt5k04bY/HFncvvcHK3YfI1+w4URKtwdaQgPUQRbVwDwuyBn 1136flfkCJECgYEA/eWt01iEyE/lXkGn6V9lCocUU7lCU6yk5UT8VXVUc5If4KZKPfCk 1137p4zJDOqwn2eM673aWz/mG9mtvAvmnugaGjcaVCyXOp/D/GDmKSoYcvW5B/yjfkLy 1138dK6Yaa5LDRVYlYgyzcdCT5/9Qc626NzFwKCZNI4ncIU8g7ViATRxWJ8CgYEA2Ver 1139vZ0M606sfgC0H3NtwNBxmuJ+lIF5LNp/wDi07lDfxRR1rnZMX5dnxjcpDr/zvm8J 1140WtJJX3xMgqjtHuWKL3yKKony9J5ZPjichSbSbhrzfovgYIRZLxLLDy4MP9L3+CX/ 1141yBXnqMWuSnFX+M5fVGxdDWiYF3V+wmeOv9JvavkCgYEAiXAPDFzaY+R78O3xiu7M 1142r0o3wqqCMPE/wav6O/hrYrQy9VSO08C0IM6g9pEEUwWmzuXSkZqhYWoQFb8Lc/GI 1143T7CMXAxXQLDDUpbRgG79FR3Wr3AewHZU8LyiXHKwxcBMV4WGmsXGK3wbh8fyU1NO 11446NsGk+BvkQVOoK1LBAPzZ1kCgYEAsBSmD8U33T9s4dxiEYTrqyV0lH3g/SFz8ZHH 1145pAyNEPI2iC1ONhyjPWKlcWHpAokiyOqeUpVBWnmSZtzC1qAydsxYB6ShT+sl9BHb 1146RMix/QAauzBJhQhUVJ3OIys0Q1UBDmqCsjCE8SfOT4NKOUnA093C+YT+iyrmmktZ 1147zDCJkckCgYEAndqM5KXGk5xYo+MAA1paZcbTUXwaWwjLU+XSRSSoyBEi5xMtfvUb 11487+a1OMhLwWbuz+pl64wFKrbSUyimMOYQpjVE/1vk/kb99pxbgol27hdKyTH1d+ov 1149kFsxKCqxAnBVGEWAvVZAiiTOxleQFjz5RnL0BQp9Lg2cQe+dvuUmIAA= 1150-----END RSA PRIVATE KEY-----`) 1151) 1152