1// DNS server implementation. 2 3package dns 4 5import ( 6 "bytes" 7 "crypto/tls" 8 "encoding/binary" 9 "io" 10 "net" 11 "sync" 12 "time" 13) 14 15// Maximum number of TCP queries before we close the socket. 16const maxTCPQueries = 128 17 18// Handler is implemented by any value that implements ServeDNS. 19type Handler interface { 20 ServeDNS(w ResponseWriter, r *Msg) 21} 22 23// A ResponseWriter interface is used by an DNS handler to 24// construct an DNS response. 25type ResponseWriter interface { 26 // LocalAddr returns the net.Addr of the server 27 LocalAddr() net.Addr 28 // RemoteAddr returns the net.Addr of the client that sent the current request. 29 RemoteAddr() net.Addr 30 // WriteMsg writes a reply back to the client. 31 WriteMsg(*Msg) error 32 // Write writes a raw buffer back to the client. 33 Write([]byte) (int, error) 34 // Close closes the connection. 35 Close() error 36 // TsigStatus returns the status of the Tsig. 37 TsigStatus() error 38 // TsigTimersOnly sets the tsig timers only boolean. 39 TsigTimersOnly(bool) 40 // Hijack lets the caller take over the connection. 41 // After a call to Hijack(), the DNS package will not do anything with the connection. 42 Hijack() 43} 44 45type response struct { 46 hijacked bool // connection has been hijacked by handler 47 tsigStatus error 48 tsigTimersOnly bool 49 tsigRequestMAC string 50 tsigSecret map[string]string // the tsig secrets 51 udp *net.UDPConn // i/o connection if UDP was used 52 tcp net.Conn // i/o connection if TCP was used 53 udpSession *SessionUDP // oob data to get egress interface right 54 remoteAddr net.Addr // address of the client 55 writer Writer // writer to output the raw DNS bits 56} 57 58// ServeMux is an DNS request multiplexer. It matches the 59// zone name of each incoming request against a list of 60// registered patterns add calls the handler for the pattern 61// that most closely matches the zone name. ServeMux is DNSSEC aware, meaning 62// that queries for the DS record are redirected to the parent zone (if that 63// is also registered), otherwise the child gets the query. 64// ServeMux is also safe for concurrent access from multiple goroutines. 65type ServeMux struct { 66 z map[string]Handler 67 m *sync.RWMutex 68} 69 70// NewServeMux allocates and returns a new ServeMux. 71func NewServeMux() *ServeMux { return &ServeMux{z: make(map[string]Handler), m: new(sync.RWMutex)} } 72 73// DefaultServeMux is the default ServeMux used by Serve. 74var DefaultServeMux = NewServeMux() 75 76// The HandlerFunc type is an adapter to allow the use of 77// ordinary functions as DNS handlers. If f is a function 78// with the appropriate signature, HandlerFunc(f) is a 79// Handler object that calls f. 80type HandlerFunc func(ResponseWriter, *Msg) 81 82// ServeDNS calls f(w, r). 83func (f HandlerFunc) ServeDNS(w ResponseWriter, r *Msg) { 84 f(w, r) 85} 86 87// HandleFailed returns a HandlerFunc that returns SERVFAIL for every request it gets. 88func HandleFailed(w ResponseWriter, r *Msg) { 89 m := new(Msg) 90 m.SetRcode(r, RcodeServerFailure) 91 // does not matter if this write fails 92 w.WriteMsg(m) 93} 94 95func failedHandler() Handler { return HandlerFunc(HandleFailed) } 96 97// ListenAndServe Starts a server on address and network specified Invoke handler 98// for incoming queries. 99func ListenAndServe(addr string, network string, handler Handler) error { 100 server := &Server{Addr: addr, Net: network, Handler: handler} 101 return server.ListenAndServe() 102} 103 104// ListenAndServeTLS acts like http.ListenAndServeTLS, more information in 105// http://golang.org/pkg/net/http/#ListenAndServeTLS 106func ListenAndServeTLS(addr, certFile, keyFile string, handler Handler) error { 107 cert, err := tls.LoadX509KeyPair(certFile, keyFile) 108 if err != nil { 109 return err 110 } 111 112 config := tls.Config{ 113 Certificates: []tls.Certificate{cert}, 114 } 115 116 server := &Server{ 117 Addr: addr, 118 Net: "tcp-tls", 119 TLSConfig: &config, 120 Handler: handler, 121 } 122 123 return server.ListenAndServe() 124} 125 126// ActivateAndServe activates a server with a listener from systemd, 127// l and p should not both be non-nil. 128// If both l and p are not nil only p will be used. 129// Invoke handler for incoming queries. 130func ActivateAndServe(l net.Listener, p net.PacketConn, handler Handler) error { 131 server := &Server{Listener: l, PacketConn: p, Handler: handler} 132 return server.ActivateAndServe() 133} 134 135func (mux *ServeMux) match(q string, t uint16) Handler { 136 mux.m.RLock() 137 defer mux.m.RUnlock() 138 var handler Handler 139 b := make([]byte, len(q)) // worst case, one label of length q 140 off := 0 141 end := false 142 for { 143 l := len(q[off:]) 144 for i := 0; i < l; i++ { 145 b[i] = q[off+i] 146 if b[i] >= 'A' && b[i] <= 'Z' { 147 b[i] |= ('a' - 'A') 148 } 149 } 150 if h, ok := mux.z[string(b[:l])]; ok { // causes garbage, might want to change the map key 151 if t != TypeDS { 152 return h 153 } 154 // Continue for DS to see if we have a parent too, if so delegeate to the parent 155 handler = h 156 } 157 off, end = NextLabel(q, off) 158 if end { 159 break 160 } 161 } 162 // Wildcard match, if we have found nothing try the root zone as a last resort. 163 if h, ok := mux.z["."]; ok { 164 return h 165 } 166 return handler 167} 168 169// Handle adds a handler to the ServeMux for pattern. 170func (mux *ServeMux) Handle(pattern string, handler Handler) { 171 if pattern == "" { 172 panic("dns: invalid pattern " + pattern) 173 } 174 mux.m.Lock() 175 mux.z[Fqdn(pattern)] = handler 176 mux.m.Unlock() 177} 178 179// HandleFunc adds a handler function to the ServeMux for pattern. 180func (mux *ServeMux) HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) { 181 mux.Handle(pattern, HandlerFunc(handler)) 182} 183 184// HandleRemove deregistrars the handler specific for pattern from the ServeMux. 185func (mux *ServeMux) HandleRemove(pattern string) { 186 if pattern == "" { 187 panic("dns: invalid pattern " + pattern) 188 } 189 mux.m.Lock() 190 delete(mux.z, Fqdn(pattern)) 191 mux.m.Unlock() 192} 193 194// ServeDNS dispatches the request to the handler whose 195// pattern most closely matches the request message. If DefaultServeMux 196// is used the correct thing for DS queries is done: a possible parent 197// is sought. 198// If no handler is found a standard SERVFAIL message is returned 199// If the request message does not have exactly one question in the 200// question section a SERVFAIL is returned, unlesss Unsafe is true. 201func (mux *ServeMux) ServeDNS(w ResponseWriter, request *Msg) { 202 var h Handler 203 if len(request.Question) < 1 { // allow more than one question 204 h = failedHandler() 205 } else { 206 if h = mux.match(request.Question[0].Name, request.Question[0].Qtype); h == nil { 207 h = failedHandler() 208 } 209 } 210 h.ServeDNS(w, request) 211} 212 213// Handle registers the handler with the given pattern 214// in the DefaultServeMux. The documentation for 215// ServeMux explains how patterns are matched. 216func Handle(pattern string, handler Handler) { DefaultServeMux.Handle(pattern, handler) } 217 218// HandleRemove deregisters the handle with the given pattern 219// in the DefaultServeMux. 220func HandleRemove(pattern string) { DefaultServeMux.HandleRemove(pattern) } 221 222// HandleFunc registers the handler function with the given pattern 223// in the DefaultServeMux. 224func HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) { 225 DefaultServeMux.HandleFunc(pattern, handler) 226} 227 228// Writer writes raw DNS messages; each call to Write should send an entire message. 229type Writer interface { 230 io.Writer 231} 232 233// Reader reads raw DNS messages; each call to ReadTCP or ReadUDP should return an entire message. 234type Reader interface { 235 // ReadTCP reads a raw message from a TCP connection. Implementations may alter 236 // connection properties, for example the read-deadline. 237 ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error) 238 // ReadUDP reads a raw message from a UDP connection. Implementations may alter 239 // connection properties, for example the read-deadline. 240 ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) 241} 242 243// defaultReader is an adapter for the Server struct that implements the Reader interface 244// using the readTCP and readUDP func of the embedded Server. 245type defaultReader struct { 246 *Server 247} 248 249func (dr *defaultReader) ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error) { 250 return dr.readTCP(conn, timeout) 251} 252 253func (dr *defaultReader) ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) { 254 return dr.readUDP(conn, timeout) 255} 256 257// DecorateReader is a decorator hook for extending or supplanting the functionality of a Reader. 258// Implementations should never return a nil Reader. 259type DecorateReader func(Reader) Reader 260 261// DecorateWriter is a decorator hook for extending or supplanting the functionality of a Writer. 262// Implementations should never return a nil Writer. 263type DecorateWriter func(Writer) Writer 264 265// A Server defines parameters for running an DNS server. 266type Server struct { 267 // Address to listen on, ":dns" if empty. 268 Addr string 269 // if "tcp" or "tcp-tls" (DNS over TLS) it will invoke a TCP listener, otherwise an UDP one 270 Net string 271 // TCP Listener to use, this is to aid in systemd's socket activation. 272 Listener net.Listener 273 // TLS connection configuration 274 TLSConfig *tls.Config 275 // UDP "Listener" to use, this is to aid in systemd's socket activation. 276 PacketConn net.PacketConn 277 // Handler to invoke, dns.DefaultServeMux if nil. 278 Handler Handler 279 // Default buffer size to use to read incoming UDP messages. If not set 280 // it defaults to MinMsgSize (512 B). 281 UDPSize int 282 // The net.Conn.SetReadTimeout value for new connections, defaults to 2 * time.Second. 283 ReadTimeout time.Duration 284 // The net.Conn.SetWriteTimeout value for new connections, defaults to 2 * time.Second. 285 WriteTimeout time.Duration 286 // TCP idle timeout for multiple queries, if nil, defaults to 8 * time.Second (RFC 5966). 287 IdleTimeout func() time.Duration 288 // Secret(s) for Tsig map[<zonename>]<base64 secret>. 289 TsigSecret map[string]string 290 // Unsafe instructs the server to disregard any sanity checks and directly hand the message to 291 // the handler. It will specifically not check if the query has the QR bit not set. 292 Unsafe bool 293 // If NotifyStartedFunc is set it is called once the server has started listening. 294 NotifyStartedFunc func() 295 // DecorateReader is optional, allows customization of the process that reads raw DNS messages. 296 DecorateReader DecorateReader 297 // DecorateWriter is optional, allows customization of the process that writes raw DNS messages. 298 DecorateWriter DecorateWriter 299 300 // Graceful shutdown handling 301 302 inFlight sync.WaitGroup 303 304 lock sync.RWMutex 305 started bool 306} 307 308// ListenAndServe starts a nameserver on the configured address in *Server. 309func (srv *Server) ListenAndServe() error { 310 srv.lock.Lock() 311 defer srv.lock.Unlock() 312 if srv.started { 313 return &Error{err: "server already started"} 314 } 315 addr := srv.Addr 316 if addr == "" { 317 addr = ":domain" 318 } 319 if srv.UDPSize == 0 { 320 srv.UDPSize = MinMsgSize 321 } 322 switch srv.Net { 323 case "tcp", "tcp4", "tcp6": 324 a, err := net.ResolveTCPAddr(srv.Net, addr) 325 if err != nil { 326 return err 327 } 328 l, err := net.ListenTCP(srv.Net, a) 329 if err != nil { 330 return err 331 } 332 srv.Listener = l 333 srv.started = true 334 srv.lock.Unlock() 335 err = srv.serveTCP(l) 336 srv.lock.Lock() // to satisfy the defer at the top 337 return err 338 case "tcp-tls", "tcp4-tls", "tcp6-tls": 339 network := "tcp" 340 if srv.Net == "tcp4-tls" { 341 network = "tcp4" 342 } else if srv.Net == "tcp6" { 343 network = "tcp6" 344 } 345 346 l, err := tls.Listen(network, addr, srv.TLSConfig) 347 if err != nil { 348 return err 349 } 350 srv.Listener = l 351 srv.started = true 352 srv.lock.Unlock() 353 err = srv.serveTCP(l) 354 srv.lock.Lock() // to satisfy the defer at the top 355 return err 356 case "udp", "udp4", "udp6": 357 a, err := net.ResolveUDPAddr(srv.Net, addr) 358 if err != nil { 359 return err 360 } 361 l, err := net.ListenUDP(srv.Net, a) 362 if err != nil { 363 return err 364 } 365 if e := setUDPSocketOptions(l); e != nil { 366 return e 367 } 368 srv.PacketConn = l 369 srv.started = true 370 srv.lock.Unlock() 371 err = srv.serveUDP(l) 372 srv.lock.Lock() // to satisfy the defer at the top 373 return err 374 } 375 return &Error{err: "bad network"} 376} 377 378// ActivateAndServe starts a nameserver with the PacketConn or Listener 379// configured in *Server. Its main use is to start a server from systemd. 380func (srv *Server) ActivateAndServe() error { 381 srv.lock.Lock() 382 defer srv.lock.Unlock() 383 if srv.started { 384 return &Error{err: "server already started"} 385 } 386 pConn := srv.PacketConn 387 l := srv.Listener 388 if pConn != nil { 389 if srv.UDPSize == 0 { 390 srv.UDPSize = MinMsgSize 391 } 392 // Check PacketConn interface's type is valid and value 393 // is not nil 394 if t, ok := pConn.(*net.UDPConn); ok && t != nil { 395 if e := setUDPSocketOptions(t); e != nil { 396 return e 397 } 398 srv.started = true 399 srv.lock.Unlock() 400 e := srv.serveUDP(t) 401 srv.lock.Lock() // to satisfy the defer at the top 402 return e 403 } 404 } 405 if l != nil { 406 srv.started = true 407 srv.lock.Unlock() 408 e := srv.serveTCP(l) 409 srv.lock.Lock() // to satisfy the defer at the top 410 return e 411 } 412 return &Error{err: "bad listeners"} 413} 414 415// Shutdown gracefully shuts down a server. After a call to Shutdown, ListenAndServe and 416// ActivateAndServe will return. All in progress queries are completed before the server 417// is taken down. If the Shutdown is taking longer than the reading timeout an error 418// is returned. 419func (srv *Server) Shutdown() error { 420 srv.lock.Lock() 421 if !srv.started { 422 srv.lock.Unlock() 423 return &Error{err: "server not started"} 424 } 425 srv.started = false 426 srv.lock.Unlock() 427 428 if srv.PacketConn != nil { 429 srv.PacketConn.Close() 430 } 431 if srv.Listener != nil { 432 srv.Listener.Close() 433 } 434 435 fin := make(chan bool) 436 go func() { 437 srv.inFlight.Wait() 438 fin <- true 439 }() 440 441 select { 442 case <-time.After(srv.getReadTimeout()): 443 return &Error{err: "server shutdown is pending"} 444 case <-fin: 445 return nil 446 } 447} 448 449// getReadTimeout is a helper func to use system timeout if server did not intend to change it. 450func (srv *Server) getReadTimeout() time.Duration { 451 rtimeout := dnsTimeout 452 if srv.ReadTimeout != 0 { 453 rtimeout = srv.ReadTimeout 454 } 455 return rtimeout 456} 457 458// serveTCP starts a TCP listener for the server. 459// Each request is handled in a separate goroutine. 460func (srv *Server) serveTCP(l net.Listener) error { 461 defer l.Close() 462 463 if srv.NotifyStartedFunc != nil { 464 srv.NotifyStartedFunc() 465 } 466 467 reader := Reader(&defaultReader{srv}) 468 if srv.DecorateReader != nil { 469 reader = srv.DecorateReader(reader) 470 } 471 472 handler := srv.Handler 473 if handler == nil { 474 handler = DefaultServeMux 475 } 476 rtimeout := srv.getReadTimeout() 477 // deadline is not used here 478 for { 479 rw, err := l.Accept() 480 if err != nil { 481 if neterr, ok := err.(net.Error); ok && neterr.Temporary() { 482 continue 483 } 484 return err 485 } 486 m, err := reader.ReadTCP(rw, rtimeout) 487 srv.lock.RLock() 488 if !srv.started { 489 srv.lock.RUnlock() 490 return nil 491 } 492 srv.lock.RUnlock() 493 if err != nil { 494 continue 495 } 496 srv.inFlight.Add(1) 497 go srv.serve(rw.RemoteAddr(), handler, m, nil, nil, rw) 498 } 499} 500 501// serveUDP starts a UDP listener for the server. 502// Each request is handled in a separate goroutine. 503func (srv *Server) serveUDP(l *net.UDPConn) error { 504 defer l.Close() 505 506 if srv.NotifyStartedFunc != nil { 507 srv.NotifyStartedFunc() 508 } 509 510 reader := Reader(&defaultReader{srv}) 511 if srv.DecorateReader != nil { 512 reader = srv.DecorateReader(reader) 513 } 514 515 handler := srv.Handler 516 if handler == nil { 517 handler = DefaultServeMux 518 } 519 rtimeout := srv.getReadTimeout() 520 // deadline is not used here 521 for { 522 m, s, err := reader.ReadUDP(l, rtimeout) 523 srv.lock.RLock() 524 if !srv.started { 525 srv.lock.RUnlock() 526 return nil 527 } 528 srv.lock.RUnlock() 529 if err != nil { 530 continue 531 } 532 srv.inFlight.Add(1) 533 go srv.serve(s.RemoteAddr(), handler, m, l, s, nil) 534 } 535} 536 537// Serve a new connection. 538func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *SessionUDP, t net.Conn) { 539 defer srv.inFlight.Done() 540 541 w := &response{tsigSecret: srv.TsigSecret, udp: u, tcp: t, remoteAddr: a, udpSession: s} 542 if srv.DecorateWriter != nil { 543 w.writer = srv.DecorateWriter(w) 544 } else { 545 w.writer = w 546 } 547 548 q := 0 // counter for the amount of TCP queries we get 549 550 reader := Reader(&defaultReader{srv}) 551 if srv.DecorateReader != nil { 552 reader = srv.DecorateReader(reader) 553 } 554Redo: 555 req := new(Msg) 556 err := req.Unpack(m) 557 if err != nil { // Send a FormatError back 558 x := new(Msg) 559 x.SetRcodeFormatError(req) 560 w.WriteMsg(x) 561 goto Exit 562 } 563 if !srv.Unsafe && req.Response { 564 goto Exit 565 } 566 567 w.tsigStatus = nil 568 if w.tsigSecret != nil { 569 if t := req.IsTsig(); t != nil { 570 secret := t.Hdr.Name 571 if _, ok := w.tsigSecret[secret]; !ok { 572 w.tsigStatus = ErrKeyAlg 573 } 574 w.tsigStatus = TsigVerify(m, w.tsigSecret[secret], "", false) 575 w.tsigTimersOnly = false 576 w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC 577 } 578 } 579 h.ServeDNS(w, req) // Writes back to the client 580 581Exit: 582 if w.tcp == nil { 583 return 584 } 585 // TODO(miek): make this number configurable? 586 if q > maxTCPQueries { // close socket after this many queries 587 w.Close() 588 return 589 } 590 591 if w.hijacked { 592 return // client calls Close() 593 } 594 if u != nil { // UDP, "close" and return 595 w.Close() 596 return 597 } 598 idleTimeout := tcpIdleTimeout 599 if srv.IdleTimeout != nil { 600 idleTimeout = srv.IdleTimeout() 601 } 602 m, err = reader.ReadTCP(w.tcp, idleTimeout) 603 if err == nil { 604 q++ 605 goto Redo 606 } 607 w.Close() 608 return 609} 610 611func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error) { 612 conn.SetReadDeadline(time.Now().Add(timeout)) 613 l := make([]byte, 2) 614 n, err := conn.Read(l) 615 if err != nil || n != 2 { 616 if err != nil { 617 return nil, err 618 } 619 return nil, ErrShortRead 620 } 621 length := binary.BigEndian.Uint16(l) 622 if length == 0 { 623 return nil, ErrShortRead 624 } 625 m := make([]byte, int(length)) 626 n, err = conn.Read(m[:int(length)]) 627 if err != nil || n == 0 { 628 if err != nil { 629 return nil, err 630 } 631 return nil, ErrShortRead 632 } 633 i := n 634 for i < int(length) { 635 j, err := conn.Read(m[i:int(length)]) 636 if err != nil { 637 return nil, err 638 } 639 i += j 640 } 641 n = i 642 m = m[:n] 643 return m, nil 644} 645 646func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) { 647 conn.SetReadDeadline(time.Now().Add(timeout)) 648 m := make([]byte, srv.UDPSize) 649 n, s, err := ReadFromSessionUDP(conn, m) 650 if err != nil || n == 0 { 651 if err != nil { 652 return nil, nil, err 653 } 654 return nil, nil, ErrShortRead 655 } 656 m = m[:n] 657 return m, s, nil 658} 659 660// WriteMsg implements the ResponseWriter.WriteMsg method. 661func (w *response) WriteMsg(m *Msg) (err error) { 662 var data []byte 663 if w.tsigSecret != nil { // if no secrets, dont check for the tsig (which is a longer check) 664 if t := m.IsTsig(); t != nil { 665 data, w.tsigRequestMAC, err = TsigGenerate(m, w.tsigSecret[t.Hdr.Name], w.tsigRequestMAC, w.tsigTimersOnly) 666 if err != nil { 667 return err 668 } 669 _, err = w.writer.Write(data) 670 return err 671 } 672 } 673 data, err = m.Pack() 674 if err != nil { 675 return err 676 } 677 _, err = w.writer.Write(data) 678 return err 679} 680 681// Write implements the ResponseWriter.Write method. 682func (w *response) Write(m []byte) (int, error) { 683 switch { 684 case w.udp != nil: 685 n, err := WriteToSessionUDP(w.udp, m, w.udpSession) 686 return n, err 687 case w.tcp != nil: 688 lm := len(m) 689 if lm < 2 { 690 return 0, io.ErrShortBuffer 691 } 692 if lm > MaxMsgSize { 693 return 0, &Error{err: "message too large"} 694 } 695 l := make([]byte, 2, 2+lm) 696 binary.BigEndian.PutUint16(l, uint16(lm)) 697 m = append(l, m...) 698 699 n, err := io.Copy(w.tcp, bytes.NewReader(m)) 700 return int(n), err 701 } 702 panic("not reached") 703} 704 705// LocalAddr implements the ResponseWriter.LocalAddr method. 706func (w *response) LocalAddr() net.Addr { 707 if w.tcp != nil { 708 return w.tcp.LocalAddr() 709 } 710 return w.udp.LocalAddr() 711} 712 713// RemoteAddr implements the ResponseWriter.RemoteAddr method. 714func (w *response) RemoteAddr() net.Addr { return w.remoteAddr } 715 716// TsigStatus implements the ResponseWriter.TsigStatus method. 717func (w *response) TsigStatus() error { return w.tsigStatus } 718 719// TsigTimersOnly implements the ResponseWriter.TsigTimersOnly method. 720func (w *response) TsigTimersOnly(b bool) { w.tsigTimersOnly = b } 721 722// Hijack implements the ResponseWriter.Hijack method. 723func (w *response) Hijack() { w.hijacked = true } 724 725// Close implements the ResponseWriter.Close method 726func (w *response) Close() error { 727 // Can't close the udp conn, as that is actually the listener. 728 if w.tcp != nil { 729 e := w.tcp.Close() 730 w.tcp = nil 731 return e 732 } 733 return nil 734} 735