1// DNS server implementation.
2
3package dns
4
5import (
6	"bytes"
7	"crypto/tls"
8	"encoding/binary"
9	"io"
10	"net"
11	"sync"
12	"time"
13)
14
15// Maximum number of TCP queries before we close the socket.
16const maxTCPQueries = 128
17
18// Handler is implemented by any value that implements ServeDNS.
19type Handler interface {
20	ServeDNS(w ResponseWriter, r *Msg)
21}
22
23// A ResponseWriter interface is used by an DNS handler to
24// construct an DNS response.
25type ResponseWriter interface {
26	// LocalAddr returns the net.Addr of the server
27	LocalAddr() net.Addr
28	// RemoteAddr returns the net.Addr of the client that sent the current request.
29	RemoteAddr() net.Addr
30	// WriteMsg writes a reply back to the client.
31	WriteMsg(*Msg) error
32	// Write writes a raw buffer back to the client.
33	Write([]byte) (int, error)
34	// Close closes the connection.
35	Close() error
36	// TsigStatus returns the status of the Tsig.
37	TsigStatus() error
38	// TsigTimersOnly sets the tsig timers only boolean.
39	TsigTimersOnly(bool)
40	// Hijack lets the caller take over the connection.
41	// After a call to Hijack(), the DNS package will not do anything with the connection.
42	Hijack()
43}
44
45type response struct {
46	hijacked       bool // connection has been hijacked by handler
47	tsigStatus     error
48	tsigTimersOnly bool
49	tsigRequestMAC string
50	tsigSecret     map[string]string // the tsig secrets
51	udp            *net.UDPConn      // i/o connection if UDP was used
52	tcp            net.Conn          // i/o connection if TCP was used
53	udpSession     *SessionUDP       // oob data to get egress interface right
54	remoteAddr     net.Addr          // address of the client
55	writer         Writer            // writer to output the raw DNS bits
56}
57
58// ServeMux is an DNS request multiplexer. It matches the
59// zone name of each incoming request against a list of
60// registered patterns add calls the handler for the pattern
61// that most closely matches the zone name. ServeMux is DNSSEC aware, meaning
62// that queries for the DS record are redirected to the parent zone (if that
63// is also registered), otherwise the child gets the query.
64// ServeMux is also safe for concurrent access from multiple goroutines.
65type ServeMux struct {
66	z map[string]Handler
67	m *sync.RWMutex
68}
69
70// NewServeMux allocates and returns a new ServeMux.
71func NewServeMux() *ServeMux { return &ServeMux{z: make(map[string]Handler), m: new(sync.RWMutex)} }
72
73// DefaultServeMux is the default ServeMux used by Serve.
74var DefaultServeMux = NewServeMux()
75
76// The HandlerFunc type is an adapter to allow the use of
77// ordinary functions as DNS handlers.  If f is a function
78// with the appropriate signature, HandlerFunc(f) is a
79// Handler object that calls f.
80type HandlerFunc func(ResponseWriter, *Msg)
81
82// ServeDNS calls f(w, r).
83func (f HandlerFunc) ServeDNS(w ResponseWriter, r *Msg) {
84	f(w, r)
85}
86
87// HandleFailed returns a HandlerFunc that returns SERVFAIL for every request it gets.
88func HandleFailed(w ResponseWriter, r *Msg) {
89	m := new(Msg)
90	m.SetRcode(r, RcodeServerFailure)
91	// does not matter if this write fails
92	w.WriteMsg(m)
93}
94
95func failedHandler() Handler { return HandlerFunc(HandleFailed) }
96
97// ListenAndServe Starts a server on address and network specified Invoke handler
98// for incoming queries.
99func ListenAndServe(addr string, network string, handler Handler) error {
100	server := &Server{Addr: addr, Net: network, Handler: handler}
101	return server.ListenAndServe()
102}
103
104// ListenAndServeTLS acts like http.ListenAndServeTLS, more information in
105// http://golang.org/pkg/net/http/#ListenAndServeTLS
106func ListenAndServeTLS(addr, certFile, keyFile string, handler Handler) error {
107	cert, err := tls.LoadX509KeyPair(certFile, keyFile)
108	if err != nil {
109		return err
110	}
111
112	config := tls.Config{
113		Certificates: []tls.Certificate{cert},
114	}
115
116	server := &Server{
117		Addr:      addr,
118		Net:       "tcp-tls",
119		TLSConfig: &config,
120		Handler:   handler,
121	}
122
123	return server.ListenAndServe()
124}
125
126// ActivateAndServe activates a server with a listener from systemd,
127// l and p should not both be non-nil.
128// If both l and p are not nil only p will be used.
129// Invoke handler for incoming queries.
130func ActivateAndServe(l net.Listener, p net.PacketConn, handler Handler) error {
131	server := &Server{Listener: l, PacketConn: p, Handler: handler}
132	return server.ActivateAndServe()
133}
134
135func (mux *ServeMux) match(q string, t uint16) Handler {
136	mux.m.RLock()
137	defer mux.m.RUnlock()
138	var handler Handler
139	b := make([]byte, len(q)) // worst case, one label of length q
140	off := 0
141	end := false
142	for {
143		l := len(q[off:])
144		for i := 0; i < l; i++ {
145			b[i] = q[off+i]
146			if b[i] >= 'A' && b[i] <= 'Z' {
147				b[i] |= ('a' - 'A')
148			}
149		}
150		if h, ok := mux.z[string(b[:l])]; ok { // causes garbage, might want to change the map key
151			if t != TypeDS {
152				return h
153			}
154			// Continue for DS to see if we have a parent too, if so delegeate to the parent
155			handler = h
156		}
157		off, end = NextLabel(q, off)
158		if end {
159			break
160		}
161	}
162	// Wildcard match, if we have found nothing try the root zone as a last resort.
163	if h, ok := mux.z["."]; ok {
164		return h
165	}
166	return handler
167}
168
169// Handle adds a handler to the ServeMux for pattern.
170func (mux *ServeMux) Handle(pattern string, handler Handler) {
171	if pattern == "" {
172		panic("dns: invalid pattern " + pattern)
173	}
174	mux.m.Lock()
175	mux.z[Fqdn(pattern)] = handler
176	mux.m.Unlock()
177}
178
179// HandleFunc adds a handler function to the ServeMux for pattern.
180func (mux *ServeMux) HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) {
181	mux.Handle(pattern, HandlerFunc(handler))
182}
183
184// HandleRemove deregistrars the handler specific for pattern from the ServeMux.
185func (mux *ServeMux) HandleRemove(pattern string) {
186	if pattern == "" {
187		panic("dns: invalid pattern " + pattern)
188	}
189	mux.m.Lock()
190	delete(mux.z, Fqdn(pattern))
191	mux.m.Unlock()
192}
193
194// ServeDNS dispatches the request to the handler whose
195// pattern most closely matches the request message. If DefaultServeMux
196// is used the correct thing for DS queries is done: a possible parent
197// is sought.
198// If no handler is found a standard SERVFAIL message is returned
199// If the request message does not have exactly one question in the
200// question section a SERVFAIL is returned, unlesss Unsafe is true.
201func (mux *ServeMux) ServeDNS(w ResponseWriter, request *Msg) {
202	var h Handler
203	if len(request.Question) < 1 { // allow more than one question
204		h = failedHandler()
205	} else {
206		if h = mux.match(request.Question[0].Name, request.Question[0].Qtype); h == nil {
207			h = failedHandler()
208		}
209	}
210	h.ServeDNS(w, request)
211}
212
213// Handle registers the handler with the given pattern
214// in the DefaultServeMux. The documentation for
215// ServeMux explains how patterns are matched.
216func Handle(pattern string, handler Handler) { DefaultServeMux.Handle(pattern, handler) }
217
218// HandleRemove deregisters the handle with the given pattern
219// in the DefaultServeMux.
220func HandleRemove(pattern string) { DefaultServeMux.HandleRemove(pattern) }
221
222// HandleFunc registers the handler function with the given pattern
223// in the DefaultServeMux.
224func HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) {
225	DefaultServeMux.HandleFunc(pattern, handler)
226}
227
228// Writer writes raw DNS messages; each call to Write should send an entire message.
229type Writer interface {
230	io.Writer
231}
232
233// Reader reads raw DNS messages; each call to ReadTCP or ReadUDP should return an entire message.
234type Reader interface {
235	// ReadTCP reads a raw message from a TCP connection. Implementations may alter
236	// connection properties, for example the read-deadline.
237	ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error)
238	// ReadUDP reads a raw message from a UDP connection. Implementations may alter
239	// connection properties, for example the read-deadline.
240	ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error)
241}
242
243// defaultReader is an adapter for the Server struct that implements the Reader interface
244// using the readTCP and readUDP func of the embedded Server.
245type defaultReader struct {
246	*Server
247}
248
249func (dr *defaultReader) ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
250	return dr.readTCP(conn, timeout)
251}
252
253func (dr *defaultReader) ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) {
254	return dr.readUDP(conn, timeout)
255}
256
257// DecorateReader is a decorator hook for extending or supplanting the functionality of a Reader.
258// Implementations should never return a nil Reader.
259type DecorateReader func(Reader) Reader
260
261// DecorateWriter is a decorator hook for extending or supplanting the functionality of a Writer.
262// Implementations should never return a nil Writer.
263type DecorateWriter func(Writer) Writer
264
265// A Server defines parameters for running an DNS server.
266type Server struct {
267	// Address to listen on, ":dns" if empty.
268	Addr string
269	// if "tcp" or "tcp-tls" (DNS over TLS) it will invoke a TCP listener, otherwise an UDP one
270	Net string
271	// TCP Listener to use, this is to aid in systemd's socket activation.
272	Listener net.Listener
273	// TLS connection configuration
274	TLSConfig *tls.Config
275	// UDP "Listener" to use, this is to aid in systemd's socket activation.
276	PacketConn net.PacketConn
277	// Handler to invoke, dns.DefaultServeMux if nil.
278	Handler Handler
279	// Default buffer size to use to read incoming UDP messages. If not set
280	// it defaults to MinMsgSize (512 B).
281	UDPSize int
282	// The net.Conn.SetReadTimeout value for new connections, defaults to 2 * time.Second.
283	ReadTimeout time.Duration
284	// The net.Conn.SetWriteTimeout value for new connections, defaults to 2 * time.Second.
285	WriteTimeout time.Duration
286	// TCP idle timeout for multiple queries, if nil, defaults to 8 * time.Second (RFC 5966).
287	IdleTimeout func() time.Duration
288	// Secret(s) for Tsig map[<zonename>]<base64 secret>.
289	TsigSecret map[string]string
290	// Unsafe instructs the server to disregard any sanity checks and directly hand the message to
291	// the handler. It will specifically not check if the query has the QR bit not set.
292	Unsafe bool
293	// If NotifyStartedFunc is set it is called once the server has started listening.
294	NotifyStartedFunc func()
295	// DecorateReader is optional, allows customization of the process that reads raw DNS messages.
296	DecorateReader DecorateReader
297	// DecorateWriter is optional, allows customization of the process that writes raw DNS messages.
298	DecorateWriter DecorateWriter
299
300	// Graceful shutdown handling
301
302	inFlight sync.WaitGroup
303
304	lock    sync.RWMutex
305	started bool
306}
307
308// ListenAndServe starts a nameserver on the configured address in *Server.
309func (srv *Server) ListenAndServe() error {
310	srv.lock.Lock()
311	defer srv.lock.Unlock()
312	if srv.started {
313		return &Error{err: "server already started"}
314	}
315	addr := srv.Addr
316	if addr == "" {
317		addr = ":domain"
318	}
319	if srv.UDPSize == 0 {
320		srv.UDPSize = MinMsgSize
321	}
322	switch srv.Net {
323	case "tcp", "tcp4", "tcp6":
324		a, err := net.ResolveTCPAddr(srv.Net, addr)
325		if err != nil {
326			return err
327		}
328		l, err := net.ListenTCP(srv.Net, a)
329		if err != nil {
330			return err
331		}
332		srv.Listener = l
333		srv.started = true
334		srv.lock.Unlock()
335		err = srv.serveTCP(l)
336		srv.lock.Lock() // to satisfy the defer at the top
337		return err
338	case "tcp-tls", "tcp4-tls", "tcp6-tls":
339		network := "tcp"
340		if srv.Net == "tcp4-tls" {
341			network = "tcp4"
342		} else if srv.Net == "tcp6" {
343			network = "tcp6"
344		}
345
346		l, err := tls.Listen(network, addr, srv.TLSConfig)
347		if err != nil {
348			return err
349		}
350		srv.Listener = l
351		srv.started = true
352		srv.lock.Unlock()
353		err = srv.serveTCP(l)
354		srv.lock.Lock() // to satisfy the defer at the top
355		return err
356	case "udp", "udp4", "udp6":
357		a, err := net.ResolveUDPAddr(srv.Net, addr)
358		if err != nil {
359			return err
360		}
361		l, err := net.ListenUDP(srv.Net, a)
362		if err != nil {
363			return err
364		}
365		if e := setUDPSocketOptions(l); e != nil {
366			return e
367		}
368		srv.PacketConn = l
369		srv.started = true
370		srv.lock.Unlock()
371		err = srv.serveUDP(l)
372		srv.lock.Lock() // to satisfy the defer at the top
373		return err
374	}
375	return &Error{err: "bad network"}
376}
377
378// ActivateAndServe starts a nameserver with the PacketConn or Listener
379// configured in *Server. Its main use is to start a server from systemd.
380func (srv *Server) ActivateAndServe() error {
381	srv.lock.Lock()
382	defer srv.lock.Unlock()
383	if srv.started {
384		return &Error{err: "server already started"}
385	}
386	pConn := srv.PacketConn
387	l := srv.Listener
388	if pConn != nil {
389		if srv.UDPSize == 0 {
390			srv.UDPSize = MinMsgSize
391		}
392		// Check PacketConn interface's type is valid and value
393		// is not nil
394		if t, ok := pConn.(*net.UDPConn); ok && t != nil {
395			if e := setUDPSocketOptions(t); e != nil {
396				return e
397			}
398			srv.started = true
399			srv.lock.Unlock()
400			e := srv.serveUDP(t)
401			srv.lock.Lock() // to satisfy the defer at the top
402			return e
403		}
404	}
405	if l != nil {
406		srv.started = true
407		srv.lock.Unlock()
408		e := srv.serveTCP(l)
409		srv.lock.Lock() // to satisfy the defer at the top
410		return e
411	}
412	return &Error{err: "bad listeners"}
413}
414
415// Shutdown gracefully shuts down a server. After a call to Shutdown, ListenAndServe and
416// ActivateAndServe will return. All in progress queries are completed before the server
417// is taken down. If the Shutdown is taking longer than the reading timeout an error
418// is returned.
419func (srv *Server) Shutdown() error {
420	srv.lock.Lock()
421	if !srv.started {
422		srv.lock.Unlock()
423		return &Error{err: "server not started"}
424	}
425	srv.started = false
426	srv.lock.Unlock()
427
428	if srv.PacketConn != nil {
429		srv.PacketConn.Close()
430	}
431	if srv.Listener != nil {
432		srv.Listener.Close()
433	}
434
435	fin := make(chan bool)
436	go func() {
437		srv.inFlight.Wait()
438		fin <- true
439	}()
440
441	select {
442	case <-time.After(srv.getReadTimeout()):
443		return &Error{err: "server shutdown is pending"}
444	case <-fin:
445		return nil
446	}
447}
448
449// getReadTimeout is a helper func to use system timeout if server did not intend to change it.
450func (srv *Server) getReadTimeout() time.Duration {
451	rtimeout := dnsTimeout
452	if srv.ReadTimeout != 0 {
453		rtimeout = srv.ReadTimeout
454	}
455	return rtimeout
456}
457
458// serveTCP starts a TCP listener for the server.
459// Each request is handled in a separate goroutine.
460func (srv *Server) serveTCP(l net.Listener) error {
461	defer l.Close()
462
463	if srv.NotifyStartedFunc != nil {
464		srv.NotifyStartedFunc()
465	}
466
467	reader := Reader(&defaultReader{srv})
468	if srv.DecorateReader != nil {
469		reader = srv.DecorateReader(reader)
470	}
471
472	handler := srv.Handler
473	if handler == nil {
474		handler = DefaultServeMux
475	}
476	rtimeout := srv.getReadTimeout()
477	// deadline is not used here
478	for {
479		rw, err := l.Accept()
480		if err != nil {
481			if neterr, ok := err.(net.Error); ok && neterr.Temporary() {
482				continue
483			}
484			return err
485		}
486		m, err := reader.ReadTCP(rw, rtimeout)
487		srv.lock.RLock()
488		if !srv.started {
489			srv.lock.RUnlock()
490			return nil
491		}
492		srv.lock.RUnlock()
493		if err != nil {
494			continue
495		}
496		srv.inFlight.Add(1)
497		go srv.serve(rw.RemoteAddr(), handler, m, nil, nil, rw)
498	}
499}
500
501// serveUDP starts a UDP listener for the server.
502// Each request is handled in a separate goroutine.
503func (srv *Server) serveUDP(l *net.UDPConn) error {
504	defer l.Close()
505
506	if srv.NotifyStartedFunc != nil {
507		srv.NotifyStartedFunc()
508	}
509
510	reader := Reader(&defaultReader{srv})
511	if srv.DecorateReader != nil {
512		reader = srv.DecorateReader(reader)
513	}
514
515	handler := srv.Handler
516	if handler == nil {
517		handler = DefaultServeMux
518	}
519	rtimeout := srv.getReadTimeout()
520	// deadline is not used here
521	for {
522		m, s, err := reader.ReadUDP(l, rtimeout)
523		srv.lock.RLock()
524		if !srv.started {
525			srv.lock.RUnlock()
526			return nil
527		}
528		srv.lock.RUnlock()
529		if err != nil {
530			continue
531		}
532		srv.inFlight.Add(1)
533		go srv.serve(s.RemoteAddr(), handler, m, l, s, nil)
534	}
535}
536
537// Serve a new connection.
538func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *SessionUDP, t net.Conn) {
539	defer srv.inFlight.Done()
540
541	w := &response{tsigSecret: srv.TsigSecret, udp: u, tcp: t, remoteAddr: a, udpSession: s}
542	if srv.DecorateWriter != nil {
543		w.writer = srv.DecorateWriter(w)
544	} else {
545		w.writer = w
546	}
547
548	q := 0 // counter for the amount of TCP queries we get
549
550	reader := Reader(&defaultReader{srv})
551	if srv.DecorateReader != nil {
552		reader = srv.DecorateReader(reader)
553	}
554Redo:
555	req := new(Msg)
556	err := req.Unpack(m)
557	if err != nil { // Send a FormatError back
558		x := new(Msg)
559		x.SetRcodeFormatError(req)
560		w.WriteMsg(x)
561		goto Exit
562	}
563	if !srv.Unsafe && req.Response {
564		goto Exit
565	}
566
567	w.tsigStatus = nil
568	if w.tsigSecret != nil {
569		if t := req.IsTsig(); t != nil {
570			secret := t.Hdr.Name
571			if _, ok := w.tsigSecret[secret]; !ok {
572				w.tsigStatus = ErrKeyAlg
573			}
574			w.tsigStatus = TsigVerify(m, w.tsigSecret[secret], "", false)
575			w.tsigTimersOnly = false
576			w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC
577		}
578	}
579	h.ServeDNS(w, req) // Writes back to the client
580
581Exit:
582	if w.tcp == nil {
583		return
584	}
585	// TODO(miek): make this number configurable?
586	if q > maxTCPQueries { // close socket after this many queries
587		w.Close()
588		return
589	}
590
591	if w.hijacked {
592		return // client calls Close()
593	}
594	if u != nil { // UDP, "close" and return
595		w.Close()
596		return
597	}
598	idleTimeout := tcpIdleTimeout
599	if srv.IdleTimeout != nil {
600		idleTimeout = srv.IdleTimeout()
601	}
602	m, err = reader.ReadTCP(w.tcp, idleTimeout)
603	if err == nil {
604		q++
605		goto Redo
606	}
607	w.Close()
608	return
609}
610
611func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
612	conn.SetReadDeadline(time.Now().Add(timeout))
613	l := make([]byte, 2)
614	n, err := conn.Read(l)
615	if err != nil || n != 2 {
616		if err != nil {
617			return nil, err
618		}
619		return nil, ErrShortRead
620	}
621	length := binary.BigEndian.Uint16(l)
622	if length == 0 {
623		return nil, ErrShortRead
624	}
625	m := make([]byte, int(length))
626	n, err = conn.Read(m[:int(length)])
627	if err != nil || n == 0 {
628		if err != nil {
629			return nil, err
630		}
631		return nil, ErrShortRead
632	}
633	i := n
634	for i < int(length) {
635		j, err := conn.Read(m[i:int(length)])
636		if err != nil {
637			return nil, err
638		}
639		i += j
640	}
641	n = i
642	m = m[:n]
643	return m, nil
644}
645
646func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) {
647	conn.SetReadDeadline(time.Now().Add(timeout))
648	m := make([]byte, srv.UDPSize)
649	n, s, err := ReadFromSessionUDP(conn, m)
650	if err != nil || n == 0 {
651		if err != nil {
652			return nil, nil, err
653		}
654		return nil, nil, ErrShortRead
655	}
656	m = m[:n]
657	return m, s, nil
658}
659
660// WriteMsg implements the ResponseWriter.WriteMsg method.
661func (w *response) WriteMsg(m *Msg) (err error) {
662	var data []byte
663	if w.tsigSecret != nil { // if no secrets, dont check for the tsig (which is a longer check)
664		if t := m.IsTsig(); t != nil {
665			data, w.tsigRequestMAC, err = TsigGenerate(m, w.tsigSecret[t.Hdr.Name], w.tsigRequestMAC, w.tsigTimersOnly)
666			if err != nil {
667				return err
668			}
669			_, err = w.writer.Write(data)
670			return err
671		}
672	}
673	data, err = m.Pack()
674	if err != nil {
675		return err
676	}
677	_, err = w.writer.Write(data)
678	return err
679}
680
681// Write implements the ResponseWriter.Write method.
682func (w *response) Write(m []byte) (int, error) {
683	switch {
684	case w.udp != nil:
685		n, err := WriteToSessionUDP(w.udp, m, w.udpSession)
686		return n, err
687	case w.tcp != nil:
688		lm := len(m)
689		if lm < 2 {
690			return 0, io.ErrShortBuffer
691		}
692		if lm > MaxMsgSize {
693			return 0, &Error{err: "message too large"}
694		}
695		l := make([]byte, 2, 2+lm)
696		binary.BigEndian.PutUint16(l, uint16(lm))
697		m = append(l, m...)
698
699		n, err := io.Copy(w.tcp, bytes.NewReader(m))
700		return int(n), err
701	}
702	panic("not reached")
703}
704
705// LocalAddr implements the ResponseWriter.LocalAddr method.
706func (w *response) LocalAddr() net.Addr {
707	if w.tcp != nil {
708		return w.tcp.LocalAddr()
709	}
710	return w.udp.LocalAddr()
711}
712
713// RemoteAddr implements the ResponseWriter.RemoteAddr method.
714func (w *response) RemoteAddr() net.Addr { return w.remoteAddr }
715
716// TsigStatus implements the ResponseWriter.TsigStatus method.
717func (w *response) TsigStatus() error { return w.tsigStatus }
718
719// TsigTimersOnly implements the ResponseWriter.TsigTimersOnly method.
720func (w *response) TsigTimersOnly(b bool) { w.tsigTimersOnly = b }
721
722// Hijack implements the ResponseWriter.Hijack method.
723func (w *response) Hijack() { w.hijacked = true }
724
725// Close implements the ResponseWriter.Close method
726func (w *response) Close() error {
727	// Can't close the udp conn, as that is actually the listener.
728	if w.tcp != nil {
729		e := w.tcp.Close()
730		w.tcp = nil
731		return e
732	}
733	return nil
734}
735