1// DNS server implementation.
2
3package dns
4
5import (
6	"context"
7	"crypto/tls"
8	"encoding/binary"
9	"errors"
10	"io"
11	"net"
12	"strings"
13	"sync"
14	"time"
15)
16
17// Default maximum number of TCP queries before we close the socket.
18const maxTCPQueries = 128
19
20// aLongTimeAgo is a non-zero time, far in the past, used for
21// immediate cancelation of network operations.
22var aLongTimeAgo = time.Unix(1, 0)
23
24// Handler is implemented by any value that implements ServeDNS.
25type Handler interface {
26	ServeDNS(w ResponseWriter, r *Msg)
27}
28
29// The HandlerFunc type is an adapter to allow the use of
30// ordinary functions as DNS handlers.  If f is a function
31// with the appropriate signature, HandlerFunc(f) is a
32// Handler object that calls f.
33type HandlerFunc func(ResponseWriter, *Msg)
34
35// ServeDNS calls f(w, r).
36func (f HandlerFunc) ServeDNS(w ResponseWriter, r *Msg) {
37	f(w, r)
38}
39
40// A ResponseWriter interface is used by an DNS handler to
41// construct an DNS response.
42type ResponseWriter interface {
43	// LocalAddr returns the net.Addr of the server
44	LocalAddr() net.Addr
45	// RemoteAddr returns the net.Addr of the client that sent the current request.
46	RemoteAddr() net.Addr
47	// WriteMsg writes a reply back to the client.
48	WriteMsg(*Msg) error
49	// Write writes a raw buffer back to the client.
50	Write([]byte) (int, error)
51	// Close closes the connection.
52	Close() error
53	// TsigStatus returns the status of the Tsig.
54	TsigStatus() error
55	// TsigTimersOnly sets the tsig timers only boolean.
56	TsigTimersOnly(bool)
57	// Hijack lets the caller take over the connection.
58	// After a call to Hijack(), the DNS package will not do anything with the connection.
59	Hijack()
60}
61
62// A ConnectionStater interface is used by a DNS Handler to access TLS connection state
63// when available.
64type ConnectionStater interface {
65	ConnectionState() *tls.ConnectionState
66}
67
68type response struct {
69	closed         bool // connection has been closed
70	hijacked       bool // connection has been hijacked by handler
71	tsigTimersOnly bool
72	tsigStatus     error
73	tsigRequestMAC string
74	tsigProvider   TsigProvider
75	udp            net.PacketConn // i/o connection if UDP was used
76	tcp            net.Conn       // i/o connection if TCP was used
77	udpSession     *SessionUDP    // oob data to get egress interface right
78	pcSession      net.Addr       // address to use when writing to a generic net.PacketConn
79	writer         Writer         // writer to output the raw DNS bits
80}
81
82// handleRefused returns a HandlerFunc that returns REFUSED for every request it gets.
83func handleRefused(w ResponseWriter, r *Msg) {
84	m := new(Msg)
85	m.SetRcode(r, RcodeRefused)
86	w.WriteMsg(m)
87}
88
89// HandleFailed returns a HandlerFunc that returns SERVFAIL for every request it gets.
90// Deprecated: This function is going away.
91func HandleFailed(w ResponseWriter, r *Msg) {
92	m := new(Msg)
93	m.SetRcode(r, RcodeServerFailure)
94	// does not matter if this write fails
95	w.WriteMsg(m)
96}
97
98// ListenAndServe Starts a server on address and network specified Invoke handler
99// for incoming queries.
100func ListenAndServe(addr string, network string, handler Handler) error {
101	server := &Server{Addr: addr, Net: network, Handler: handler}
102	return server.ListenAndServe()
103}
104
105// ListenAndServeTLS acts like http.ListenAndServeTLS, more information in
106// http://golang.org/pkg/net/http/#ListenAndServeTLS
107func ListenAndServeTLS(addr, certFile, keyFile string, handler Handler) error {
108	cert, err := tls.LoadX509KeyPair(certFile, keyFile)
109	if err != nil {
110		return err
111	}
112
113	config := tls.Config{
114		Certificates: []tls.Certificate{cert},
115	}
116
117	server := &Server{
118		Addr:      addr,
119		Net:       "tcp-tls",
120		TLSConfig: &config,
121		Handler:   handler,
122	}
123
124	return server.ListenAndServe()
125}
126
127// ActivateAndServe activates a server with a listener from systemd,
128// l and p should not both be non-nil.
129// If both l and p are not nil only p will be used.
130// Invoke handler for incoming queries.
131func ActivateAndServe(l net.Listener, p net.PacketConn, handler Handler) error {
132	server := &Server{Listener: l, PacketConn: p, Handler: handler}
133	return server.ActivateAndServe()
134}
135
136// Writer writes raw DNS messages; each call to Write should send an entire message.
137type Writer interface {
138	io.Writer
139}
140
141// Reader reads raw DNS messages; each call to ReadTCP or ReadUDP should return an entire message.
142type Reader interface {
143	// ReadTCP reads a raw message from a TCP connection. Implementations may alter
144	// connection properties, for example the read-deadline.
145	ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error)
146	// ReadUDP reads a raw message from a UDP connection. Implementations may alter
147	// connection properties, for example the read-deadline.
148	ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error)
149}
150
151// PacketConnReader is an optional interface that Readers can implement to support using generic net.PacketConns.
152type PacketConnReader interface {
153	Reader
154
155	// ReadPacketConn reads a raw message from a generic net.PacketConn UDP connection. Implementations may
156	// alter connection properties, for example the read-deadline.
157	ReadPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error)
158}
159
160// defaultReader is an adapter for the Server struct that implements the Reader and
161// PacketConnReader interfaces using the readTCP, readUDP and readPacketConn funcs
162// of the embedded Server.
163type defaultReader struct {
164	*Server
165}
166
167var _ PacketConnReader = defaultReader{}
168
169func (dr defaultReader) ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
170	return dr.readTCP(conn, timeout)
171}
172
173func (dr defaultReader) ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) {
174	return dr.readUDP(conn, timeout)
175}
176
177func (dr defaultReader) ReadPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error) {
178	return dr.readPacketConn(conn, timeout)
179}
180
181// DecorateReader is a decorator hook for extending or supplanting the functionality of a Reader.
182// Implementations should never return a nil Reader.
183// Readers should also implement the optional PacketConnReader interface.
184// PacketConnReader is required to use a generic net.PacketConn.
185type DecorateReader func(Reader) Reader
186
187// DecorateWriter is a decorator hook for extending or supplanting the functionality of a Writer.
188// Implementations should never return a nil Writer.
189type DecorateWriter func(Writer) Writer
190
191// A Server defines parameters for running an DNS server.
192type Server struct {
193	// Address to listen on, ":dns" if empty.
194	Addr string
195	// if "tcp" or "tcp-tls" (DNS over TLS) it will invoke a TCP listener, otherwise an UDP one
196	Net string
197	// TCP Listener to use, this is to aid in systemd's socket activation.
198	Listener net.Listener
199	// TLS connection configuration
200	TLSConfig *tls.Config
201	// UDP "Listener" to use, this is to aid in systemd's socket activation.
202	PacketConn net.PacketConn
203	// Handler to invoke, dns.DefaultServeMux if nil.
204	Handler Handler
205	// Default buffer size to use to read incoming UDP messages. If not set
206	// it defaults to MinMsgSize (512 B).
207	UDPSize int
208	// The net.Conn.SetReadTimeout value for new connections, defaults to 2 * time.Second.
209	ReadTimeout time.Duration
210	// The net.Conn.SetWriteTimeout value for new connections, defaults to 2 * time.Second.
211	WriteTimeout time.Duration
212	// TCP idle timeout for multiple queries, if nil, defaults to 8 * time.Second (RFC 5966).
213	IdleTimeout func() time.Duration
214	// An implementation of the TsigProvider interface. If defined it replaces TsigSecret and is used for all TSIG operations.
215	TsigProvider TsigProvider
216	// Secret(s) for Tsig map[<zonename>]<base64 secret>. The zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2).
217	TsigSecret map[string]string
218	// If NotifyStartedFunc is set it is called once the server has started listening.
219	NotifyStartedFunc func()
220	// DecorateReader is optional, allows customization of the process that reads raw DNS messages.
221	DecorateReader DecorateReader
222	// DecorateWriter is optional, allows customization of the process that writes raw DNS messages.
223	DecorateWriter DecorateWriter
224	// Maximum number of TCP queries before we close the socket. Default is maxTCPQueries (unlimited if -1).
225	MaxTCPQueries int
226	// Whether to set the SO_REUSEPORT socket option, allowing multiple listeners to be bound to a single address.
227	// It is only supported on go1.11+ and when using ListenAndServe.
228	ReusePort bool
229	// AcceptMsgFunc will check the incoming message and will reject it early in the process.
230	// By default DefaultMsgAcceptFunc will be used.
231	MsgAcceptFunc MsgAcceptFunc
232
233	// Shutdown handling
234	lock     sync.RWMutex
235	started  bool
236	shutdown chan struct{}
237	conns    map[net.Conn]struct{}
238
239	// A pool for UDP message buffers.
240	udpPool sync.Pool
241}
242
243func (srv *Server) tsigProvider() TsigProvider {
244	if srv.TsigProvider != nil {
245		return srv.TsigProvider
246	}
247	if srv.TsigSecret != nil {
248		return tsigSecretProvider(srv.TsigSecret)
249	}
250	return nil
251}
252
253func (srv *Server) isStarted() bool {
254	srv.lock.RLock()
255	started := srv.started
256	srv.lock.RUnlock()
257	return started
258}
259
260func makeUDPBuffer(size int) func() interface{} {
261	return func() interface{} {
262		return make([]byte, size)
263	}
264}
265
266func (srv *Server) init() {
267	srv.shutdown = make(chan struct{})
268	srv.conns = make(map[net.Conn]struct{})
269
270	if srv.UDPSize == 0 {
271		srv.UDPSize = MinMsgSize
272	}
273	if srv.MsgAcceptFunc == nil {
274		srv.MsgAcceptFunc = DefaultMsgAcceptFunc
275	}
276	if srv.Handler == nil {
277		srv.Handler = DefaultServeMux
278	}
279
280	srv.udpPool.New = makeUDPBuffer(srv.UDPSize)
281}
282
283func unlockOnce(l sync.Locker) func() {
284	var once sync.Once
285	return func() { once.Do(l.Unlock) }
286}
287
288// ListenAndServe starts a nameserver on the configured address in *Server.
289func (srv *Server) ListenAndServe() error {
290	unlock := unlockOnce(&srv.lock)
291	srv.lock.Lock()
292	defer unlock()
293
294	if srv.started {
295		return &Error{err: "server already started"}
296	}
297
298	addr := srv.Addr
299	if addr == "" {
300		addr = ":domain"
301	}
302
303	srv.init()
304
305	switch srv.Net {
306	case "tcp", "tcp4", "tcp6":
307		l, err := listenTCP(srv.Net, addr, srv.ReusePort)
308		if err != nil {
309			return err
310		}
311		srv.Listener = l
312		srv.started = true
313		unlock()
314		return srv.serveTCP(l)
315	case "tcp-tls", "tcp4-tls", "tcp6-tls":
316		if srv.TLSConfig == nil || (len(srv.TLSConfig.Certificates) == 0 && srv.TLSConfig.GetCertificate == nil) {
317			return errors.New("dns: neither Certificates nor GetCertificate set in Config")
318		}
319		network := strings.TrimSuffix(srv.Net, "-tls")
320		l, err := listenTCP(network, addr, srv.ReusePort)
321		if err != nil {
322			return err
323		}
324		l = tls.NewListener(l, srv.TLSConfig)
325		srv.Listener = l
326		srv.started = true
327		unlock()
328		return srv.serveTCP(l)
329	case "udp", "udp4", "udp6":
330		l, err := listenUDP(srv.Net, addr, srv.ReusePort)
331		if err != nil {
332			return err
333		}
334		u := l.(*net.UDPConn)
335		if e := setUDPSocketOptions(u); e != nil {
336			u.Close()
337			return e
338		}
339		srv.PacketConn = l
340		srv.started = true
341		unlock()
342		return srv.serveUDP(u)
343	}
344	return &Error{err: "bad network"}
345}
346
347// ActivateAndServe starts a nameserver with the PacketConn or Listener
348// configured in *Server. Its main use is to start a server from systemd.
349func (srv *Server) ActivateAndServe() error {
350	unlock := unlockOnce(&srv.lock)
351	srv.lock.Lock()
352	defer unlock()
353
354	if srv.started {
355		return &Error{err: "server already started"}
356	}
357
358	srv.init()
359
360	if srv.PacketConn != nil {
361		// Check PacketConn interface's type is valid and value
362		// is not nil
363		if t, ok := srv.PacketConn.(*net.UDPConn); ok && t != nil {
364			if e := setUDPSocketOptions(t); e != nil {
365				return e
366			}
367		}
368		srv.started = true
369		unlock()
370		return srv.serveUDP(srv.PacketConn)
371	}
372	if srv.Listener != nil {
373		srv.started = true
374		unlock()
375		return srv.serveTCP(srv.Listener)
376	}
377	return &Error{err: "bad listeners"}
378}
379
380// Shutdown shuts down a server. After a call to Shutdown, ListenAndServe and
381// ActivateAndServe will return.
382func (srv *Server) Shutdown() error {
383	return srv.ShutdownContext(context.Background())
384}
385
386// ShutdownContext shuts down a server. After a call to ShutdownContext,
387// ListenAndServe and ActivateAndServe will return.
388//
389// A context.Context may be passed to limit how long to wait for connections
390// to terminate.
391func (srv *Server) ShutdownContext(ctx context.Context) error {
392	srv.lock.Lock()
393	if !srv.started {
394		srv.lock.Unlock()
395		return &Error{err: "server not started"}
396	}
397
398	srv.started = false
399
400	if srv.PacketConn != nil {
401		srv.PacketConn.SetReadDeadline(aLongTimeAgo) // Unblock reads
402	}
403
404	if srv.Listener != nil {
405		srv.Listener.Close()
406	}
407
408	for rw := range srv.conns {
409		rw.SetReadDeadline(aLongTimeAgo) // Unblock reads
410	}
411
412	srv.lock.Unlock()
413
414	if testShutdownNotify != nil {
415		testShutdownNotify.Broadcast()
416	}
417
418	var ctxErr error
419	select {
420	case <-srv.shutdown:
421	case <-ctx.Done():
422		ctxErr = ctx.Err()
423	}
424
425	if srv.PacketConn != nil {
426		srv.PacketConn.Close()
427	}
428
429	return ctxErr
430}
431
432var testShutdownNotify *sync.Cond
433
434// getReadTimeout is a helper func to use system timeout if server did not intend to change it.
435func (srv *Server) getReadTimeout() time.Duration {
436	if srv.ReadTimeout != 0 {
437		return srv.ReadTimeout
438	}
439	return dnsTimeout
440}
441
442// serveTCP starts a TCP listener for the server.
443func (srv *Server) serveTCP(l net.Listener) error {
444	defer l.Close()
445
446	if srv.NotifyStartedFunc != nil {
447		srv.NotifyStartedFunc()
448	}
449
450	var wg sync.WaitGroup
451	defer func() {
452		wg.Wait()
453		close(srv.shutdown)
454	}()
455
456	for srv.isStarted() {
457		rw, err := l.Accept()
458		if err != nil {
459			if !srv.isStarted() {
460				return nil
461			}
462			if neterr, ok := err.(net.Error); ok && neterr.Temporary() {
463				continue
464			}
465			return err
466		}
467		srv.lock.Lock()
468		// Track the connection to allow unblocking reads on shutdown.
469		srv.conns[rw] = struct{}{}
470		srv.lock.Unlock()
471		wg.Add(1)
472		go srv.serveTCPConn(&wg, rw)
473	}
474
475	return nil
476}
477
478// serveUDP starts a UDP listener for the server.
479func (srv *Server) serveUDP(l net.PacketConn) error {
480	defer l.Close()
481
482	reader := Reader(defaultReader{srv})
483	if srv.DecorateReader != nil {
484		reader = srv.DecorateReader(reader)
485	}
486
487	lUDP, isUDP := l.(*net.UDPConn)
488	readerPC, canPacketConn := reader.(PacketConnReader)
489	if !isUDP && !canPacketConn {
490		return &Error{err: "PacketConnReader was not implemented on Reader returned from DecorateReader but is required for net.PacketConn"}
491	}
492
493	if srv.NotifyStartedFunc != nil {
494		srv.NotifyStartedFunc()
495	}
496
497	var wg sync.WaitGroup
498	defer func() {
499		wg.Wait()
500		close(srv.shutdown)
501	}()
502
503	rtimeout := srv.getReadTimeout()
504	// deadline is not used here
505	for srv.isStarted() {
506		var (
507			m    []byte
508			sPC  net.Addr
509			sUDP *SessionUDP
510			err  error
511		)
512		if isUDP {
513			m, sUDP, err = reader.ReadUDP(lUDP, rtimeout)
514		} else {
515			m, sPC, err = readerPC.ReadPacketConn(l, rtimeout)
516		}
517		if err != nil {
518			if !srv.isStarted() {
519				return nil
520			}
521			if netErr, ok := err.(net.Error); ok && netErr.Temporary() {
522				continue
523			}
524			return err
525		}
526		if len(m) < headerSize {
527			if cap(m) == srv.UDPSize {
528				srv.udpPool.Put(m[:srv.UDPSize])
529			}
530			continue
531		}
532		wg.Add(1)
533		go srv.serveUDPPacket(&wg, m, l, sUDP, sPC)
534	}
535
536	return nil
537}
538
539// Serve a new TCP connection.
540func (srv *Server) serveTCPConn(wg *sync.WaitGroup, rw net.Conn) {
541	w := &response{tsigProvider: srv.tsigProvider(), tcp: rw}
542	if srv.DecorateWriter != nil {
543		w.writer = srv.DecorateWriter(w)
544	} else {
545		w.writer = w
546	}
547
548	reader := Reader(defaultReader{srv})
549	if srv.DecorateReader != nil {
550		reader = srv.DecorateReader(reader)
551	}
552
553	idleTimeout := tcpIdleTimeout
554	if srv.IdleTimeout != nil {
555		idleTimeout = srv.IdleTimeout()
556	}
557
558	timeout := srv.getReadTimeout()
559
560	limit := srv.MaxTCPQueries
561	if limit == 0 {
562		limit = maxTCPQueries
563	}
564
565	for q := 0; (q < limit || limit == -1) && srv.isStarted(); q++ {
566		m, err := reader.ReadTCP(w.tcp, timeout)
567		if err != nil {
568			// TODO(tmthrgd): handle error
569			break
570		}
571		srv.serveDNS(m, w)
572		if w.closed {
573			break // Close() was called
574		}
575		if w.hijacked {
576			break // client will call Close() themselves
577		}
578		// The first read uses the read timeout, the rest use the
579		// idle timeout.
580		timeout = idleTimeout
581	}
582
583	if !w.hijacked {
584		w.Close()
585	}
586
587	srv.lock.Lock()
588	delete(srv.conns, w.tcp)
589	srv.lock.Unlock()
590
591	wg.Done()
592}
593
594// Serve a new UDP request.
595func (srv *Server) serveUDPPacket(wg *sync.WaitGroup, m []byte, u net.PacketConn, udpSession *SessionUDP, pcSession net.Addr) {
596	w := &response{tsigProvider: srv.tsigProvider(), udp: u, udpSession: udpSession, pcSession: pcSession}
597	if srv.DecorateWriter != nil {
598		w.writer = srv.DecorateWriter(w)
599	} else {
600		w.writer = w
601	}
602
603	srv.serveDNS(m, w)
604	wg.Done()
605}
606
607func (srv *Server) serveDNS(m []byte, w *response) {
608	dh, off, err := unpackMsgHdr(m, 0)
609	if err != nil {
610		// Let client hang, they are sending crap; any reply can be used to amplify.
611		return
612	}
613
614	req := new(Msg)
615	req.setHdr(dh)
616
617	switch action := srv.MsgAcceptFunc(dh); action {
618	case MsgAccept:
619		if req.unpack(dh, m, off) == nil {
620			break
621		}
622
623		fallthrough
624	case MsgReject, MsgRejectNotImplemented:
625		opcode := req.Opcode
626		req.SetRcodeFormatError(req)
627		req.Zero = false
628		if action == MsgRejectNotImplemented {
629			req.Opcode = opcode
630			req.Rcode = RcodeNotImplemented
631		}
632
633		// Are we allowed to delete any OPT records here?
634		req.Ns, req.Answer, req.Extra = nil, nil, nil
635
636		w.WriteMsg(req)
637		fallthrough
638	case MsgIgnore:
639		if w.udp != nil && cap(m) == srv.UDPSize {
640			srv.udpPool.Put(m[:srv.UDPSize])
641		}
642
643		return
644	}
645
646	w.tsigStatus = nil
647	if w.tsigProvider != nil {
648		if t := req.IsTsig(); t != nil {
649			w.tsigStatus = tsigVerifyProvider(m, w.tsigProvider, "", false)
650			w.tsigTimersOnly = false
651			w.tsigRequestMAC = t.MAC
652		}
653	}
654
655	if w.udp != nil && cap(m) == srv.UDPSize {
656		srv.udpPool.Put(m[:srv.UDPSize])
657	}
658
659	srv.Handler.ServeDNS(w, req) // Writes back to the client
660}
661
662func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
663	// If we race with ShutdownContext, the read deadline may
664	// have been set in the distant past to unblock the read
665	// below. We must not override it, otherwise we may block
666	// ShutdownContext.
667	srv.lock.RLock()
668	if srv.started {
669		conn.SetReadDeadline(time.Now().Add(timeout))
670	}
671	srv.lock.RUnlock()
672
673	var length uint16
674	if err := binary.Read(conn, binary.BigEndian, &length); err != nil {
675		return nil, err
676	}
677
678	m := make([]byte, length)
679	if _, err := io.ReadFull(conn, m); err != nil {
680		return nil, err
681	}
682
683	return m, nil
684}
685
686func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) {
687	srv.lock.RLock()
688	if srv.started {
689		// See the comment in readTCP above.
690		conn.SetReadDeadline(time.Now().Add(timeout))
691	}
692	srv.lock.RUnlock()
693
694	m := srv.udpPool.Get().([]byte)
695	n, s, err := ReadFromSessionUDP(conn, m)
696	if err != nil {
697		srv.udpPool.Put(m)
698		return nil, nil, err
699	}
700	m = m[:n]
701	return m, s, nil
702}
703
704func (srv *Server) readPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error) {
705	srv.lock.RLock()
706	if srv.started {
707		// See the comment in readTCP above.
708		conn.SetReadDeadline(time.Now().Add(timeout))
709	}
710	srv.lock.RUnlock()
711
712	m := srv.udpPool.Get().([]byte)
713	n, addr, err := conn.ReadFrom(m)
714	if err != nil {
715		srv.udpPool.Put(m)
716		return nil, nil, err
717	}
718	m = m[:n]
719	return m, addr, nil
720}
721
722// WriteMsg implements the ResponseWriter.WriteMsg method.
723func (w *response) WriteMsg(m *Msg) (err error) {
724	if w.closed {
725		return &Error{err: "WriteMsg called after Close"}
726	}
727
728	var data []byte
729	if w.tsigProvider != nil { // if no provider, dont check for the tsig (which is a longer check)
730		if t := m.IsTsig(); t != nil {
731			data, w.tsigRequestMAC, err = tsigGenerateProvider(m, w.tsigProvider, w.tsigRequestMAC, w.tsigTimersOnly)
732			if err != nil {
733				return err
734			}
735			_, err = w.writer.Write(data)
736			return err
737		}
738	}
739	data, err = m.Pack()
740	if err != nil {
741		return err
742	}
743	_, err = w.writer.Write(data)
744	return err
745}
746
747// Write implements the ResponseWriter.Write method.
748func (w *response) Write(m []byte) (int, error) {
749	if w.closed {
750		return 0, &Error{err: "Write called after Close"}
751	}
752
753	switch {
754	case w.udp != nil:
755		if u, ok := w.udp.(*net.UDPConn); ok {
756			return WriteToSessionUDP(u, m, w.udpSession)
757		}
758		return w.udp.WriteTo(m, w.pcSession)
759	case w.tcp != nil:
760		if len(m) > MaxMsgSize {
761			return 0, &Error{err: "message too large"}
762		}
763
764		msg := make([]byte, 2+len(m))
765		binary.BigEndian.PutUint16(msg, uint16(len(m)))
766		copy(msg[2:], m)
767		return w.tcp.Write(msg)
768	default:
769		panic("dns: internal error: udp and tcp both nil")
770	}
771}
772
773// LocalAddr implements the ResponseWriter.LocalAddr method.
774func (w *response) LocalAddr() net.Addr {
775	switch {
776	case w.udp != nil:
777		return w.udp.LocalAddr()
778	case w.tcp != nil:
779		return w.tcp.LocalAddr()
780	default:
781		panic("dns: internal error: udp and tcp both nil")
782	}
783}
784
785// RemoteAddr implements the ResponseWriter.RemoteAddr method.
786func (w *response) RemoteAddr() net.Addr {
787	switch {
788	case w.udpSession != nil:
789		return w.udpSession.RemoteAddr()
790	case w.pcSession != nil:
791		return w.pcSession
792	case w.tcp != nil:
793		return w.tcp.RemoteAddr()
794	default:
795		panic("dns: internal error: udpSession, pcSession and tcp are all nil")
796	}
797}
798
799// TsigStatus implements the ResponseWriter.TsigStatus method.
800func (w *response) TsigStatus() error { return w.tsigStatus }
801
802// TsigTimersOnly implements the ResponseWriter.TsigTimersOnly method.
803func (w *response) TsigTimersOnly(b bool) { w.tsigTimersOnly = b }
804
805// Hijack implements the ResponseWriter.Hijack method.
806func (w *response) Hijack() { w.hijacked = true }
807
808// Close implements the ResponseWriter.Close method
809func (w *response) Close() error {
810	if w.closed {
811		return &Error{err: "connection already closed"}
812	}
813	w.closed = true
814
815	switch {
816	case w.udp != nil:
817		// Can't close the udp conn, as that is actually the listener.
818		return nil
819	case w.tcp != nil:
820		return w.tcp.Close()
821	default:
822		panic("dns: internal error: udp and tcp both nil")
823	}
824}
825
826// ConnectionState() implements the ConnectionStater.ConnectionState() interface.
827func (w *response) ConnectionState() *tls.ConnectionState {
828	type tlsConnectionStater interface {
829		ConnectionState() tls.ConnectionState
830	}
831	if v, ok := w.tcp.(tlsConnectionStater); ok {
832		t := v.ConnectionState()
833		return &t
834	}
835	return nil
836}
837