1// DNS server implementation. 2 3package dns 4 5import ( 6 "context" 7 "crypto/tls" 8 "encoding/binary" 9 "errors" 10 "io" 11 "net" 12 "strings" 13 "sync" 14 "time" 15) 16 17// Default maximum number of TCP queries before we close the socket. 18const maxTCPQueries = 128 19 20// aLongTimeAgo is a non-zero time, far in the past, used for 21// immediate cancelation of network operations. 22var aLongTimeAgo = time.Unix(1, 0) 23 24// Handler is implemented by any value that implements ServeDNS. 25type Handler interface { 26 ServeDNS(w ResponseWriter, r *Msg) 27} 28 29// The HandlerFunc type is an adapter to allow the use of 30// ordinary functions as DNS handlers. If f is a function 31// with the appropriate signature, HandlerFunc(f) is a 32// Handler object that calls f. 33type HandlerFunc func(ResponseWriter, *Msg) 34 35// ServeDNS calls f(w, r). 36func (f HandlerFunc) ServeDNS(w ResponseWriter, r *Msg) { 37 f(w, r) 38} 39 40// A ResponseWriter interface is used by an DNS handler to 41// construct an DNS response. 42type ResponseWriter interface { 43 // LocalAddr returns the net.Addr of the server 44 LocalAddr() net.Addr 45 // RemoteAddr returns the net.Addr of the client that sent the current request. 46 RemoteAddr() net.Addr 47 // WriteMsg writes a reply back to the client. 48 WriteMsg(*Msg) error 49 // Write writes a raw buffer back to the client. 50 Write([]byte) (int, error) 51 // Close closes the connection. 52 Close() error 53 // TsigStatus returns the status of the Tsig. 54 TsigStatus() error 55 // TsigTimersOnly sets the tsig timers only boolean. 56 TsigTimersOnly(bool) 57 // Hijack lets the caller take over the connection. 58 // After a call to Hijack(), the DNS package will not do anything with the connection. 59 Hijack() 60} 61 62// A ConnectionStater interface is used by a DNS Handler to access TLS connection state 63// when available. 64type ConnectionStater interface { 65 ConnectionState() *tls.ConnectionState 66} 67 68type response struct { 69 closed bool // connection has been closed 70 hijacked bool // connection has been hijacked by handler 71 tsigTimersOnly bool 72 tsigStatus error 73 tsigRequestMAC string 74 tsigSecret map[string]string // the tsig secrets 75 udp *net.UDPConn // i/o connection if UDP was used 76 tcp net.Conn // i/o connection if TCP was used 77 udpSession *SessionUDP // oob data to get egress interface right 78 writer Writer // writer to output the raw DNS bits 79} 80 81// HandleFailed returns a HandlerFunc that returns SERVFAIL for every request it gets. 82func HandleFailed(w ResponseWriter, r *Msg) { 83 m := new(Msg) 84 m.SetRcode(r, RcodeServerFailure) 85 // does not matter if this write fails 86 w.WriteMsg(m) 87} 88 89// ListenAndServe Starts a server on address and network specified Invoke handler 90// for incoming queries. 91func ListenAndServe(addr string, network string, handler Handler) error { 92 server := &Server{Addr: addr, Net: network, Handler: handler} 93 return server.ListenAndServe() 94} 95 96// ListenAndServeTLS acts like http.ListenAndServeTLS, more information in 97// http://golang.org/pkg/net/http/#ListenAndServeTLS 98func ListenAndServeTLS(addr, certFile, keyFile string, handler Handler) error { 99 cert, err := tls.LoadX509KeyPair(certFile, keyFile) 100 if err != nil { 101 return err 102 } 103 104 config := tls.Config{ 105 Certificates: []tls.Certificate{cert}, 106 } 107 108 server := &Server{ 109 Addr: addr, 110 Net: "tcp-tls", 111 TLSConfig: &config, 112 Handler: handler, 113 } 114 115 return server.ListenAndServe() 116} 117 118// ActivateAndServe activates a server with a listener from systemd, 119// l and p should not both be non-nil. 120// If both l and p are not nil only p will be used. 121// Invoke handler for incoming queries. 122func ActivateAndServe(l net.Listener, p net.PacketConn, handler Handler) error { 123 server := &Server{Listener: l, PacketConn: p, Handler: handler} 124 return server.ActivateAndServe() 125} 126 127// Writer writes raw DNS messages; each call to Write should send an entire message. 128type Writer interface { 129 io.Writer 130} 131 132// Reader reads raw DNS messages; each call to ReadTCP or ReadUDP should return an entire message. 133type Reader interface { 134 // ReadTCP reads a raw message from a TCP connection. Implementations may alter 135 // connection properties, for example the read-deadline. 136 ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error) 137 // ReadUDP reads a raw message from a UDP connection. Implementations may alter 138 // connection properties, for example the read-deadline. 139 ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) 140} 141 142// defaultReader is an adapter for the Server struct that implements the Reader interface 143// using the readTCP and readUDP func of the embedded Server. 144type defaultReader struct { 145 *Server 146} 147 148func (dr defaultReader) ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error) { 149 return dr.readTCP(conn, timeout) 150} 151 152func (dr defaultReader) ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) { 153 return dr.readUDP(conn, timeout) 154} 155 156// DecorateReader is a decorator hook for extending or supplanting the functionality of a Reader. 157// Implementations should never return a nil Reader. 158type DecorateReader func(Reader) Reader 159 160// DecorateWriter is a decorator hook for extending or supplanting the functionality of a Writer. 161// Implementations should never return a nil Writer. 162type DecorateWriter func(Writer) Writer 163 164// A Server defines parameters for running an DNS server. 165type Server struct { 166 // Address to listen on, ":dns" if empty. 167 Addr string 168 // if "tcp" or "tcp-tls" (DNS over TLS) it will invoke a TCP listener, otherwise an UDP one 169 Net string 170 // TCP Listener to use, this is to aid in systemd's socket activation. 171 Listener net.Listener 172 // TLS connection configuration 173 TLSConfig *tls.Config 174 // UDP "Listener" to use, this is to aid in systemd's socket activation. 175 PacketConn net.PacketConn 176 // Handler to invoke, dns.DefaultServeMux if nil. 177 Handler Handler 178 // Default buffer size to use to read incoming UDP messages. If not set 179 // it defaults to MinMsgSize (512 B). 180 UDPSize int 181 // The net.Conn.SetReadTimeout value for new connections, defaults to 2 * time.Second. 182 ReadTimeout time.Duration 183 // The net.Conn.SetWriteTimeout value for new connections, defaults to 2 * time.Second. 184 WriteTimeout time.Duration 185 // TCP idle timeout for multiple queries, if nil, defaults to 8 * time.Second (RFC 5966). 186 IdleTimeout func() time.Duration 187 // Secret(s) for Tsig map[<zonename>]<base64 secret>. The zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2). 188 TsigSecret map[string]string 189 // If NotifyStartedFunc is set it is called once the server has started listening. 190 NotifyStartedFunc func() 191 // DecorateReader is optional, allows customization of the process that reads raw DNS messages. 192 DecorateReader DecorateReader 193 // DecorateWriter is optional, allows customization of the process that writes raw DNS messages. 194 DecorateWriter DecorateWriter 195 // Maximum number of TCP queries before we close the socket. Default is maxTCPQueries (unlimited if -1). 196 MaxTCPQueries int 197 // Whether to set the SO_REUSEPORT socket option, allowing multiple listeners to be bound to a single address. 198 // It is only supported on go1.11+ and when using ListenAndServe. 199 ReusePort bool 200 // AcceptMsgFunc will check the incoming message and will reject it early in the process. 201 // By default DefaultMsgAcceptFunc will be used. 202 MsgAcceptFunc MsgAcceptFunc 203 204 // Shutdown handling 205 lock sync.RWMutex 206 started bool 207 shutdown chan struct{} 208 conns map[net.Conn]struct{} 209 210 // A pool for UDP message buffers. 211 udpPool sync.Pool 212} 213 214func (srv *Server) isStarted() bool { 215 srv.lock.RLock() 216 started := srv.started 217 srv.lock.RUnlock() 218 return started 219} 220 221func makeUDPBuffer(size int) func() interface{} { 222 return func() interface{} { 223 return make([]byte, size) 224 } 225} 226 227func (srv *Server) init() { 228 srv.shutdown = make(chan struct{}) 229 srv.conns = make(map[net.Conn]struct{}) 230 231 if srv.UDPSize == 0 { 232 srv.UDPSize = MinMsgSize 233 } 234 if srv.MsgAcceptFunc == nil { 235 srv.MsgAcceptFunc = DefaultMsgAcceptFunc 236 } 237 if srv.Handler == nil { 238 srv.Handler = DefaultServeMux 239 } 240 241 srv.udpPool.New = makeUDPBuffer(srv.UDPSize) 242} 243 244func unlockOnce(l sync.Locker) func() { 245 var once sync.Once 246 return func() { once.Do(l.Unlock) } 247} 248 249// ListenAndServe starts a nameserver on the configured address in *Server. 250func (srv *Server) ListenAndServe() error { 251 unlock := unlockOnce(&srv.lock) 252 srv.lock.Lock() 253 defer unlock() 254 255 if srv.started { 256 return &Error{err: "server already started"} 257 } 258 259 addr := srv.Addr 260 if addr == "" { 261 addr = ":domain" 262 } 263 264 srv.init() 265 266 switch srv.Net { 267 case "tcp", "tcp4", "tcp6": 268 l, err := listenTCP(srv.Net, addr, srv.ReusePort) 269 if err != nil { 270 return err 271 } 272 srv.Listener = l 273 srv.started = true 274 unlock() 275 return srv.serveTCP(l) 276 case "tcp-tls", "tcp4-tls", "tcp6-tls": 277 if srv.TLSConfig == nil || (len(srv.TLSConfig.Certificates) == 0 && srv.TLSConfig.GetCertificate == nil) { 278 return errors.New("dns: neither Certificates nor GetCertificate set in Config") 279 } 280 network := strings.TrimSuffix(srv.Net, "-tls") 281 l, err := listenTCP(network, addr, srv.ReusePort) 282 if err != nil { 283 return err 284 } 285 l = tls.NewListener(l, srv.TLSConfig) 286 srv.Listener = l 287 srv.started = true 288 unlock() 289 return srv.serveTCP(l) 290 case "udp", "udp4", "udp6": 291 l, err := listenUDP(srv.Net, addr, srv.ReusePort) 292 if err != nil { 293 return err 294 } 295 u := l.(*net.UDPConn) 296 if e := setUDPSocketOptions(u); e != nil { 297 return e 298 } 299 srv.PacketConn = l 300 srv.started = true 301 unlock() 302 return srv.serveUDP(u) 303 } 304 return &Error{err: "bad network"} 305} 306 307// ActivateAndServe starts a nameserver with the PacketConn or Listener 308// configured in *Server. Its main use is to start a server from systemd. 309func (srv *Server) ActivateAndServe() error { 310 unlock := unlockOnce(&srv.lock) 311 srv.lock.Lock() 312 defer unlock() 313 314 if srv.started { 315 return &Error{err: "server already started"} 316 } 317 318 srv.init() 319 320 pConn := srv.PacketConn 321 l := srv.Listener 322 if pConn != nil { 323 // Check PacketConn interface's type is valid and value 324 // is not nil 325 if t, ok := pConn.(*net.UDPConn); ok && t != nil { 326 if e := setUDPSocketOptions(t); e != nil { 327 return e 328 } 329 srv.started = true 330 unlock() 331 return srv.serveUDP(t) 332 } 333 } 334 if l != nil { 335 srv.started = true 336 unlock() 337 return srv.serveTCP(l) 338 } 339 return &Error{err: "bad listeners"} 340} 341 342// Shutdown shuts down a server. After a call to Shutdown, ListenAndServe and 343// ActivateAndServe will return. 344func (srv *Server) Shutdown() error { 345 return srv.ShutdownContext(context.Background()) 346} 347 348// ShutdownContext shuts down a server. After a call to ShutdownContext, 349// ListenAndServe and ActivateAndServe will return. 350// 351// A context.Context may be passed to limit how long to wait for connections 352// to terminate. 353func (srv *Server) ShutdownContext(ctx context.Context) error { 354 srv.lock.Lock() 355 if !srv.started { 356 srv.lock.Unlock() 357 return &Error{err: "server not started"} 358 } 359 360 srv.started = false 361 362 if srv.PacketConn != nil { 363 srv.PacketConn.SetReadDeadline(aLongTimeAgo) // Unblock reads 364 } 365 366 if srv.Listener != nil { 367 srv.Listener.Close() 368 } 369 370 for rw := range srv.conns { 371 rw.SetReadDeadline(aLongTimeAgo) // Unblock reads 372 } 373 374 srv.lock.Unlock() 375 376 if testShutdownNotify != nil { 377 testShutdownNotify.Broadcast() 378 } 379 380 var ctxErr error 381 select { 382 case <-srv.shutdown: 383 case <-ctx.Done(): 384 ctxErr = ctx.Err() 385 } 386 387 if srv.PacketConn != nil { 388 srv.PacketConn.Close() 389 } 390 391 return ctxErr 392} 393 394var testShutdownNotify *sync.Cond 395 396// getReadTimeout is a helper func to use system timeout if server did not intend to change it. 397func (srv *Server) getReadTimeout() time.Duration { 398 if srv.ReadTimeout != 0 { 399 return srv.ReadTimeout 400 } 401 return dnsTimeout 402} 403 404// serveTCP starts a TCP listener for the server. 405func (srv *Server) serveTCP(l net.Listener) error { 406 defer l.Close() 407 408 if srv.NotifyStartedFunc != nil { 409 srv.NotifyStartedFunc() 410 } 411 412 var wg sync.WaitGroup 413 defer func() { 414 wg.Wait() 415 close(srv.shutdown) 416 }() 417 418 for srv.isStarted() { 419 rw, err := l.Accept() 420 if err != nil { 421 if !srv.isStarted() { 422 return nil 423 } 424 if neterr, ok := err.(net.Error); ok && neterr.Temporary() { 425 continue 426 } 427 return err 428 } 429 srv.lock.Lock() 430 // Track the connection to allow unblocking reads on shutdown. 431 srv.conns[rw] = struct{}{} 432 srv.lock.Unlock() 433 wg.Add(1) 434 go srv.serveTCPConn(&wg, rw) 435 } 436 437 return nil 438} 439 440// serveUDP starts a UDP listener for the server. 441func (srv *Server) serveUDP(l *net.UDPConn) error { 442 defer l.Close() 443 444 if srv.NotifyStartedFunc != nil { 445 srv.NotifyStartedFunc() 446 } 447 448 reader := Reader(defaultReader{srv}) 449 if srv.DecorateReader != nil { 450 reader = srv.DecorateReader(reader) 451 } 452 453 var wg sync.WaitGroup 454 defer func() { 455 wg.Wait() 456 close(srv.shutdown) 457 }() 458 459 rtimeout := srv.getReadTimeout() 460 // deadline is not used here 461 for srv.isStarted() { 462 m, s, err := reader.ReadUDP(l, rtimeout) 463 if err != nil { 464 if !srv.isStarted() { 465 return nil 466 } 467 if netErr, ok := err.(net.Error); ok && netErr.Temporary() { 468 continue 469 } 470 return err 471 } 472 if len(m) < headerSize { 473 if cap(m) == srv.UDPSize { 474 srv.udpPool.Put(m[:srv.UDPSize]) 475 } 476 continue 477 } 478 wg.Add(1) 479 go srv.serveUDPPacket(&wg, m, l, s) 480 } 481 482 return nil 483} 484 485// Serve a new TCP connection. 486func (srv *Server) serveTCPConn(wg *sync.WaitGroup, rw net.Conn) { 487 w := &response{tsigSecret: srv.TsigSecret, tcp: rw} 488 if srv.DecorateWriter != nil { 489 w.writer = srv.DecorateWriter(w) 490 } else { 491 w.writer = w 492 } 493 494 reader := Reader(defaultReader{srv}) 495 if srv.DecorateReader != nil { 496 reader = srv.DecorateReader(reader) 497 } 498 499 idleTimeout := tcpIdleTimeout 500 if srv.IdleTimeout != nil { 501 idleTimeout = srv.IdleTimeout() 502 } 503 504 timeout := srv.getReadTimeout() 505 506 limit := srv.MaxTCPQueries 507 if limit == 0 { 508 limit = maxTCPQueries 509 } 510 511 for q := 0; (q < limit || limit == -1) && srv.isStarted(); q++ { 512 m, err := reader.ReadTCP(w.tcp, timeout) 513 if err != nil { 514 // TODO(tmthrgd): handle error 515 break 516 } 517 srv.serveDNS(m, w) 518 if w.closed { 519 break // Close() was called 520 } 521 if w.hijacked { 522 break // client will call Close() themselves 523 } 524 // The first read uses the read timeout, the rest use the 525 // idle timeout. 526 timeout = idleTimeout 527 } 528 529 if !w.hijacked { 530 w.Close() 531 } 532 533 srv.lock.Lock() 534 delete(srv.conns, w.tcp) 535 srv.lock.Unlock() 536 537 wg.Done() 538} 539 540// Serve a new UDP request. 541func (srv *Server) serveUDPPacket(wg *sync.WaitGroup, m []byte, u *net.UDPConn, s *SessionUDP) { 542 w := &response{tsigSecret: srv.TsigSecret, udp: u, udpSession: s} 543 if srv.DecorateWriter != nil { 544 w.writer = srv.DecorateWriter(w) 545 } else { 546 w.writer = w 547 } 548 549 srv.serveDNS(m, w) 550 wg.Done() 551} 552 553func (srv *Server) serveDNS(m []byte, w *response) { 554 dh, off, err := unpackMsgHdr(m, 0) 555 if err != nil { 556 // Let client hang, they are sending crap; any reply can be used to amplify. 557 return 558 } 559 560 req := new(Msg) 561 req.setHdr(dh) 562 563 switch action := srv.MsgAcceptFunc(dh); action { 564 case MsgAccept: 565 if req.unpack(dh, m, off) == nil { 566 break 567 } 568 569 fallthrough 570 case MsgReject, MsgRejectNotImplemented: 571 opcode := req.Opcode 572 req.SetRcodeFormatError(req) 573 req.Zero = false 574 if action == MsgRejectNotImplemented { 575 req.Opcode = opcode 576 req.Rcode = RcodeNotImplemented 577 } 578 579 // Are we allowed to delete any OPT records here? 580 req.Ns, req.Answer, req.Extra = nil, nil, nil 581 582 w.WriteMsg(req) 583 fallthrough 584 case MsgIgnore: 585 if w.udp != nil && cap(m) == srv.UDPSize { 586 srv.udpPool.Put(m[:srv.UDPSize]) 587 } 588 589 return 590 } 591 592 w.tsigStatus = nil 593 if w.tsigSecret != nil { 594 if t := req.IsTsig(); t != nil { 595 if secret, ok := w.tsigSecret[t.Hdr.Name]; ok { 596 w.tsigStatus = TsigVerify(m, secret, "", false) 597 } else { 598 w.tsigStatus = ErrSecret 599 } 600 w.tsigTimersOnly = false 601 w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC 602 } 603 } 604 605 if w.udp != nil && cap(m) == srv.UDPSize { 606 srv.udpPool.Put(m[:srv.UDPSize]) 607 } 608 609 srv.Handler.ServeDNS(w, req) // Writes back to the client 610} 611 612func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error) { 613 // If we race with ShutdownContext, the read deadline may 614 // have been set in the distant past to unblock the read 615 // below. We must not override it, otherwise we may block 616 // ShutdownContext. 617 srv.lock.RLock() 618 if srv.started { 619 conn.SetReadDeadline(time.Now().Add(timeout)) 620 } 621 srv.lock.RUnlock() 622 623 var length uint16 624 if err := binary.Read(conn, binary.BigEndian, &length); err != nil { 625 return nil, err 626 } 627 628 m := make([]byte, length) 629 if _, err := io.ReadFull(conn, m); err != nil { 630 return nil, err 631 } 632 633 return m, nil 634} 635 636func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) { 637 srv.lock.RLock() 638 if srv.started { 639 // See the comment in readTCP above. 640 conn.SetReadDeadline(time.Now().Add(timeout)) 641 } 642 srv.lock.RUnlock() 643 644 m := srv.udpPool.Get().([]byte) 645 n, s, err := ReadFromSessionUDP(conn, m) 646 if err != nil { 647 srv.udpPool.Put(m) 648 return nil, nil, err 649 } 650 m = m[:n] 651 return m, s, nil 652} 653 654// WriteMsg implements the ResponseWriter.WriteMsg method. 655func (w *response) WriteMsg(m *Msg) (err error) { 656 if w.closed { 657 return &Error{err: "WriteMsg called after Close"} 658 } 659 660 var data []byte 661 if w.tsigSecret != nil { // if no secrets, dont check for the tsig (which is a longer check) 662 if t := m.IsTsig(); t != nil { 663 data, w.tsigRequestMAC, err = TsigGenerate(m, w.tsigSecret[t.Hdr.Name], w.tsigRequestMAC, w.tsigTimersOnly) 664 if err != nil { 665 return err 666 } 667 _, err = w.writer.Write(data) 668 return err 669 } 670 } 671 data, err = m.Pack() 672 if err != nil { 673 return err 674 } 675 _, err = w.writer.Write(data) 676 return err 677} 678 679// Write implements the ResponseWriter.Write method. 680func (w *response) Write(m []byte) (int, error) { 681 if w.closed { 682 return 0, &Error{err: "Write called after Close"} 683 } 684 685 switch { 686 case w.udp != nil: 687 return WriteToSessionUDP(w.udp, m, w.udpSession) 688 case w.tcp != nil: 689 if len(m) > MaxMsgSize { 690 return 0, &Error{err: "message too large"} 691 } 692 693 l := make([]byte, 2) 694 binary.BigEndian.PutUint16(l, uint16(len(m))) 695 696 n, err := (&net.Buffers{l, m}).WriteTo(w.tcp) 697 return int(n), err 698 default: 699 panic("dns: internal error: udp and tcp both nil") 700 } 701} 702 703// LocalAddr implements the ResponseWriter.LocalAddr method. 704func (w *response) LocalAddr() net.Addr { 705 switch { 706 case w.udp != nil: 707 return w.udp.LocalAddr() 708 case w.tcp != nil: 709 return w.tcp.LocalAddr() 710 default: 711 panic("dns: internal error: udp and tcp both nil") 712 } 713} 714 715// RemoteAddr implements the ResponseWriter.RemoteAddr method. 716func (w *response) RemoteAddr() net.Addr { 717 switch { 718 case w.udpSession != nil: 719 return w.udpSession.RemoteAddr() 720 case w.tcp != nil: 721 return w.tcp.RemoteAddr() 722 default: 723 panic("dns: internal error: udpSession and tcp both nil") 724 } 725} 726 727// TsigStatus implements the ResponseWriter.TsigStatus method. 728func (w *response) TsigStatus() error { return w.tsigStatus } 729 730// TsigTimersOnly implements the ResponseWriter.TsigTimersOnly method. 731func (w *response) TsigTimersOnly(b bool) { w.tsigTimersOnly = b } 732 733// Hijack implements the ResponseWriter.Hijack method. 734func (w *response) Hijack() { w.hijacked = true } 735 736// Close implements the ResponseWriter.Close method 737func (w *response) Close() error { 738 if w.closed { 739 return &Error{err: "connection already closed"} 740 } 741 w.closed = true 742 743 switch { 744 case w.udp != nil: 745 // Can't close the udp conn, as that is actually the listener. 746 return nil 747 case w.tcp != nil: 748 return w.tcp.Close() 749 default: 750 panic("dns: internal error: udp and tcp both nil") 751 } 752} 753 754// ConnectionState() implements the ConnectionStater.ConnectionState() interface. 755func (w *response) ConnectionState() *tls.ConnectionState { 756 type tlsConnectionStater interface { 757 ConnectionState() tls.ConnectionState 758 } 759 if v, ok := w.tcp.(tlsConnectionStater); ok { 760 t := v.ConnectionState() 761 return &t 762 } 763 return nil 764} 765