1package dns 2 3import ( 4 "context" 5 "crypto/tls" 6 "fmt" 7 "io" 8 "net" 9 "runtime" 10 "strings" 11 "sync" 12 "sync/atomic" 13 "testing" 14 "time" 15 16 "golang.org/x/sync/errgroup" 17) 18 19func HelloServer(w ResponseWriter, req *Msg) { 20 m := new(Msg) 21 m.SetReply(req) 22 23 m.Extra = make([]RR, 1) 24 m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{"Hello world"}} 25 w.WriteMsg(m) 26} 27 28func HelloServerBadID(w ResponseWriter, req *Msg) { 29 m := new(Msg) 30 m.SetReply(req) 31 m.Id++ 32 33 m.Extra = make([]RR, 1) 34 m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{"Hello world"}} 35 w.WriteMsg(m) 36} 37 38func HelloServerBadThenGoodID(w ResponseWriter, req *Msg) { 39 m := new(Msg) 40 m.SetReply(req) 41 m.Id++ 42 43 m.Extra = make([]RR, 1) 44 m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{"Hello world"}} 45 w.WriteMsg(m) 46 47 m.Id-- 48 w.WriteMsg(m) 49} 50 51func HelloServerEchoAddrPort(w ResponseWriter, req *Msg) { 52 m := new(Msg) 53 m.SetReply(req) 54 55 remoteAddr := w.RemoteAddr().String() 56 m.Extra = make([]RR, 1) 57 m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{remoteAddr}} 58 w.WriteMsg(m) 59} 60 61func AnotherHelloServer(w ResponseWriter, req *Msg) { 62 m := new(Msg) 63 m.SetReply(req) 64 65 m.Extra = make([]RR, 1) 66 m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{"Hello example"}} 67 w.WriteMsg(m) 68} 69 70func RunLocalUDPServer(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) { 71 pc, err := net.ListenPacket("udp", laddr) 72 if err != nil { 73 return nil, "", nil, err 74 } 75 server := &Server{PacketConn: pc, ReadTimeout: time.Hour, WriteTimeout: time.Hour} 76 77 waitLock := sync.Mutex{} 78 waitLock.Lock() 79 server.NotifyStartedFunc = waitLock.Unlock 80 81 for _, opt := range opts { 82 opt(server) 83 } 84 85 // fin must be buffered so the goroutine below won't block 86 // forever if fin is never read from. This always happens 87 // if the channel is discarded and can happen in TestShutdownUDP. 88 fin := make(chan error, 1) 89 90 go func() { 91 fin <- server.ActivateAndServe() 92 pc.Close() 93 }() 94 95 waitLock.Lock() 96 return server, pc.LocalAddr().String(), fin, nil 97} 98 99func RunLocalPacketConnServer(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) { 100 return RunLocalUDPServer(laddr, append(opts, func(srv *Server) { 101 // Make srv.PacketConn opaque to trigger the generic code paths. 102 srv.PacketConn = struct{ net.PacketConn }{srv.PacketConn} 103 })...) 104} 105 106func RunLocalTCPServer(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) { 107 l, err := net.Listen("tcp", laddr) 108 if err != nil { 109 return nil, "", nil, err 110 } 111 112 server := &Server{Listener: l, ReadTimeout: time.Hour, WriteTimeout: time.Hour} 113 114 waitLock := sync.Mutex{} 115 waitLock.Lock() 116 server.NotifyStartedFunc = waitLock.Unlock 117 118 for _, opt := range opts { 119 opt(server) 120 } 121 122 // See the comment in RunLocalUDPServer as to why fin must be buffered. 123 fin := make(chan error, 1) 124 125 go func() { 126 fin <- server.ActivateAndServe() 127 l.Close() 128 }() 129 130 waitLock.Lock() 131 return server, l.Addr().String(), fin, nil 132} 133 134func RunLocalTLSServer(laddr string, config *tls.Config) (*Server, string, chan error, error) { 135 return RunLocalTCPServer(laddr, func(srv *Server) { 136 srv.Listener = tls.NewListener(srv.Listener, config) 137 }) 138} 139 140func TestServing(t *testing.T) { 141 for _, tc := range []struct { 142 name string 143 network string 144 runServer func(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) 145 }{ 146 {"udp", "udp", RunLocalUDPServer}, 147 {"tcp", "tcp", RunLocalTCPServer}, 148 {"PacketConn", "udp", RunLocalPacketConnServer}, 149 } { 150 t.Run(tc.name, func(t *testing.T) { 151 HandleFunc("miek.nl.", HelloServer) 152 HandleFunc("example.com.", AnotherHelloServer) 153 defer HandleRemove("miek.nl.") 154 defer HandleRemove("example.com.") 155 156 s, addrstr, _, err := tc.runServer(":0") 157 if err != nil { 158 t.Fatalf("unable to run test server: %v", err) 159 } 160 defer s.Shutdown() 161 162 c := &Client{ 163 Net: tc.network, 164 } 165 m := new(Msg) 166 m.SetQuestion("miek.nl.", TypeTXT) 167 r, _, err := c.Exchange(m, addrstr) 168 if err != nil || len(r.Extra) == 0 { 169 t.Fatal("failed to exchange miek.nl", err) 170 } 171 txt := r.Extra[0].(*TXT).Txt[0] 172 if txt != "Hello world" { 173 t.Error("unexpected result for miek.nl", txt, "!= Hello world") 174 } 175 176 m.SetQuestion("example.com.", TypeTXT) 177 r, _, err = c.Exchange(m, addrstr) 178 if err != nil { 179 t.Fatal("failed to exchange example.com", err) 180 } 181 txt = r.Extra[0].(*TXT).Txt[0] 182 if txt != "Hello example" { 183 t.Error("unexpected result for example.com", txt, "!= Hello example") 184 } 185 186 // Test Mixes cased as noticed by Ask. 187 m.SetQuestion("eXaMplE.cOm.", TypeTXT) 188 r, _, err = c.Exchange(m, addrstr) 189 if err != nil { 190 t.Error("failed to exchange eXaMplE.cOm", err) 191 } 192 txt = r.Extra[0].(*TXT).Txt[0] 193 if txt != "Hello example" { 194 t.Error("unexpected result for example.com", txt, "!= Hello example") 195 } 196 }) 197 } 198} 199 200// Verify that the server responds to a query with Z flag on, ignoring the flag, and does not echoes it back 201func TestServeIgnoresZFlag(t *testing.T) { 202 HandleFunc("example.com.", AnotherHelloServer) 203 204 s, addrstr, _, err := RunLocalUDPServer(":0") 205 if err != nil { 206 t.Fatalf("unable to run test server: %v", err) 207 } 208 defer s.Shutdown() 209 210 c := new(Client) 211 m := new(Msg) 212 213 // Test the Z flag is not echoed 214 m.SetQuestion("example.com.", TypeTXT) 215 m.Zero = true 216 r, _, err := c.Exchange(m, addrstr) 217 if err != nil { 218 t.Fatal("failed to exchange example.com with +zflag", err) 219 } 220 if r.Zero { 221 t.Error("the response should not have Z flag set - even for a query which does") 222 } 223 if r.Rcode != RcodeSuccess { 224 t.Errorf("expected rcode %v, got %v", RcodeSuccess, r.Rcode) 225 } 226} 227 228// Verify that the server responds to a query with unsupported Opcode with a NotImplemented error and that Opcode is unchanged. 229func TestServeNotImplemented(t *testing.T) { 230 HandleFunc("example.com.", AnotherHelloServer) 231 opcode := 15 232 233 s, addrstr, _, err := RunLocalUDPServer(":0") 234 if err != nil { 235 t.Fatalf("unable to run test server: %v", err) 236 } 237 defer s.Shutdown() 238 239 c := new(Client) 240 m := new(Msg) 241 242 // Test that Opcode is like the unchanged from request Opcode and that Rcode is set to NotImplemented 243 m.SetQuestion("example.com.", TypeTXT) 244 m.Opcode = opcode 245 r, _, err := c.Exchange(m, addrstr) 246 if err != nil { 247 t.Fatal("failed to exchange example.com with +zflag", err) 248 } 249 if r.Opcode != opcode { 250 t.Errorf("expected opcode %v, got %v", opcode, r.Opcode) 251 } 252 if r.Rcode != RcodeNotImplemented { 253 t.Errorf("expected rcode %v, got %v", RcodeNotImplemented, r.Rcode) 254 } 255} 256 257func TestServingTLS(t *testing.T) { 258 HandleFunc("miek.nl.", HelloServer) 259 HandleFunc("example.com.", AnotherHelloServer) 260 defer HandleRemove("miek.nl.") 261 defer HandleRemove("example.com.") 262 263 cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock) 264 if err != nil { 265 t.Fatalf("unable to build certificate: %v", err) 266 } 267 268 config := tls.Config{ 269 Certificates: []tls.Certificate{cert}, 270 } 271 272 s, addrstr, _, err := RunLocalTLSServer(":0", &config) 273 if err != nil { 274 t.Fatalf("unable to run test server: %v", err) 275 } 276 defer s.Shutdown() 277 278 c := new(Client) 279 c.Net = "tcp-tls" 280 c.TLSConfig = &tls.Config{ 281 InsecureSkipVerify: true, 282 } 283 284 m := new(Msg) 285 m.SetQuestion("miek.nl.", TypeTXT) 286 r, _, err := c.Exchange(m, addrstr) 287 if err != nil || len(r.Extra) == 0 { 288 t.Fatal("failed to exchange miek.nl", err) 289 } 290 txt := r.Extra[0].(*TXT).Txt[0] 291 if txt != "Hello world" { 292 t.Error("unexpected result for miek.nl", txt, "!= Hello world") 293 } 294 295 m.SetQuestion("example.com.", TypeTXT) 296 r, _, err = c.Exchange(m, addrstr) 297 if err != nil { 298 t.Fatal("failed to exchange example.com", err) 299 } 300 txt = r.Extra[0].(*TXT).Txt[0] 301 if txt != "Hello example" { 302 t.Error("unexpected result for example.com", txt, "!= Hello example") 303 } 304 305 // Test Mixes cased as noticed by Ask. 306 m.SetQuestion("eXaMplE.cOm.", TypeTXT) 307 r, _, err = c.Exchange(m, addrstr) 308 if err != nil { 309 t.Error("failed to exchange eXaMplE.cOm", err) 310 } 311 txt = r.Extra[0].(*TXT).Txt[0] 312 if txt != "Hello example" { 313 t.Error("unexpected result for example.com", txt, "!= Hello example") 314 } 315} 316 317// TestServingTLSConnectionState tests that we only can access 318// tls.ConnectionState under a DNS query handled by a TLS DNS server. 319// This test will sequentially create a TLS, UDP and TCP server, attach a custom 320// handler which will set a testing error if tls.ConnectionState is available 321// when it is not expected, or the other way around. 322func TestServingTLSConnectionState(t *testing.T) { 323 handlerResponse := "Hello example" 324 // tlsHandlerTLS is a HandlerFunc that can be set to expect or not TLS 325 // connection state. 326 tlsHandlerTLS := func(tlsExpected bool) func(ResponseWriter, *Msg) { 327 return func(w ResponseWriter, req *Msg) { 328 m := new(Msg) 329 m.SetReply(req) 330 tlsFound := true 331 if connState := w.(ConnectionStater).ConnectionState(); connState == nil { 332 tlsFound = false 333 } 334 if tlsFound != tlsExpected { 335 t.Errorf("TLS connection state available: %t, expected: %t", tlsFound, tlsExpected) 336 } 337 m.Extra = make([]RR, 1) 338 m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{handlerResponse}} 339 w.WriteMsg(m) 340 } 341 } 342 343 // Question used in tests 344 m := new(Msg) 345 m.SetQuestion("tlsstate.example.net.", TypeTXT) 346 347 // TLS DNS server 348 HandleFunc(".", tlsHandlerTLS(true)) 349 cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock) 350 if err != nil { 351 t.Fatalf("unable to build certificate: %v", err) 352 } 353 354 config := tls.Config{ 355 Certificates: []tls.Certificate{cert}, 356 } 357 358 s, addrstr, _, err := RunLocalTLSServer(":0", &config) 359 if err != nil { 360 t.Fatalf("unable to run test server: %v", err) 361 } 362 defer s.Shutdown() 363 364 // TLS DNS query 365 c := &Client{ 366 Net: "tcp-tls", 367 TLSConfig: &tls.Config{ 368 InsecureSkipVerify: true, 369 }, 370 } 371 372 _, _, err = c.Exchange(m, addrstr) 373 if err != nil { 374 t.Error("failed to exchange tlsstate.example.net", err) 375 } 376 377 HandleRemove(".") 378 // UDP DNS Server 379 HandleFunc(".", tlsHandlerTLS(false)) 380 defer HandleRemove(".") 381 s, addrstr, _, err = RunLocalUDPServer(":0") 382 if err != nil { 383 t.Fatalf("unable to run test server: %v", err) 384 } 385 defer s.Shutdown() 386 387 // UDP DNS query 388 c = new(Client) 389 _, _, err = c.Exchange(m, addrstr) 390 if err != nil { 391 t.Error("failed to exchange tlsstate.example.net", err) 392 } 393 394 // TCP DNS Server 395 s, addrstr, _, err = RunLocalTCPServer(":0") 396 if err != nil { 397 t.Fatalf("unable to run test server: %v", err) 398 } 399 defer s.Shutdown() 400 401 // TCP DNS query 402 c = &Client{Net: "tcp"} 403 _, _, err = c.Exchange(m, addrstr) 404 if err != nil { 405 t.Error("failed to exchange tlsstate.example.net", err) 406 } 407} 408 409func TestServingListenAndServe(t *testing.T) { 410 HandleFunc("example.com.", AnotherHelloServer) 411 defer HandleRemove("example.com.") 412 413 waitLock := sync.Mutex{} 414 server := &Server{Addr: ":0", Net: "udp", ReadTimeout: time.Hour, WriteTimeout: time.Hour, NotifyStartedFunc: waitLock.Unlock} 415 waitLock.Lock() 416 417 go func() { 418 server.ListenAndServe() 419 }() 420 waitLock.Lock() 421 422 c, m := new(Client), new(Msg) 423 m.SetQuestion("example.com.", TypeTXT) 424 addr := server.PacketConn.LocalAddr().String() // Get address via the PacketConn that gets set. 425 r, _, err := c.Exchange(m, addr) 426 if err != nil { 427 t.Fatal("failed to exchange example.com", err) 428 } 429 txt := r.Extra[0].(*TXT).Txt[0] 430 if txt != "Hello example" { 431 t.Error("unexpected result for example.com", txt, "!= Hello example") 432 } 433 server.Shutdown() 434} 435 436func TestServingListenAndServeTLS(t *testing.T) { 437 HandleFunc("example.com.", AnotherHelloServer) 438 defer HandleRemove("example.com.") 439 440 cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock) 441 if err != nil { 442 t.Fatalf("unable to build certificate: %v", err) 443 } 444 445 config := &tls.Config{ 446 Certificates: []tls.Certificate{cert}, 447 } 448 449 waitLock := sync.Mutex{} 450 server := &Server{Addr: ":0", Net: "tcp", TLSConfig: config, ReadTimeout: time.Hour, WriteTimeout: time.Hour, NotifyStartedFunc: waitLock.Unlock} 451 waitLock.Lock() 452 453 go func() { 454 server.ListenAndServe() 455 }() 456 waitLock.Lock() 457 458 c, m := new(Client), new(Msg) 459 c.Net = "tcp" 460 m.SetQuestion("example.com.", TypeTXT) 461 addr := server.Listener.Addr().String() // Get address via the Listener that gets set. 462 r, _, err := c.Exchange(m, addr) 463 if err != nil { 464 t.Fatal(err) 465 } 466 txt := r.Extra[0].(*TXT).Txt[0] 467 if txt != "Hello example" { 468 t.Error("unexpected result for example.com", txt, "!= Hello example") 469 } 470 server.Shutdown() 471} 472 473func BenchmarkServe(b *testing.B) { 474 b.StopTimer() 475 HandleFunc("miek.nl.", HelloServer) 476 defer HandleRemove("miek.nl.") 477 a := runtime.GOMAXPROCS(4) 478 479 s, addrstr, _, err := RunLocalUDPServer(":0") 480 if err != nil { 481 b.Fatalf("unable to run test server: %v", err) 482 } 483 defer s.Shutdown() 484 485 c := new(Client) 486 m := new(Msg) 487 m.SetQuestion("miek.nl.", TypeSOA) 488 489 b.StartTimer() 490 for i := 0; i < b.N; i++ { 491 _, _, err := c.Exchange(m, addrstr) 492 if err != nil { 493 b.Fatalf("Exchange failed: %v", err) 494 } 495 } 496 runtime.GOMAXPROCS(a) 497} 498 499func BenchmarkServe6(b *testing.B) { 500 b.StopTimer() 501 HandleFunc("miek.nl.", HelloServer) 502 defer HandleRemove("miek.nl.") 503 a := runtime.GOMAXPROCS(4) 504 s, addrstr, _, err := RunLocalUDPServer("[::1]:0") 505 if err != nil { 506 if strings.Contains(err.Error(), "bind: cannot assign requested address") { 507 b.Skip("missing IPv6 support") 508 } 509 b.Fatalf("unable to run test server: %v", err) 510 } 511 defer s.Shutdown() 512 513 c := new(Client) 514 m := new(Msg) 515 m.SetQuestion("miek.nl.", TypeSOA) 516 517 b.StartTimer() 518 for i := 0; i < b.N; i++ { 519 _, _, err := c.Exchange(m, addrstr) 520 if err != nil { 521 b.Fatalf("Exchange failed: %v", err) 522 } 523 } 524 runtime.GOMAXPROCS(a) 525} 526 527func HelloServerCompress(w ResponseWriter, req *Msg) { 528 m := new(Msg) 529 m.SetReply(req) 530 m.Extra = make([]RR, 1) 531 m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{"Hello world"}} 532 m.Compress = true 533 w.WriteMsg(m) 534} 535 536func BenchmarkServeCompress(b *testing.B) { 537 b.StopTimer() 538 HandleFunc("miek.nl.", HelloServerCompress) 539 defer HandleRemove("miek.nl.") 540 a := runtime.GOMAXPROCS(4) 541 s, addrstr, _, err := RunLocalUDPServer(":0") 542 if err != nil { 543 b.Fatalf("unable to run test server: %v", err) 544 } 545 defer s.Shutdown() 546 547 c := new(Client) 548 m := new(Msg) 549 m.SetQuestion("miek.nl.", TypeSOA) 550 b.StartTimer() 551 for i := 0; i < b.N; i++ { 552 _, _, err := c.Exchange(m, addrstr) 553 if err != nil { 554 b.Fatalf("Exchange failed: %v", err) 555 } 556 } 557 runtime.GOMAXPROCS(a) 558} 559 560type maxRec struct { 561 max int 562 sync.RWMutex 563} 564 565var M = new(maxRec) 566 567func HelloServerLargeResponse(resp ResponseWriter, req *Msg) { 568 m := new(Msg) 569 m.SetReply(req) 570 m.Authoritative = true 571 m1 := 0 572 M.RLock() 573 m1 = M.max 574 M.RUnlock() 575 for i := 0; i < m1; i++ { 576 aRec := &A{ 577 Hdr: RR_Header{ 578 Name: req.Question[0].Name, 579 Rrtype: TypeA, 580 Class: ClassINET, 581 Ttl: 0, 582 }, 583 A: net.ParseIP(fmt.Sprintf("127.0.0.%d", i+1)).To4(), 584 } 585 m.Answer = append(m.Answer, aRec) 586 } 587 resp.WriteMsg(m) 588} 589 590func TestServingLargeResponses(t *testing.T) { 591 HandleFunc("example.", HelloServerLargeResponse) 592 defer HandleRemove("example.") 593 594 s, addrstr, _, err := RunLocalUDPServer(":0") 595 if err != nil { 596 t.Fatalf("unable to run test server: %v", err) 597 } 598 defer s.Shutdown() 599 600 // Create request 601 m := new(Msg) 602 m.SetQuestion("web.service.example.", TypeANY) 603 604 c := new(Client) 605 c.Net = "udp" 606 M.Lock() 607 M.max = 2 608 M.Unlock() 609 _, _, err = c.Exchange(m, addrstr) 610 if err != nil { 611 t.Errorf("failed to exchange: %v", err) 612 } 613 // This must fail 614 M.Lock() 615 M.max = 20 616 M.Unlock() 617 _, _, err = c.Exchange(m, addrstr) 618 if err == nil { 619 t.Error("failed to fail exchange, this should generate packet error") 620 } 621 // But this must work again 622 c.UDPSize = 7000 623 _, _, err = c.Exchange(m, addrstr) 624 if err != nil { 625 t.Errorf("failed to exchange: %v", err) 626 } 627} 628 629func TestServingResponse(t *testing.T) { 630 if testing.Short() { 631 t.Skip("skipping test in short mode.") 632 } 633 HandleFunc("miek.nl.", HelloServer) 634 s, addrstr, _, err := RunLocalUDPServer(":0") 635 if err != nil { 636 t.Fatalf("unable to run test server: %v", err) 637 } 638 defer s.Shutdown() 639 640 c := new(Client) 641 m := new(Msg) 642 m.SetQuestion("miek.nl.", TypeTXT) 643 m.Response = false 644 _, _, err = c.Exchange(m, addrstr) 645 if err != nil { 646 t.Fatal("failed to exchange", err) 647 } 648 m.Response = true // this holds up the reply, set short read time out to avoid waiting too long 649 c.ReadTimeout = 100 * time.Millisecond 650 _, _, err = c.Exchange(m, addrstr) 651 if err == nil { 652 t.Fatal("exchanged response message") 653 } 654} 655 656func TestShutdownTCP(t *testing.T) { 657 s, _, fin, err := RunLocalTCPServer(":0") 658 if err != nil { 659 t.Fatalf("unable to run test server: %v", err) 660 } 661 err = s.Shutdown() 662 if err != nil { 663 t.Fatalf("could not shutdown test TCP server, %v", err) 664 } 665 select { 666 case err := <-fin: 667 if err != nil { 668 t.Errorf("error returned from ActivateAndServe, %v", err) 669 } 670 case <-time.After(2 * time.Second): 671 t.Error("could not shutdown test TCP server. Gave up waiting") 672 } 673} 674 675func init() { 676 testShutdownNotify = &sync.Cond{ 677 L: new(sync.Mutex), 678 } 679} 680 681func checkInProgressQueriesAtShutdownServer(t *testing.T, srv *Server, addr string, client *Client) { 682 const requests = 15 // enough to make this interesting? TODO: find a proper value 683 684 var errOnce sync.Once 685 // t.Fail will panic if it's called after the test function has 686 // finished. Burning the sync.Once with a defer will prevent the 687 // handler from calling t.Errorf after we've returned. 688 defer errOnce.Do(func() {}) 689 690 toHandle := int32(requests) 691 HandleFunc("example.com.", func(w ResponseWriter, req *Msg) { 692 defer atomic.AddInt32(&toHandle, -1) 693 694 // Wait until ShutdownContext is called before replying. 695 testShutdownNotify.L.Lock() 696 testShutdownNotify.Wait() 697 testShutdownNotify.L.Unlock() 698 699 m := new(Msg) 700 m.SetReply(req) 701 m.Extra = make([]RR, 1) 702 m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{"Hello world"}} 703 704 if err := w.WriteMsg(m); err != nil { 705 errOnce.Do(func() { 706 t.Errorf("ResponseWriter.WriteMsg error: %s", err) 707 }) 708 } 709 }) 710 defer HandleRemove("example.com.") 711 712 client.Timeout = 1 * time.Second 713 714 conns := make([]*Conn, requests) 715 eg := new(errgroup.Group) 716 717 for i := range conns { 718 conn := &conns[i] 719 eg.Go(func() error { 720 var err error 721 *conn, err = client.Dial(addr) 722 return err 723 }) 724 } 725 726 if eg.Wait() != nil { 727 t.Fatalf("client.Dial error: %v", eg.Wait()) 728 } 729 730 m := new(Msg) 731 m.SetQuestion("example.com.", TypeTXT) 732 eg = new(errgroup.Group) 733 734 for _, conn := range conns { 735 conn := conn 736 eg.Go(func() error { 737 conn.SetWriteDeadline(time.Now().Add(client.Timeout)) 738 739 return conn.WriteMsg(m) 740 }) 741 } 742 743 if eg.Wait() != nil { 744 t.Fatalf("conn.WriteMsg error: %v", eg.Wait()) 745 } 746 747 // This sleep is needed to allow time for the requests to 748 // pass from the client through the kernel and back into 749 // the server. Without it, some requests may still be in 750 // the kernel's buffer when ShutdownContext is called. 751 time.Sleep(100 * time.Millisecond) 752 753 eg = new(errgroup.Group) 754 755 for _, conn := range conns { 756 conn := conn 757 eg.Go(func() error { 758 conn.SetReadDeadline(time.Now().Add(client.Timeout)) 759 760 _, err := conn.ReadMsg() 761 return err 762 }) 763 } 764 765 ctx, cancel := context.WithTimeout(context.Background(), client.Timeout) 766 defer cancel() 767 768 if err := srv.ShutdownContext(ctx); err != nil { 769 t.Errorf("could not shutdown test server: %v", err) 770 } 771 772 if left := atomic.LoadInt32(&toHandle); left != 0 { 773 t.Errorf("ShutdownContext returned before %d replies", left) 774 } 775 776 if eg.Wait() != nil { 777 t.Errorf("conn.ReadMsg error: %v", eg.Wait()) 778 } 779 780 srv.lock.RLock() 781 defer srv.lock.RUnlock() 782 if len(srv.conns) != 0 { 783 t.Errorf("TCP connection tracking map not empty after ShutdownContext; map still contains %d connections", len(srv.conns)) 784 } 785} 786 787func TestInProgressQueriesAtShutdownTCP(t *testing.T) { 788 s, addr, _, err := RunLocalTCPServer(":0") 789 if err != nil { 790 t.Fatalf("unable to run test server: %v", err) 791 } 792 793 c := &Client{Net: "tcp"} 794 checkInProgressQueriesAtShutdownServer(t, s, addr, c) 795} 796 797func TestShutdownTLS(t *testing.T) { 798 cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock) 799 if err != nil { 800 t.Fatalf("unable to build certificate: %v", err) 801 } 802 803 config := tls.Config{ 804 Certificates: []tls.Certificate{cert}, 805 } 806 807 s, _, _, err := RunLocalTLSServer(":0", &config) 808 if err != nil { 809 t.Fatalf("unable to run test server: %v", err) 810 } 811 err = s.Shutdown() 812 if err != nil { 813 t.Errorf("could not shutdown test TLS server, %v", err) 814 } 815} 816 817func TestInProgressQueriesAtShutdownTLS(t *testing.T) { 818 cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock) 819 if err != nil { 820 t.Fatalf("unable to build certificate: %v", err) 821 } 822 823 config := tls.Config{ 824 Certificates: []tls.Certificate{cert}, 825 } 826 827 s, addr, _, err := RunLocalTLSServer(":0", &config) 828 if err != nil { 829 t.Fatalf("unable to run test server: %v", err) 830 } 831 832 c := &Client{ 833 Net: "tcp-tls", 834 TLSConfig: &tls.Config{ 835 InsecureSkipVerify: true, 836 }, 837 } 838 checkInProgressQueriesAtShutdownServer(t, s, addr, c) 839} 840 841func TestHandlerCloseTCP(t *testing.T) { 842 ln, err := net.Listen("tcp", ":0") 843 if err != nil { 844 panic(err) 845 } 846 addr := ln.Addr().String() 847 848 server := &Server{Addr: addr, Net: "tcp", Listener: ln} 849 850 hname := "testhandlerclosetcp." 851 triggered := make(chan struct{}) 852 HandleFunc(hname, func(w ResponseWriter, r *Msg) { 853 close(triggered) 854 w.Close() 855 }) 856 defer HandleRemove(hname) 857 858 go func() { 859 defer server.Shutdown() 860 c := &Client{Net: "tcp"} 861 m := new(Msg).SetQuestion(hname, 1) 862 tries := 0 863 exchange: 864 _, _, err := c.Exchange(m, addr) 865 if err != nil && err != io.EOF { 866 t.Errorf("exchange failed: %v", err) 867 if tries == 3 { 868 return 869 } 870 time.Sleep(time.Second / 10) 871 tries++ 872 goto exchange 873 } 874 }() 875 if err := server.ActivateAndServe(); err != nil { 876 t.Fatalf("ActivateAndServe failed: %v", err) 877 } 878 select { 879 case <-triggered: 880 default: 881 t.Fatalf("handler never called") 882 } 883} 884 885func TestShutdownUDP(t *testing.T) { 886 s, _, fin, err := RunLocalUDPServer(":0") 887 if err != nil { 888 t.Fatalf("unable to run test server: %v", err) 889 } 890 err = s.Shutdown() 891 if err != nil { 892 t.Errorf("could not shutdown test UDP server, %v", err) 893 } 894 select { 895 case err := <-fin: 896 if err != nil { 897 t.Errorf("error returned from ActivateAndServe, %v", err) 898 } 899 case <-time.After(2 * time.Second): 900 t.Error("could not shutdown test UDP server. Gave up waiting") 901 } 902} 903 904func TestShutdownPacketConn(t *testing.T) { 905 s, _, fin, err := RunLocalPacketConnServer(":0") 906 if err != nil { 907 t.Fatalf("unable to run test server: %v", err) 908 } 909 err = s.Shutdown() 910 if err != nil { 911 t.Errorf("could not shutdown test UDP server, %v", err) 912 } 913 select { 914 case err := <-fin: 915 if err != nil { 916 t.Errorf("error returned from ActivateAndServe, %v", err) 917 } 918 case <-time.After(2 * time.Second): 919 t.Error("could not shutdown test UDP server. Gave up waiting") 920 } 921} 922 923func TestInProgressQueriesAtShutdownUDP(t *testing.T) { 924 s, addr, _, err := RunLocalUDPServer(":0") 925 if err != nil { 926 t.Fatalf("unable to run test server: %v", err) 927 } 928 929 c := &Client{Net: "udp"} 930 checkInProgressQueriesAtShutdownServer(t, s, addr, c) 931} 932 933func TestInProgressQueriesAtShutdownPacketConn(t *testing.T) { 934 s, addr, _, err := RunLocalPacketConnServer(":0") 935 if err != nil { 936 t.Fatalf("unable to run test server: %v", err) 937 } 938 939 c := &Client{Net: "udp"} 940 checkInProgressQueriesAtShutdownServer(t, s, addr, c) 941} 942 943func TestServerStartStopRace(t *testing.T) { 944 var wg sync.WaitGroup 945 for i := 0; i < 10; i++ { 946 wg.Add(1) 947 s, _, _, err := RunLocalUDPServer(":0") 948 if err != nil { 949 t.Fatalf("could not start server: %s", err) 950 } 951 go func() { 952 defer wg.Done() 953 if err := s.Shutdown(); err != nil { 954 t.Errorf("could not stop server: %s", err) 955 } 956 }() 957 } 958 wg.Wait() 959} 960 961func TestServerReuseport(t *testing.T) { 962 if !supportsReusePort { 963 t.Skip("reuseport is not supported") 964 } 965 966 startServer := func(addr string) (*Server, chan error) { 967 wait := make(chan struct{}) 968 srv := &Server{ 969 Net: "udp", 970 Addr: addr, 971 NotifyStartedFunc: func() { close(wait) }, 972 ReusePort: true, 973 } 974 975 fin := make(chan error, 1) 976 go func() { 977 fin <- srv.ListenAndServe() 978 }() 979 980 select { 981 case <-wait: 982 case err := <-fin: 983 t.Fatalf("failed to start server: %v", err) 984 } 985 986 return srv, fin 987 } 988 989 srv1, fin1 := startServer(":0") // :0 is resolved to a random free port by the kernel 990 srv2, fin2 := startServer(srv1.PacketConn.LocalAddr().String()) 991 992 if err := srv1.Shutdown(); err != nil { 993 t.Fatalf("failed to shutdown first server: %v", err) 994 } 995 if err := srv2.Shutdown(); err != nil { 996 t.Fatalf("failed to shutdown second server: %v", err) 997 } 998 999 if err := <-fin1; err != nil { 1000 t.Fatalf("first ListenAndServe returned error after Shutdown: %v", err) 1001 } 1002 if err := <-fin2; err != nil { 1003 t.Fatalf("second ListenAndServe returned error after Shutdown: %v", err) 1004 } 1005} 1006 1007func TestServerRoundtripTsig(t *testing.T) { 1008 secret := map[string]string{"test.": "so6ZGir4GPAqINNh9U5c3A=="} 1009 1010 s, addrstr, _, err := RunLocalUDPServer(":0", func(srv *Server) { 1011 srv.TsigSecret = secret 1012 srv.MsgAcceptFunc = func(dh Header) MsgAcceptAction { 1013 // defaultMsgAcceptFunc does reject UPDATE queries 1014 return MsgAccept 1015 } 1016 }) 1017 if err != nil { 1018 t.Fatalf("unable to run test server: %v", err) 1019 } 1020 defer s.Shutdown() 1021 1022 handlerFired := make(chan struct{}) 1023 HandleFunc("example.com.", func(w ResponseWriter, r *Msg) { 1024 close(handlerFired) 1025 1026 m := new(Msg) 1027 m.SetReply(r) 1028 if r.IsTsig() != nil { 1029 status := w.TsigStatus() 1030 if status == nil { 1031 // *Msg r has an TSIG record and it was validated 1032 m.SetTsig("test.", HmacSHA256, 300, time.Now().Unix()) 1033 } else { 1034 // *Msg r has an TSIG records and it was not validated 1035 t.Errorf("invalid TSIG: %v", status) 1036 } 1037 } else { 1038 t.Error("missing TSIG") 1039 } 1040 if err := w.WriteMsg(m); err != nil { 1041 t.Error("writemsg failed", err) 1042 } 1043 }) 1044 1045 c := new(Client) 1046 m := new(Msg) 1047 m.Opcode = OpcodeUpdate 1048 m.SetQuestion("example.com.", TypeSOA) 1049 m.Ns = []RR{&CNAME{ 1050 Hdr: RR_Header{ 1051 Name: "foo.example.com.", 1052 Rrtype: TypeCNAME, 1053 Class: ClassINET, 1054 Ttl: 300, 1055 }, 1056 Target: "bar.example.com.", 1057 }} 1058 c.TsigSecret = secret 1059 m.SetTsig("test.", HmacSHA256, 300, time.Now().Unix()) 1060 _, _, err = c.Exchange(m, addrstr) 1061 if err != nil { 1062 t.Fatal("failed to exchange", err) 1063 } 1064 select { 1065 case <-handlerFired: 1066 // ok, handler was actually called 1067 default: 1068 t.Error("handler was not called") 1069 } 1070} 1071 1072func TestResponseAfterClose(t *testing.T) { 1073 testError := func(name string, err error) { 1074 t.Helper() 1075 1076 expect := fmt.Sprintf("dns: %s called after Close", name) 1077 if err == nil { 1078 t.Errorf("expected error from %s after Close", name) 1079 } else if err.Error() != expect { 1080 t.Errorf("expected explicit error from %s after Close, expected %q, got %q", name, expect, err) 1081 } 1082 } 1083 1084 rw := &response{ 1085 closed: true, 1086 } 1087 1088 _, err := rw.Write(make([]byte, 2)) 1089 testError("Write", err) 1090 1091 testError("WriteMsg", rw.WriteMsg(new(Msg))) 1092} 1093 1094func TestResponseDoubleClose(t *testing.T) { 1095 rw := &response{ 1096 closed: true, 1097 } 1098 if err, expect := rw.Close(), "dns: connection already closed"; err == nil || err.Error() != expect { 1099 t.Errorf("Close did not return expected: error %q, got: %v", expect, err) 1100 } 1101} 1102 1103type countingConn struct { 1104 net.Conn 1105 writes int 1106} 1107 1108func (c *countingConn) Write(p []byte) (int, error) { 1109 c.writes++ 1110 return len(p), nil 1111} 1112 1113func TestResponseWriteSinglePacket(t *testing.T) { 1114 c := &countingConn{} 1115 rw := &response{ 1116 tcp: c, 1117 } 1118 rw.writer = rw 1119 1120 m := new(Msg) 1121 m.SetQuestion("miek.nl.", TypeTXT) 1122 m.Response = true 1123 err := rw.WriteMsg(m) 1124 1125 if err != nil { 1126 t.Fatalf("failed to write: %v", err) 1127 } 1128 1129 if c.writes != 1 { 1130 t.Fatalf("incorrect number of Write calls") 1131 } 1132} 1133 1134type ExampleFrameLengthWriter struct { 1135 Writer 1136} 1137 1138func (e *ExampleFrameLengthWriter) Write(m []byte) (int, error) { 1139 fmt.Println("writing raw DNS message of length", len(m)) 1140 return e.Writer.Write(m) 1141} 1142 1143func ExampleDecorateWriter() { 1144 // instrument raw DNS message writing 1145 wf := DecorateWriter(func(w Writer) Writer { 1146 return &ExampleFrameLengthWriter{w} 1147 }) 1148 1149 // simple UDP server 1150 pc, err := net.ListenPacket("udp", ":0") 1151 if err != nil { 1152 fmt.Println(err.Error()) 1153 return 1154 } 1155 server := &Server{ 1156 PacketConn: pc, 1157 DecorateWriter: wf, 1158 ReadTimeout: time.Hour, WriteTimeout: time.Hour, 1159 } 1160 1161 waitLock := sync.Mutex{} 1162 waitLock.Lock() 1163 server.NotifyStartedFunc = waitLock.Unlock 1164 defer server.Shutdown() 1165 1166 go func() { 1167 server.ActivateAndServe() 1168 pc.Close() 1169 }() 1170 1171 waitLock.Lock() 1172 1173 HandleFunc("miek.nl.", HelloServer) 1174 1175 c := new(Client) 1176 m := new(Msg) 1177 m.SetQuestion("miek.nl.", TypeTXT) 1178 _, _, err = c.Exchange(m, pc.LocalAddr().String()) 1179 if err != nil { 1180 fmt.Println("failed to exchange", err.Error()) 1181 return 1182 } 1183 // Output: writing raw DNS message of length 56 1184} 1185 1186var ( 1187 // CertPEMBlock is a X509 data used to test TLS servers (used with tls.X509KeyPair) 1188 CertPEMBlock = []byte(`-----BEGIN CERTIFICATE----- 1189MIIDAzCCAeugAwIBAgIRAJFYMkcn+b8dpU15wjf++GgwDQYJKoZIhvcNAQELBQAw 1190EjEQMA4GA1UEChMHQWNtZSBDbzAeFw0xNjAxMDgxMjAzNTNaFw0xNzAxMDcxMjAz 1191NTNaMBIxEDAOBgNVBAoTB0FjbWUgQ28wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAw 1192ggEKAoIBAQDXjqO6skvP03k58CNjQggd9G/mt+Wa+xRU+WXiKCCHttawM8x+slq5 1193yfsHCwxlwsGn79HmJqecNqgHb2GWBXAvVVokFDTcC1hUP4+gp2gu9Ny27UHTjlLm 1194O0l/xZ5MN8tfKyYlFw18tXu3fkaPyHj8v/D1RDkuo4ARdFvGSe8TqisbhLk2+9ow 1195xfIGbEM9Fdiw8qByC2+d+FfvzIKz3GfQVwn0VoRom8L6NBIANq1IGrB5JefZB6nv 1196DnfuxkBmY7F1513HKuEJ8KsLWWZWV9OPU4j4I4Rt+WJNlKjbD2srHxyrS2RDsr91 11978nCkNoWVNO3sZq0XkWKecdc921vL4ginAgMBAAGjVDBSMA4GA1UdDwEB/wQEAwIC 1198pDATBgNVHSUEDDAKBggrBgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MBoGA1UdEQQT 1199MBGCCWxvY2FsaG9zdIcEfwAAATANBgkqhkiG9w0BAQsFAAOCAQEAGcU3iyLBIVZj 1200aDzSvEDHUd1bnLBl1C58Xu/CyKlPqVU7mLfK0JcgEaYQTSX6fCJVNLbbCrcGLsPJ 1201fbjlBbyeLjTV413fxPVuona62pBFjqdtbli2Qe8FRH2KBdm41JUJGdo+SdsFu7nc 1202BFOcubdw6LLIXvsTvwndKcHWx1rMX709QU1Vn1GAIsbJV/DWI231Jyyb+lxAUx/C 12038vce5uVxiKcGS+g6OjsN3D3TtiEQGSXLh013W6Wsih8td8yMCMZ3w8LQ38br1GUe 1204ahLIgUJ9l6HDguM17R7kGqxNvbElsMUHfTtXXP7UDQUiYXDakg8xDP6n9DCDhJ8Y 1205bSt7OLB7NQ== 1206-----END CERTIFICATE-----`) 1207 1208 // KeyPEMBlock is a X509 data used to test TLS servers (used with tls.X509KeyPair) 1209 KeyPEMBlock = []byte(`-----BEGIN RSA PRIVATE KEY----- 1210MIIEpQIBAAKCAQEA146jurJLz9N5OfAjY0IIHfRv5rflmvsUVPll4iggh7bWsDPM 1211frJaucn7BwsMZcLBp+/R5iannDaoB29hlgVwL1VaJBQ03AtYVD+PoKdoLvTctu1B 1212045S5jtJf8WeTDfLXysmJRcNfLV7t35Gj8h4/L/w9UQ5LqOAEXRbxknvE6orG4S5 1213NvvaMMXyBmxDPRXYsPKgcgtvnfhX78yCs9xn0FcJ9FaEaJvC+jQSADatSBqweSXn 12142Qep7w537sZAZmOxdeddxyrhCfCrC1lmVlfTj1OI+COEbfliTZSo2w9rKx8cq0tk 1215Q7K/dfJwpDaFlTTt7GatF5FinnHXPdtby+IIpwIDAQABAoIBAAJK4RDmPooqTJrC 1216JA41MJLo+5uvjwCT9QZmVKAQHzByUFw1YNJkITTiognUI0CdzqNzmH7jIFs39ZeG 1217proKusO2G6xQjrNcZ4cV2fgyb5g4QHStl0qhs94A+WojduiGm2IaumAgm6Mc5wDv 1218ld6HmknN3Mku/ZCyanVFEIjOVn2WB7ZQLTBs6ZYaebTJG2Xv6p9t2YJW7pPQ9Xce 1219s9ohAWohyM4X/OvfnfnLtQp2YLw/BxwehBsCR5SXM3ibTKpFNtxJC8hIfTuWtxZu 12202ywrmXShYBRB1WgtZt5k04bY/HFncvvcHK3YfI1+w4URKtwdaQgPUQRbVwDwuyBn 1221flfkCJECgYEA/eWt01iEyE/lXkGn6V9lCocUU7lCU6yk5UT8VXVUc5If4KZKPfCk 1222p4zJDOqwn2eM673aWz/mG9mtvAvmnugaGjcaVCyXOp/D/GDmKSoYcvW5B/yjfkLy 1223dK6Yaa5LDRVYlYgyzcdCT5/9Qc626NzFwKCZNI4ncIU8g7ViATRxWJ8CgYEA2Ver 1224vZ0M606sfgC0H3NtwNBxmuJ+lIF5LNp/wDi07lDfxRR1rnZMX5dnxjcpDr/zvm8J 1225WtJJX3xMgqjtHuWKL3yKKony9J5ZPjichSbSbhrzfovgYIRZLxLLDy4MP9L3+CX/ 1226yBXnqMWuSnFX+M5fVGxdDWiYF3V+wmeOv9JvavkCgYEAiXAPDFzaY+R78O3xiu7M 1227r0o3wqqCMPE/wav6O/hrYrQy9VSO08C0IM6g9pEEUwWmzuXSkZqhYWoQFb8Lc/GI 1228T7CMXAxXQLDDUpbRgG79FR3Wr3AewHZU8LyiXHKwxcBMV4WGmsXGK3wbh8fyU1NO 12296NsGk+BvkQVOoK1LBAPzZ1kCgYEAsBSmD8U33T9s4dxiEYTrqyV0lH3g/SFz8ZHH 1230pAyNEPI2iC1ONhyjPWKlcWHpAokiyOqeUpVBWnmSZtzC1qAydsxYB6ShT+sl9BHb 1231RMix/QAauzBJhQhUVJ3OIys0Q1UBDmqCsjCE8SfOT4NKOUnA093C+YT+iyrmmktZ 1232zDCJkckCgYEAndqM5KXGk5xYo+MAA1paZcbTUXwaWwjLU+XSRSSoyBEi5xMtfvUb 12337+a1OMhLwWbuz+pl64wFKrbSUyimMOYQpjVE/1vk/kb99pxbgol27hdKyTH1d+ov 1234kFsxKCqxAnBVGEWAvVZAiiTOxleQFjz5RnL0BQp9Lg2cQe+dvuUmIAA= 1235-----END RSA PRIVATE KEY-----`) 1236) 1237