1// DNS server implementation.
2
3package dns
4
5import (
6	"context"
7	"crypto/tls"
8	"encoding/binary"
9	"errors"
10	"io"
11	"net"
12	"strings"
13	"sync"
14	"time"
15)
16
17// Default maximum number of TCP queries before we close the socket.
18const maxTCPQueries = 128
19
20// aLongTimeAgo is a non-zero time, far in the past, used for
21// immediate cancelation of network operations.
22var aLongTimeAgo = time.Unix(1, 0)
23
24// Handler is implemented by any value that implements ServeDNS.
25type Handler interface {
26	ServeDNS(w ResponseWriter, r *Msg)
27}
28
29// The HandlerFunc type is an adapter to allow the use of
30// ordinary functions as DNS handlers.  If f is a function
31// with the appropriate signature, HandlerFunc(f) is a
32// Handler object that calls f.
33type HandlerFunc func(ResponseWriter, *Msg)
34
35// ServeDNS calls f(w, r).
36func (f HandlerFunc) ServeDNS(w ResponseWriter, r *Msg) {
37	f(w, r)
38}
39
40// A ResponseWriter interface is used by an DNS handler to
41// construct an DNS response.
42type ResponseWriter interface {
43	// LocalAddr returns the net.Addr of the server
44	LocalAddr() net.Addr
45	// RemoteAddr returns the net.Addr of the client that sent the current request.
46	RemoteAddr() net.Addr
47	// WriteMsg writes a reply back to the client.
48	WriteMsg(*Msg) error
49	// Write writes a raw buffer back to the client.
50	Write([]byte) (int, error)
51	// Close closes the connection.
52	Close() error
53	// TsigStatus returns the status of the Tsig.
54	TsigStatus() error
55	// TsigTimersOnly sets the tsig timers only boolean.
56	TsigTimersOnly(bool)
57	// Hijack lets the caller take over the connection.
58	// After a call to Hijack(), the DNS package will not do anything with the connection.
59	Hijack()
60}
61
62// A ConnectionStater interface is used by a DNS Handler to access TLS connection state
63// when available.
64type ConnectionStater interface {
65	ConnectionState() *tls.ConnectionState
66}
67
68type response struct {
69	closed         bool // connection has been closed
70	hijacked       bool // connection has been hijacked by handler
71	tsigTimersOnly bool
72	tsigStatus     error
73	tsigRequestMAC string
74	tsigSecret     map[string]string // the tsig secrets
75	udp            *net.UDPConn      // i/o connection if UDP was used
76	tcp            net.Conn          // i/o connection if TCP was used
77	udpSession     *SessionUDP       // oob data to get egress interface right
78	writer         Writer            // writer to output the raw DNS bits
79}
80
81// HandleFailed returns a HandlerFunc that returns SERVFAIL for every request it gets.
82func HandleFailed(w ResponseWriter, r *Msg) {
83	m := new(Msg)
84	m.SetRcode(r, RcodeServerFailure)
85	// does not matter if this write fails
86	w.WriteMsg(m)
87}
88
89// ListenAndServe Starts a server on address and network specified Invoke handler
90// for incoming queries.
91func ListenAndServe(addr string, network string, handler Handler) error {
92	server := &Server{Addr: addr, Net: network, Handler: handler}
93	return server.ListenAndServe()
94}
95
96// ListenAndServeTLS acts like http.ListenAndServeTLS, more information in
97// http://golang.org/pkg/net/http/#ListenAndServeTLS
98func ListenAndServeTLS(addr, certFile, keyFile string, handler Handler) error {
99	cert, err := tls.LoadX509KeyPair(certFile, keyFile)
100	if err != nil {
101		return err
102	}
103
104	config := tls.Config{
105		Certificates: []tls.Certificate{cert},
106	}
107
108	server := &Server{
109		Addr:      addr,
110		Net:       "tcp-tls",
111		TLSConfig: &config,
112		Handler:   handler,
113	}
114
115	return server.ListenAndServe()
116}
117
118// ActivateAndServe activates a server with a listener from systemd,
119// l and p should not both be non-nil.
120// If both l and p are not nil only p will be used.
121// Invoke handler for incoming queries.
122func ActivateAndServe(l net.Listener, p net.PacketConn, handler Handler) error {
123	server := &Server{Listener: l, PacketConn: p, Handler: handler}
124	return server.ActivateAndServe()
125}
126
127// Writer writes raw DNS messages; each call to Write should send an entire message.
128type Writer interface {
129	io.Writer
130}
131
132// Reader reads raw DNS messages; each call to ReadTCP or ReadUDP should return an entire message.
133type Reader interface {
134	// ReadTCP reads a raw message from a TCP connection. Implementations may alter
135	// connection properties, for example the read-deadline.
136	ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error)
137	// ReadUDP reads a raw message from a UDP connection. Implementations may alter
138	// connection properties, for example the read-deadline.
139	ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error)
140}
141
142// defaultReader is an adapter for the Server struct that implements the Reader interface
143// using the readTCP and readUDP func of the embedded Server.
144type defaultReader struct {
145	*Server
146}
147
148func (dr defaultReader) ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
149	return dr.readTCP(conn, timeout)
150}
151
152func (dr defaultReader) ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) {
153	return dr.readUDP(conn, timeout)
154}
155
156// DecorateReader is a decorator hook for extending or supplanting the functionality of a Reader.
157// Implementations should never return a nil Reader.
158type DecorateReader func(Reader) Reader
159
160// DecorateWriter is a decorator hook for extending or supplanting the functionality of a Writer.
161// Implementations should never return a nil Writer.
162type DecorateWriter func(Writer) Writer
163
164// A Server defines parameters for running an DNS server.
165type Server struct {
166	// Address to listen on, ":dns" if empty.
167	Addr string
168	// if "tcp" or "tcp-tls" (DNS over TLS) it will invoke a TCP listener, otherwise an UDP one
169	Net string
170	// TCP Listener to use, this is to aid in systemd's socket activation.
171	Listener net.Listener
172	// TLS connection configuration
173	TLSConfig *tls.Config
174	// UDP "Listener" to use, this is to aid in systemd's socket activation.
175	PacketConn net.PacketConn
176	// Handler to invoke, dns.DefaultServeMux if nil.
177	Handler Handler
178	// Default buffer size to use to read incoming UDP messages. If not set
179	// it defaults to MinMsgSize (512 B).
180	UDPSize int
181	// The net.Conn.SetReadTimeout value for new connections, defaults to 2 * time.Second.
182	ReadTimeout time.Duration
183	// The net.Conn.SetWriteTimeout value for new connections, defaults to 2 * time.Second.
184	WriteTimeout time.Duration
185	// TCP idle timeout for multiple queries, if nil, defaults to 8 * time.Second (RFC 5966).
186	IdleTimeout func() time.Duration
187	// Secret(s) for Tsig map[<zonename>]<base64 secret>. The zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2).
188	TsigSecret map[string]string
189	// If NotifyStartedFunc is set it is called once the server has started listening.
190	NotifyStartedFunc func()
191	// DecorateReader is optional, allows customization of the process that reads raw DNS messages.
192	DecorateReader DecorateReader
193	// DecorateWriter is optional, allows customization of the process that writes raw DNS messages.
194	DecorateWriter DecorateWriter
195	// Maximum number of TCP queries before we close the socket. Default is maxTCPQueries (unlimited if -1).
196	MaxTCPQueries int
197	// Whether to set the SO_REUSEPORT socket option, allowing multiple listeners to be bound to a single address.
198	// It is only supported on go1.11+ and when using ListenAndServe.
199	ReusePort bool
200	// AcceptMsgFunc will check the incoming message and will reject it early in the process.
201	// By default DefaultMsgAcceptFunc will be used.
202	MsgAcceptFunc MsgAcceptFunc
203
204	// Shutdown handling
205	lock     sync.RWMutex
206	started  bool
207	shutdown chan struct{}
208	conns    map[net.Conn]struct{}
209
210	// A pool for UDP message buffers.
211	udpPool sync.Pool
212}
213
214func (srv *Server) isStarted() bool {
215	srv.lock.RLock()
216	started := srv.started
217	srv.lock.RUnlock()
218	return started
219}
220
221func makeUDPBuffer(size int) func() interface{} {
222	return func() interface{} {
223		return make([]byte, size)
224	}
225}
226
227func (srv *Server) init() {
228	srv.shutdown = make(chan struct{})
229	srv.conns = make(map[net.Conn]struct{})
230
231	if srv.UDPSize == 0 {
232		srv.UDPSize = MinMsgSize
233	}
234	if srv.MsgAcceptFunc == nil {
235		srv.MsgAcceptFunc = DefaultMsgAcceptFunc
236	}
237	if srv.Handler == nil {
238		srv.Handler = DefaultServeMux
239	}
240
241	srv.udpPool.New = makeUDPBuffer(srv.UDPSize)
242}
243
244func unlockOnce(l sync.Locker) func() {
245	var once sync.Once
246	return func() { once.Do(l.Unlock) }
247}
248
249// ListenAndServe starts a nameserver on the configured address in *Server.
250func (srv *Server) ListenAndServe() error {
251	unlock := unlockOnce(&srv.lock)
252	srv.lock.Lock()
253	defer unlock()
254
255	if srv.started {
256		return &Error{err: "server already started"}
257	}
258
259	addr := srv.Addr
260	if addr == "" {
261		addr = ":domain"
262	}
263
264	srv.init()
265
266	switch srv.Net {
267	case "tcp", "tcp4", "tcp6":
268		l, err := listenTCP(srv.Net, addr, srv.ReusePort)
269		if err != nil {
270			return err
271		}
272		srv.Listener = l
273		srv.started = true
274		unlock()
275		return srv.serveTCP(l)
276	case "tcp-tls", "tcp4-tls", "tcp6-tls":
277		if srv.TLSConfig == nil || (len(srv.TLSConfig.Certificates) == 0 && srv.TLSConfig.GetCertificate == nil) {
278			return errors.New("dns: neither Certificates nor GetCertificate set in Config")
279		}
280		network := strings.TrimSuffix(srv.Net, "-tls")
281		l, err := listenTCP(network, addr, srv.ReusePort)
282		if err != nil {
283			return err
284		}
285		l = tls.NewListener(l, srv.TLSConfig)
286		srv.Listener = l
287		srv.started = true
288		unlock()
289		return srv.serveTCP(l)
290	case "udp", "udp4", "udp6":
291		l, err := listenUDP(srv.Net, addr, srv.ReusePort)
292		if err != nil {
293			return err
294		}
295		u := l.(*net.UDPConn)
296		if e := setUDPSocketOptions(u); e != nil {
297			return e
298		}
299		srv.PacketConn = l
300		srv.started = true
301		unlock()
302		return srv.serveUDP(u)
303	}
304	return &Error{err: "bad network"}
305}
306
307// ActivateAndServe starts a nameserver with the PacketConn or Listener
308// configured in *Server. Its main use is to start a server from systemd.
309func (srv *Server) ActivateAndServe() error {
310	unlock := unlockOnce(&srv.lock)
311	srv.lock.Lock()
312	defer unlock()
313
314	if srv.started {
315		return &Error{err: "server already started"}
316	}
317
318	srv.init()
319
320	pConn := srv.PacketConn
321	l := srv.Listener
322	if pConn != nil {
323		// Check PacketConn interface's type is valid and value
324		// is not nil
325		if t, ok := pConn.(*net.UDPConn); ok && t != nil {
326			if e := setUDPSocketOptions(t); e != nil {
327				return e
328			}
329			srv.started = true
330			unlock()
331			return srv.serveUDP(t)
332		}
333	}
334	if l != nil {
335		srv.started = true
336		unlock()
337		return srv.serveTCP(l)
338	}
339	return &Error{err: "bad listeners"}
340}
341
342// Shutdown shuts down a server. After a call to Shutdown, ListenAndServe and
343// ActivateAndServe will return.
344func (srv *Server) Shutdown() error {
345	return srv.ShutdownContext(context.Background())
346}
347
348// ShutdownContext shuts down a server. After a call to ShutdownContext,
349// ListenAndServe and ActivateAndServe will return.
350//
351// A context.Context may be passed to limit how long to wait for connections
352// to terminate.
353func (srv *Server) ShutdownContext(ctx context.Context) error {
354	srv.lock.Lock()
355	if !srv.started {
356		srv.lock.Unlock()
357		return &Error{err: "server not started"}
358	}
359
360	srv.started = false
361
362	if srv.PacketConn != nil {
363		srv.PacketConn.SetReadDeadline(aLongTimeAgo) // Unblock reads
364	}
365
366	if srv.Listener != nil {
367		srv.Listener.Close()
368	}
369
370	for rw := range srv.conns {
371		rw.SetReadDeadline(aLongTimeAgo) // Unblock reads
372	}
373
374	srv.lock.Unlock()
375
376	if testShutdownNotify != nil {
377		testShutdownNotify.Broadcast()
378	}
379
380	var ctxErr error
381	select {
382	case <-srv.shutdown:
383	case <-ctx.Done():
384		ctxErr = ctx.Err()
385	}
386
387	if srv.PacketConn != nil {
388		srv.PacketConn.Close()
389	}
390
391	return ctxErr
392}
393
394var testShutdownNotify *sync.Cond
395
396// getReadTimeout is a helper func to use system timeout if server did not intend to change it.
397func (srv *Server) getReadTimeout() time.Duration {
398	if srv.ReadTimeout != 0 {
399		return srv.ReadTimeout
400	}
401	return dnsTimeout
402}
403
404// serveTCP starts a TCP listener for the server.
405func (srv *Server) serveTCP(l net.Listener) error {
406	defer l.Close()
407
408	if srv.NotifyStartedFunc != nil {
409		srv.NotifyStartedFunc()
410	}
411
412	var wg sync.WaitGroup
413	defer func() {
414		wg.Wait()
415		close(srv.shutdown)
416	}()
417
418	for srv.isStarted() {
419		rw, err := l.Accept()
420		if err != nil {
421			if !srv.isStarted() {
422				return nil
423			}
424			if neterr, ok := err.(net.Error); ok && neterr.Temporary() {
425				continue
426			}
427			return err
428		}
429		srv.lock.Lock()
430		// Track the connection to allow unblocking reads on shutdown.
431		srv.conns[rw] = struct{}{}
432		srv.lock.Unlock()
433		wg.Add(1)
434		go srv.serveTCPConn(&wg, rw)
435	}
436
437	return nil
438}
439
440// serveUDP starts a UDP listener for the server.
441func (srv *Server) serveUDP(l *net.UDPConn) error {
442	defer l.Close()
443
444	if srv.NotifyStartedFunc != nil {
445		srv.NotifyStartedFunc()
446	}
447
448	reader := Reader(defaultReader{srv})
449	if srv.DecorateReader != nil {
450		reader = srv.DecorateReader(reader)
451	}
452
453	var wg sync.WaitGroup
454	defer func() {
455		wg.Wait()
456		close(srv.shutdown)
457	}()
458
459	rtimeout := srv.getReadTimeout()
460	// deadline is not used here
461	for srv.isStarted() {
462		m, s, err := reader.ReadUDP(l, rtimeout)
463		if err != nil {
464			if !srv.isStarted() {
465				return nil
466			}
467			if netErr, ok := err.(net.Error); ok && netErr.Temporary() {
468				continue
469			}
470			return err
471		}
472		if len(m) < headerSize {
473			if cap(m) == srv.UDPSize {
474				srv.udpPool.Put(m[:srv.UDPSize])
475			}
476			continue
477		}
478		wg.Add(1)
479		go srv.serveUDPPacket(&wg, m, l, s)
480	}
481
482	return nil
483}
484
485// Serve a new TCP connection.
486func (srv *Server) serveTCPConn(wg *sync.WaitGroup, rw net.Conn) {
487	w := &response{tsigSecret: srv.TsigSecret, tcp: rw}
488	if srv.DecorateWriter != nil {
489		w.writer = srv.DecorateWriter(w)
490	} else {
491		w.writer = w
492	}
493
494	reader := Reader(defaultReader{srv})
495	if srv.DecorateReader != nil {
496		reader = srv.DecorateReader(reader)
497	}
498
499	idleTimeout := tcpIdleTimeout
500	if srv.IdleTimeout != nil {
501		idleTimeout = srv.IdleTimeout()
502	}
503
504	timeout := srv.getReadTimeout()
505
506	limit := srv.MaxTCPQueries
507	if limit == 0 {
508		limit = maxTCPQueries
509	}
510
511	for q := 0; (q < limit || limit == -1) && srv.isStarted(); q++ {
512		m, err := reader.ReadTCP(w.tcp, timeout)
513		if err != nil {
514			// TODO(tmthrgd): handle error
515			break
516		}
517		srv.serveDNS(m, w)
518		if w.closed {
519			break // Close() was called
520		}
521		if w.hijacked {
522			break // client will call Close() themselves
523		}
524		// The first read uses the read timeout, the rest use the
525		// idle timeout.
526		timeout = idleTimeout
527	}
528
529	if !w.hijacked {
530		w.Close()
531	}
532
533	srv.lock.Lock()
534	delete(srv.conns, w.tcp)
535	srv.lock.Unlock()
536
537	wg.Done()
538}
539
540// Serve a new UDP request.
541func (srv *Server) serveUDPPacket(wg *sync.WaitGroup, m []byte, u *net.UDPConn, s *SessionUDP) {
542	w := &response{tsigSecret: srv.TsigSecret, udp: u, udpSession: s}
543	if srv.DecorateWriter != nil {
544		w.writer = srv.DecorateWriter(w)
545	} else {
546		w.writer = w
547	}
548
549	srv.serveDNS(m, w)
550	wg.Done()
551}
552
553func (srv *Server) serveDNS(m []byte, w *response) {
554	dh, off, err := unpackMsgHdr(m, 0)
555	if err != nil {
556		// Let client hang, they are sending crap; any reply can be used to amplify.
557		return
558	}
559
560	req := new(Msg)
561	req.setHdr(dh)
562
563	switch action := srv.MsgAcceptFunc(dh); action {
564	case MsgAccept:
565		if req.unpack(dh, m, off) == nil {
566			break
567		}
568
569		fallthrough
570	case MsgReject, MsgRejectNotImplemented:
571		opcode := req.Opcode
572		req.SetRcodeFormatError(req)
573		req.Zero = false
574		if action == MsgRejectNotImplemented {
575			req.Opcode = opcode
576			req.Rcode = RcodeNotImplemented
577		}
578
579		// Are we allowed to delete any OPT records here?
580		req.Ns, req.Answer, req.Extra = nil, nil, nil
581
582		w.WriteMsg(req)
583		fallthrough
584	case MsgIgnore:
585		if w.udp != nil && cap(m) == srv.UDPSize {
586			srv.udpPool.Put(m[:srv.UDPSize])
587		}
588
589		return
590	}
591
592	w.tsigStatus = nil
593	if w.tsigSecret != nil {
594		if t := req.IsTsig(); t != nil {
595			if secret, ok := w.tsigSecret[t.Hdr.Name]; ok {
596				w.tsigStatus = TsigVerify(m, secret, "", false)
597			} else {
598				w.tsigStatus = ErrSecret
599			}
600			w.tsigTimersOnly = false
601			w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC
602		}
603	}
604
605	if w.udp != nil && cap(m) == srv.UDPSize {
606		srv.udpPool.Put(m[:srv.UDPSize])
607	}
608
609	srv.Handler.ServeDNS(w, req) // Writes back to the client
610}
611
612func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
613	// If we race with ShutdownContext, the read deadline may
614	// have been set in the distant past to unblock the read
615	// below. We must not override it, otherwise we may block
616	// ShutdownContext.
617	srv.lock.RLock()
618	if srv.started {
619		conn.SetReadDeadline(time.Now().Add(timeout))
620	}
621	srv.lock.RUnlock()
622
623	var length uint16
624	if err := binary.Read(conn, binary.BigEndian, &length); err != nil {
625		return nil, err
626	}
627
628	m := make([]byte, length)
629	if _, err := io.ReadFull(conn, m); err != nil {
630		return nil, err
631	}
632
633	return m, nil
634}
635
636func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) {
637	srv.lock.RLock()
638	if srv.started {
639		// See the comment in readTCP above.
640		conn.SetReadDeadline(time.Now().Add(timeout))
641	}
642	srv.lock.RUnlock()
643
644	m := srv.udpPool.Get().([]byte)
645	n, s, err := ReadFromSessionUDP(conn, m)
646	if err != nil {
647		srv.udpPool.Put(m)
648		return nil, nil, err
649	}
650	m = m[:n]
651	return m, s, nil
652}
653
654// WriteMsg implements the ResponseWriter.WriteMsg method.
655func (w *response) WriteMsg(m *Msg) (err error) {
656	if w.closed {
657		return &Error{err: "WriteMsg called after Close"}
658	}
659
660	var data []byte
661	if w.tsigSecret != nil { // if no secrets, dont check for the tsig (which is a longer check)
662		if t := m.IsTsig(); t != nil {
663			data, w.tsigRequestMAC, err = TsigGenerate(m, w.tsigSecret[t.Hdr.Name], w.tsigRequestMAC, w.tsigTimersOnly)
664			if err != nil {
665				return err
666			}
667			_, err = w.writer.Write(data)
668			return err
669		}
670	}
671	data, err = m.Pack()
672	if err != nil {
673		return err
674	}
675	_, err = w.writer.Write(data)
676	return err
677}
678
679// Write implements the ResponseWriter.Write method.
680func (w *response) Write(m []byte) (int, error) {
681	if w.closed {
682		return 0, &Error{err: "Write called after Close"}
683	}
684
685	switch {
686	case w.udp != nil:
687		return WriteToSessionUDP(w.udp, m, w.udpSession)
688	case w.tcp != nil:
689		if len(m) > MaxMsgSize {
690			return 0, &Error{err: "message too large"}
691		}
692
693		l := make([]byte, 2)
694		binary.BigEndian.PutUint16(l, uint16(len(m)))
695
696		n, err := (&net.Buffers{l, m}).WriteTo(w.tcp)
697		return int(n), err
698	default:
699		panic("dns: internal error: udp and tcp both nil")
700	}
701}
702
703// LocalAddr implements the ResponseWriter.LocalAddr method.
704func (w *response) LocalAddr() net.Addr {
705	switch {
706	case w.udp != nil:
707		return w.udp.LocalAddr()
708	case w.tcp != nil:
709		return w.tcp.LocalAddr()
710	default:
711		panic("dns: internal error: udp and tcp both nil")
712	}
713}
714
715// RemoteAddr implements the ResponseWriter.RemoteAddr method.
716func (w *response) RemoteAddr() net.Addr {
717	switch {
718	case w.udpSession != nil:
719		return w.udpSession.RemoteAddr()
720	case w.tcp != nil:
721		return w.tcp.RemoteAddr()
722	default:
723		panic("dns: internal error: udpSession and tcp both nil")
724	}
725}
726
727// TsigStatus implements the ResponseWriter.TsigStatus method.
728func (w *response) TsigStatus() error { return w.tsigStatus }
729
730// TsigTimersOnly implements the ResponseWriter.TsigTimersOnly method.
731func (w *response) TsigTimersOnly(b bool) { w.tsigTimersOnly = b }
732
733// Hijack implements the ResponseWriter.Hijack method.
734func (w *response) Hijack() { w.hijacked = true }
735
736// Close implements the ResponseWriter.Close method
737func (w *response) Close() error {
738	if w.closed {
739		return &Error{err: "connection already closed"}
740	}
741	w.closed = true
742
743	switch {
744	case w.udp != nil:
745		// Can't close the udp conn, as that is actually the listener.
746		return nil
747	case w.tcp != nil:
748		return w.tcp.Close()
749	default:
750		panic("dns: internal error: udp and tcp both nil")
751	}
752}
753
754// ConnectionState() implements the ConnectionStater.ConnectionState() interface.
755func (w *response) ConnectionState() *tls.ConnectionState {
756	type tlsConnectionStater interface {
757		ConnectionState() tls.ConnectionState
758	}
759	if v, ok := w.tcp.(tlsConnectionStater); ok {
760		t := v.ConnectionState()
761		return &t
762	}
763	return nil
764}
765