1// DNS packet assembly, see RFC 1035. Converting from - Unpack() -
2// and to - Pack() - wire format.
3// All the packers and unpackers take a (msg []byte, off int)
4// and return (off1 int, ok bool).  If they return ok==false, they
5// also return off1==len(msg), so that the next unpacker will
6// also fail.  This lets us avoid checks of ok until the end of a
7// packing sequence.
8
9package dns
10
11//go:generate go run msg_generate.go
12
13import (
14	crand "crypto/rand"
15	"encoding/binary"
16	"fmt"
17	"math/big"
18	"math/rand"
19	"strconv"
20	"strings"
21	"sync"
22)
23
24const (
25	maxCompressionOffset    = 2 << 13 // We have 14 bits for the compression pointer
26	maxDomainNameWireOctets = 255     // See RFC 1035 section 2.3.4
27
28	// This is the maximum number of compression pointers that should occur in a
29	// semantically valid message. Each label in a domain name must be at least one
30	// octet and is separated by a period. The root label won't be represented by a
31	// compression pointer to a compression pointer, hence the -2 to exclude the
32	// smallest valid root label.
33	//
34	// It is possible to construct a valid message that has more compression pointers
35	// than this, and still doesn't loop, by pointing to a previous pointer. This is
36	// not something a well written implementation should ever do, so we leave them
37	// to trip the maximum compression pointer check.
38	maxCompressionPointers = (maxDomainNameWireOctets+1)/2 - 2
39
40	// This is the maximum length of a domain name in presentation format. The
41	// maximum wire length of a domain name is 255 octets (see above), with the
42	// maximum label length being 63. The wire format requires one extra byte over
43	// the presentation format, reducing the number of octets by 1. Each label in
44	// the name will be separated by a single period, with each octet in the label
45	// expanding to at most 4 bytes (\DDD). If all other labels are of the maximum
46	// length, then the final label can only be 61 octets long to not exceed the
47	// maximum allowed wire length.
48	maxDomainNamePresentationLength = 61*4 + 1 + 63*4 + 1 + 63*4 + 1 + 63*4 + 1
49)
50
51// Errors defined in this package.
52var (
53	ErrAlg           error = &Error{err: "bad algorithm"}                  // ErrAlg indicates an error with the (DNSSEC) algorithm.
54	ErrAuth          error = &Error{err: "bad authentication"}             // ErrAuth indicates an error in the TSIG authentication.
55	ErrBuf           error = &Error{err: "buffer size too small"}          // ErrBuf indicates that the buffer used is too small for the message.
56	ErrConnEmpty     error = &Error{err: "conn has no connection"}         // ErrConnEmpty indicates a connection is being used before it is initialized.
57	ErrExtendedRcode error = &Error{err: "bad extended rcode"}             // ErrExtendedRcode ...
58	ErrFqdn          error = &Error{err: "domain must be fully qualified"} // ErrFqdn indicates that a domain name does not have a closing dot.
59	ErrId            error = &Error{err: "id mismatch"}                    // ErrId indicates there is a mismatch with the message's ID.
60	ErrKeyAlg        error = &Error{err: "bad key algorithm"}              // ErrKeyAlg indicates that the algorithm in the key is not valid.
61	ErrKey           error = &Error{err: "bad key"}
62	ErrKeySize       error = &Error{err: "bad key size"}
63	ErrLongDomain    error = &Error{err: fmt.Sprintf("domain name exceeded %d wire-format octets", maxDomainNameWireOctets)}
64	ErrNoSig         error = &Error{err: "no signature found"}
65	ErrPrivKey       error = &Error{err: "bad private key"}
66	ErrRcode         error = &Error{err: "bad rcode"}
67	ErrRdata         error = &Error{err: "bad rdata"}
68	ErrRRset         error = &Error{err: "bad rrset"}
69	ErrSecret        error = &Error{err: "no secrets defined"}
70	ErrShortRead     error = &Error{err: "short read"}
71	ErrSig           error = &Error{err: "bad signature"} // ErrSig indicates that a signature can not be cryptographically validated.
72	ErrSoa           error = &Error{err: "no SOA"}        // ErrSOA indicates that no SOA RR was seen when doing zone transfers.
73	ErrTime          error = &Error{err: "bad time"}      // ErrTime indicates a timing error in TSIG authentication.
74)
75
76// Id by default, returns a 16 bits random number to be used as a
77// message id. The random provided should be good enough. This being a
78// variable the function can be reassigned to a custom function.
79// For instance, to make it return a static value:
80//
81//	dns.Id = func() uint16 { return 3 }
82var Id = id
83
84var (
85	idLock sync.Mutex
86	idRand *rand.Rand
87)
88
89// id returns a 16 bits random number to be used as a
90// message id. The random provided should be good enough.
91func id() uint16 {
92	idLock.Lock()
93
94	if idRand == nil {
95		// This (partially) works around
96		// https://github.com/golang/go/issues/11833 by only
97		// seeding idRand upon the first call to id.
98
99		var seed int64
100		var buf [8]byte
101
102		if _, err := crand.Read(buf[:]); err == nil {
103			seed = int64(binary.LittleEndian.Uint64(buf[:]))
104		} else {
105			seed = rand.Int63()
106		}
107
108		idRand = rand.New(rand.NewSource(seed))
109	}
110
111	// The call to idRand.Uint32 must be within the
112	// mutex lock because *rand.Rand is not safe for
113	// concurrent use.
114	//
115	// There is no added performance overhead to calling
116	// idRand.Uint32 inside a mutex lock over just
117	// calling rand.Uint32 as the global math/rand rng
118	// is internally protected by a sync.Mutex.
119	id := uint16(idRand.Uint32())
120
121	idLock.Unlock()
122	return id
123}
124
125// MsgHdr is a a manually-unpacked version of (id, bits).
126type MsgHdr struct {
127	Id                 uint16
128	Response           bool
129	Opcode             int
130	Authoritative      bool
131	Truncated          bool
132	RecursionDesired   bool
133	RecursionAvailable bool
134	Zero               bool
135	AuthenticatedData  bool
136	CheckingDisabled   bool
137	Rcode              int
138}
139
140// Msg contains the layout of a DNS message.
141type Msg struct {
142	MsgHdr
143	Compress bool       `json:"-"` // If true, the message will be compressed when converted to wire format.
144	Question []Question // Holds the RR(s) of the question section.
145	Answer   []RR       // Holds the RR(s) of the answer section.
146	Ns       []RR       // Holds the RR(s) of the authority section.
147	Extra    []RR       // Holds the RR(s) of the additional section.
148}
149
150// ClassToString is a maps Classes to strings for each CLASS wire type.
151var ClassToString = map[uint16]string{
152	ClassINET:   "IN",
153	ClassCSNET:  "CS",
154	ClassCHAOS:  "CH",
155	ClassHESIOD: "HS",
156	ClassNONE:   "NONE",
157	ClassANY:    "ANY",
158}
159
160// OpcodeToString maps Opcodes to strings.
161var OpcodeToString = map[int]string{
162	OpcodeQuery:  "QUERY",
163	OpcodeIQuery: "IQUERY",
164	OpcodeStatus: "STATUS",
165	OpcodeNotify: "NOTIFY",
166	OpcodeUpdate: "UPDATE",
167}
168
169// RcodeToString maps Rcodes to strings.
170var RcodeToString = map[int]string{
171	RcodeSuccess:        "NOERROR",
172	RcodeFormatError:    "FORMERR",
173	RcodeServerFailure:  "SERVFAIL",
174	RcodeNameError:      "NXDOMAIN",
175	RcodeNotImplemented: "NOTIMP",
176	RcodeRefused:        "REFUSED",
177	RcodeYXDomain:       "YXDOMAIN", // See RFC 2136
178	RcodeYXRrset:        "YXRRSET",
179	RcodeNXRrset:        "NXRRSET",
180	RcodeNotAuth:        "NOTAUTH",
181	RcodeNotZone:        "NOTZONE",
182	RcodeBadSig:         "BADSIG", // Also known as RcodeBadVers, see RFC 6891
183	//	RcodeBadVers:        "BADVERS",
184	RcodeBadKey:    "BADKEY",
185	RcodeBadTime:   "BADTIME",
186	RcodeBadMode:   "BADMODE",
187	RcodeBadName:   "BADNAME",
188	RcodeBadAlg:    "BADALG",
189	RcodeBadTrunc:  "BADTRUNC",
190	RcodeBadCookie: "BADCOOKIE",
191}
192
193// compressionMap is used to allow a more efficient compression map
194// to be used for internal packDomainName calls without changing the
195// signature or functionality of public API.
196//
197// In particular, map[string]uint16 uses 25% less per-entry memory
198// than does map[string]int.
199type compressionMap struct {
200	ext map[string]int    // external callers
201	int map[string]uint16 // internal callers
202}
203
204func (m compressionMap) valid() bool {
205	return m.int != nil || m.ext != nil
206}
207
208func (m compressionMap) insert(s string, pos int) {
209	if m.ext != nil {
210		m.ext[s] = pos
211	} else {
212		m.int[s] = uint16(pos)
213	}
214}
215
216func (m compressionMap) find(s string) (int, bool) {
217	if m.ext != nil {
218		pos, ok := m.ext[s]
219		return pos, ok
220	}
221
222	pos, ok := m.int[s]
223	return int(pos), ok
224}
225
226// Domain names are a sequence of counted strings
227// split at the dots. They end with a zero-length string.
228
229// PackDomainName packs a domain name s into msg[off:].
230// If compression is wanted compress must be true and the compression
231// map needs to hold a mapping between domain names and offsets
232// pointing into msg.
233func PackDomainName(s string, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
234	return packDomainName(s, msg, off, compressionMap{ext: compression}, compress)
235}
236
237func packDomainName(s string, msg []byte, off int, compression compressionMap, compress bool) (off1 int, err error) {
238	// XXX: A logical copy of this function exists in IsDomainName and
239	// should be kept in sync with this function.
240
241	ls := len(s)
242	if ls == 0 { // Ok, for instance when dealing with update RR without any rdata.
243		return off, nil
244	}
245
246	// If not fully qualified, error out.
247	if !IsFqdn(s) {
248		return len(msg), ErrFqdn
249	}
250
251	// Each dot ends a segment of the name.
252	// We trade each dot byte for a length byte.
253	// Except for escaped dots (\.), which are normal dots.
254	// There is also a trailing zero.
255
256	// Compression
257	pointer := -1
258
259	// Emit sequence of counted strings, chopping at dots.
260	var (
261		begin     int
262		compBegin int
263		compOff   int
264		bs        []byte
265		wasDot    bool
266	)
267loop:
268	for i := 0; i < ls; i++ {
269		var c byte
270		if bs == nil {
271			c = s[i]
272		} else {
273			c = bs[i]
274		}
275
276		switch c {
277		case '\\':
278			if off+1 > len(msg) {
279				return len(msg), ErrBuf
280			}
281
282			if bs == nil {
283				bs = []byte(s)
284			}
285
286			// check for \DDD
287			if i+3 < ls && isDigit(bs[i+1]) && isDigit(bs[i+2]) && isDigit(bs[i+3]) {
288				bs[i] = dddToByte(bs[i+1:])
289				copy(bs[i+1:ls-3], bs[i+4:])
290				ls -= 3
291				compOff += 3
292			} else {
293				copy(bs[i:ls-1], bs[i+1:])
294				ls--
295				compOff++
296			}
297
298			wasDot = false
299		case '.':
300			if wasDot {
301				// two dots back to back is not legal
302				return len(msg), ErrRdata
303			}
304			wasDot = true
305
306			labelLen := i - begin
307			if labelLen >= 1<<6 { // top two bits of length must be clear
308				return len(msg), ErrRdata
309			}
310
311			// off can already (we're in a loop) be bigger than len(msg)
312			// this happens when a name isn't fully qualified
313			if off+1+labelLen > len(msg) {
314				return len(msg), ErrBuf
315			}
316
317			// Don't try to compress '.'
318			// We should only compress when compress is true, but we should also still pick
319			// up names that can be used for *future* compression(s).
320			if compression.valid() && !isRootLabel(s, bs, begin, ls) {
321				if p, ok := compression.find(s[compBegin:]); ok {
322					// The first hit is the longest matching dname
323					// keep the pointer offset we get back and store
324					// the offset of the current name, because that's
325					// where we need to insert the pointer later
326
327					// If compress is true, we're allowed to compress this dname
328					if compress {
329						pointer = p // Where to point to
330						break loop
331					}
332				} else if off < maxCompressionOffset {
333					// Only offsets smaller than maxCompressionOffset can be used.
334					compression.insert(s[compBegin:], off)
335				}
336			}
337
338			// The following is covered by the length check above.
339			msg[off] = byte(labelLen)
340
341			if bs == nil {
342				copy(msg[off+1:], s[begin:i])
343			} else {
344				copy(msg[off+1:], bs[begin:i])
345			}
346			off += 1 + labelLen
347
348			begin = i + 1
349			compBegin = begin + compOff
350		default:
351			wasDot = false
352		}
353	}
354
355	// Root label is special
356	if isRootLabel(s, bs, 0, ls) {
357		return off, nil
358	}
359
360	// If we did compression and we find something add the pointer here
361	if pointer != -1 {
362		// We have two bytes (14 bits) to put the pointer in
363		binary.BigEndian.PutUint16(msg[off:], uint16(pointer^0xC000))
364		return off + 2, nil
365	}
366
367	if off < len(msg) {
368		msg[off] = 0
369	}
370
371	return off + 1, nil
372}
373
374// isRootLabel returns whether s or bs, from off to end, is the root
375// label ".".
376//
377// If bs is nil, s will be checked, otherwise bs will be checked.
378func isRootLabel(s string, bs []byte, off, end int) bool {
379	if bs == nil {
380		return s[off:end] == "."
381	}
382
383	return end-off == 1 && bs[off] == '.'
384}
385
386// Unpack a domain name.
387// In addition to the simple sequences of counted strings above,
388// domain names are allowed to refer to strings elsewhere in the
389// packet, to avoid repeating common suffixes when returning
390// many entries in a single domain.  The pointers are marked
391// by a length byte with the top two bits set.  Ignoring those
392// two bits, that byte and the next give a 14 bit offset from msg[0]
393// where we should pick up the trail.
394// Note that if we jump elsewhere in the packet,
395// we return off1 == the offset after the first pointer we found,
396// which is where the next record will start.
397// In theory, the pointers are only allowed to jump backward.
398// We let them jump anywhere and stop jumping after a while.
399
400// UnpackDomainName unpacks a domain name into a string. It returns
401// the name, the new offset into msg and any error that occurred.
402//
403// When an error is encountered, the unpacked name will be discarded
404// and len(msg) will be returned as the offset.
405func UnpackDomainName(msg []byte, off int) (string, int, error) {
406	s := make([]byte, 0, maxDomainNamePresentationLength)
407	off1 := 0
408	lenmsg := len(msg)
409	budget := maxDomainNameWireOctets
410	ptr := 0 // number of pointers followed
411Loop:
412	for {
413		if off >= lenmsg {
414			return "", lenmsg, ErrBuf
415		}
416		c := int(msg[off])
417		off++
418		switch c & 0xC0 {
419		case 0x00:
420			if c == 0x00 {
421				// end of name
422				break Loop
423			}
424			// literal string
425			if off+c > lenmsg {
426				return "", lenmsg, ErrBuf
427			}
428			budget -= c + 1 // +1 for the label separator
429			if budget <= 0 {
430				return "", lenmsg, ErrLongDomain
431			}
432			for _, b := range msg[off : off+c] {
433				switch b {
434				case '.', '(', ')', ';', ' ', '@':
435					fallthrough
436				case '"', '\\':
437					s = append(s, '\\', b)
438				default:
439					if b < ' ' || b > '~' { // unprintable, use \DDD
440						s = append(s, escapeByte(b)...)
441					} else {
442						s = append(s, b)
443					}
444				}
445			}
446			s = append(s, '.')
447			off += c
448		case 0xC0:
449			// pointer to somewhere else in msg.
450			// remember location after first ptr,
451			// since that's how many bytes we consumed.
452			// also, don't follow too many pointers --
453			// maybe there's a loop.
454			if off >= lenmsg {
455				return "", lenmsg, ErrBuf
456			}
457			c1 := msg[off]
458			off++
459			if ptr == 0 {
460				off1 = off
461			}
462			if ptr++; ptr > maxCompressionPointers {
463				return "", lenmsg, &Error{err: "too many compression pointers"}
464			}
465			// pointer should guarantee that it advances and points forwards at least
466			// but the condition on previous three lines guarantees that it's
467			// at least loop-free
468			off = (c^0xC0)<<8 | int(c1)
469		default:
470			// 0x80 and 0x40 are reserved
471			return "", lenmsg, ErrRdata
472		}
473	}
474	if ptr == 0 {
475		off1 = off
476	}
477	if len(s) == 0 {
478		return ".", off1, nil
479	}
480	return string(s), off1, nil
481}
482
483func packTxt(txt []string, msg []byte, offset int, tmp []byte) (int, error) {
484	if len(txt) == 0 {
485		if offset >= len(msg) {
486			return offset, ErrBuf
487		}
488		msg[offset] = 0
489		return offset, nil
490	}
491	var err error
492	for _, s := range txt {
493		if len(s) > len(tmp) {
494			return offset, ErrBuf
495		}
496		offset, err = packTxtString(s, msg, offset, tmp)
497		if err != nil {
498			return offset, err
499		}
500	}
501	return offset, nil
502}
503
504func packTxtString(s string, msg []byte, offset int, tmp []byte) (int, error) {
505	lenByteOffset := offset
506	if offset >= len(msg) || len(s) > len(tmp) {
507		return offset, ErrBuf
508	}
509	offset++
510	bs := tmp[:len(s)]
511	copy(bs, s)
512	for i := 0; i < len(bs); i++ {
513		if len(msg) <= offset {
514			return offset, ErrBuf
515		}
516		if bs[i] == '\\' {
517			i++
518			if i == len(bs) {
519				break
520			}
521			// check for \DDD
522			if i+2 < len(bs) && isDigit(bs[i]) && isDigit(bs[i+1]) && isDigit(bs[i+2]) {
523				msg[offset] = dddToByte(bs[i:])
524				i += 2
525			} else {
526				msg[offset] = bs[i]
527			}
528		} else {
529			msg[offset] = bs[i]
530		}
531		offset++
532	}
533	l := offset - lenByteOffset - 1
534	if l > 255 {
535		return offset, &Error{err: "string exceeded 255 bytes in txt"}
536	}
537	msg[lenByteOffset] = byte(l)
538	return offset, nil
539}
540
541func packOctetString(s string, msg []byte, offset int, tmp []byte) (int, error) {
542	if offset >= len(msg) || len(s) > len(tmp) {
543		return offset, ErrBuf
544	}
545	bs := tmp[:len(s)]
546	copy(bs, s)
547	for i := 0; i < len(bs); i++ {
548		if len(msg) <= offset {
549			return offset, ErrBuf
550		}
551		if bs[i] == '\\' {
552			i++
553			if i == len(bs) {
554				break
555			}
556			// check for \DDD
557			if i+2 < len(bs) && isDigit(bs[i]) && isDigit(bs[i+1]) && isDigit(bs[i+2]) {
558				msg[offset] = dddToByte(bs[i:])
559				i += 2
560			} else {
561				msg[offset] = bs[i]
562			}
563		} else {
564			msg[offset] = bs[i]
565		}
566		offset++
567	}
568	return offset, nil
569}
570
571func unpackTxt(msg []byte, off0 int) (ss []string, off int, err error) {
572	off = off0
573	var s string
574	for off < len(msg) && err == nil {
575		s, off, err = unpackString(msg, off)
576		if err == nil {
577			ss = append(ss, s)
578		}
579	}
580	return
581}
582
583// Helpers for dealing with escaped bytes
584func isDigit(b byte) bool { return b >= '0' && b <= '9' }
585
586func dddToByte(s []byte) byte {
587	_ = s[2] // bounds check hint to compiler; see golang.org/issue/14808
588	return byte((s[0]-'0')*100 + (s[1]-'0')*10 + (s[2] - '0'))
589}
590
591func dddStringToByte(s string) byte {
592	_ = s[2] // bounds check hint to compiler; see golang.org/issue/14808
593	return byte((s[0]-'0')*100 + (s[1]-'0')*10 + (s[2] - '0'))
594}
595
596// Helper function for packing and unpacking
597func intToBytes(i *big.Int, length int) []byte {
598	buf := i.Bytes()
599	if len(buf) < length {
600		b := make([]byte, length)
601		copy(b[length-len(buf):], buf)
602		return b
603	}
604	return buf
605}
606
607// PackRR packs a resource record rr into msg[off:].
608// See PackDomainName for documentation about the compression.
609func PackRR(rr RR, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
610	headerEnd, off1, err := packRR(rr, msg, off, compressionMap{ext: compression}, compress)
611	if err == nil {
612		// packRR no longer sets the Rdlength field on the rr, but
613		// callers might be expecting it so we set it here.
614		rr.Header().Rdlength = uint16(off1 - headerEnd)
615	}
616	return off1, err
617}
618
619func packRR(rr RR, msg []byte, off int, compression compressionMap, compress bool) (headerEnd int, off1 int, err error) {
620	if rr == nil {
621		return len(msg), len(msg), &Error{err: "nil rr"}
622	}
623
624	headerEnd, err = rr.Header().packHeader(msg, off, compression, compress)
625	if err != nil {
626		return headerEnd, len(msg), err
627	}
628
629	off1, err = rr.pack(msg, headerEnd, compression, compress)
630	if err != nil {
631		return headerEnd, len(msg), err
632	}
633
634	rdlength := off1 - headerEnd
635	if int(uint16(rdlength)) != rdlength { // overflow
636		return headerEnd, len(msg), ErrRdata
637	}
638
639	// The RDLENGTH field is the last field in the header and we set it here.
640	binary.BigEndian.PutUint16(msg[headerEnd-2:], uint16(rdlength))
641	return headerEnd, off1, nil
642}
643
644// UnpackRR unpacks msg[off:] into an RR.
645func UnpackRR(msg []byte, off int) (rr RR, off1 int, err error) {
646	h, off, msg, err := unpackHeader(msg, off)
647	if err != nil {
648		return nil, len(msg), err
649	}
650
651	return UnpackRRWithHeader(h, msg, off)
652}
653
654// UnpackRRWithHeader unpacks the record type specific payload given an existing
655// RR_Header.
656func UnpackRRWithHeader(h RR_Header, msg []byte, off int) (rr RR, off1 int, err error) {
657	if newFn, ok := TypeToRR[h.Rrtype]; ok {
658		rr = newFn()
659		*rr.Header() = h
660	} else {
661		rr = &RFC3597{Hdr: h}
662	}
663
664	if noRdata(h) {
665		return rr, off, nil
666	}
667
668	end := off + int(h.Rdlength)
669
670	off, err = rr.unpack(msg, off)
671	if err != nil {
672		return nil, end, err
673	}
674	if off != end {
675		return &h, end, &Error{err: "bad rdlength"}
676	}
677
678	return rr, off, nil
679}
680
681// unpackRRslice unpacks msg[off:] into an []RR.
682// If we cannot unpack the whole array, then it will return nil
683func unpackRRslice(l int, msg []byte, off int) (dst1 []RR, off1 int, err error) {
684	var r RR
685	// Don't pre-allocate, l may be under attacker control
686	var dst []RR
687	for i := 0; i < l; i++ {
688		off1 := off
689		r, off, err = UnpackRR(msg, off)
690		if err != nil {
691			off = len(msg)
692			break
693		}
694		// If offset does not increase anymore, l is a lie
695		if off1 == off {
696			l = i
697			break
698		}
699		dst = append(dst, r)
700	}
701	if err != nil && off == len(msg) {
702		dst = nil
703	}
704	return dst, off, err
705}
706
707// Convert a MsgHdr to a string, with dig-like headers:
708//
709//;; opcode: QUERY, status: NOERROR, id: 48404
710//
711//;; flags: qr aa rd ra;
712func (h *MsgHdr) String() string {
713	if h == nil {
714		return "<nil> MsgHdr"
715	}
716
717	s := ";; opcode: " + OpcodeToString[h.Opcode]
718	s += ", status: " + RcodeToString[h.Rcode]
719	s += ", id: " + strconv.Itoa(int(h.Id)) + "\n"
720
721	s += ";; flags:"
722	if h.Response {
723		s += " qr"
724	}
725	if h.Authoritative {
726		s += " aa"
727	}
728	if h.Truncated {
729		s += " tc"
730	}
731	if h.RecursionDesired {
732		s += " rd"
733	}
734	if h.RecursionAvailable {
735		s += " ra"
736	}
737	if h.Zero { // Hmm
738		s += " z"
739	}
740	if h.AuthenticatedData {
741		s += " ad"
742	}
743	if h.CheckingDisabled {
744		s += " cd"
745	}
746
747	s += ";"
748	return s
749}
750
751// Pack packs a Msg: it is converted to to wire format.
752// If the dns.Compress is true the message will be in compressed wire format.
753func (dns *Msg) Pack() (msg []byte, err error) {
754	return dns.PackBuffer(nil)
755}
756
757// PackBuffer packs a Msg, using the given buffer buf. If buf is too small a new buffer is allocated.
758func (dns *Msg) PackBuffer(buf []byte) (msg []byte, err error) {
759	// If this message can't be compressed, avoid filling the
760	// compression map and creating garbage.
761	if dns.Compress && dns.isCompressible() {
762		compression := make(map[string]uint16) // Compression pointer mappings.
763		return dns.packBufferWithCompressionMap(buf, compressionMap{int: compression}, true)
764	}
765
766	return dns.packBufferWithCompressionMap(buf, compressionMap{}, false)
767}
768
769// packBufferWithCompressionMap packs a Msg, using the given buffer buf.
770func (dns *Msg) packBufferWithCompressionMap(buf []byte, compression compressionMap, compress bool) (msg []byte, err error) {
771	if dns.Rcode < 0 || dns.Rcode > 0xFFF {
772		return nil, ErrRcode
773	}
774
775	// Set extended rcode unconditionally if we have an opt, this will allow
776	// reseting the extended rcode bits if they need to.
777	if opt := dns.IsEdns0(); opt != nil {
778		opt.SetExtendedRcode(uint16(dns.Rcode))
779	} else if dns.Rcode > 0xF {
780		// If Rcode is an extended one and opt is nil, error out.
781		return nil, ErrExtendedRcode
782	}
783
784	// Convert convenient Msg into wire-like Header.
785	var dh Header
786	dh.Id = dns.Id
787	dh.Bits = uint16(dns.Opcode)<<11 | uint16(dns.Rcode&0xF)
788	if dns.Response {
789		dh.Bits |= _QR
790	}
791	if dns.Authoritative {
792		dh.Bits |= _AA
793	}
794	if dns.Truncated {
795		dh.Bits |= _TC
796	}
797	if dns.RecursionDesired {
798		dh.Bits |= _RD
799	}
800	if dns.RecursionAvailable {
801		dh.Bits |= _RA
802	}
803	if dns.Zero {
804		dh.Bits |= _Z
805	}
806	if dns.AuthenticatedData {
807		dh.Bits |= _AD
808	}
809	if dns.CheckingDisabled {
810		dh.Bits |= _CD
811	}
812
813	dh.Qdcount = uint16(len(dns.Question))
814	dh.Ancount = uint16(len(dns.Answer))
815	dh.Nscount = uint16(len(dns.Ns))
816	dh.Arcount = uint16(len(dns.Extra))
817
818	// We need the uncompressed length here, because we first pack it and then compress it.
819	msg = buf
820	uncompressedLen := msgLenWithCompressionMap(dns, nil)
821	if packLen := uncompressedLen + 1; len(msg) < packLen {
822		msg = make([]byte, packLen)
823	}
824
825	// Pack it in: header and then the pieces.
826	off := 0
827	off, err = dh.pack(msg, off, compression, compress)
828	if err != nil {
829		return nil, err
830	}
831	for _, r := range dns.Question {
832		off, err = r.pack(msg, off, compression, compress)
833		if err != nil {
834			return nil, err
835		}
836	}
837	for _, r := range dns.Answer {
838		_, off, err = packRR(r, msg, off, compression, compress)
839		if err != nil {
840			return nil, err
841		}
842	}
843	for _, r := range dns.Ns {
844		_, off, err = packRR(r, msg, off, compression, compress)
845		if err != nil {
846			return nil, err
847		}
848	}
849	for _, r := range dns.Extra {
850		_, off, err = packRR(r, msg, off, compression, compress)
851		if err != nil {
852			return nil, err
853		}
854	}
855	return msg[:off], nil
856}
857
858func (dns *Msg) unpack(dh Header, msg []byte, off int) (err error) {
859	// If we are at the end of the message we should return *just* the
860	// header. This can still be useful to the caller. 9.9.9.9 sends these
861	// when responding with REFUSED for instance.
862	if off == len(msg) {
863		// reset sections before returning
864		dns.Question, dns.Answer, dns.Ns, dns.Extra = nil, nil, nil, nil
865		return nil
866	}
867
868	// Qdcount, Ancount, Nscount, Arcount can't be trusted, as they are
869	// attacker controlled. This means we can't use them to pre-allocate
870	// slices.
871	dns.Question = nil
872	for i := 0; i < int(dh.Qdcount); i++ {
873		off1 := off
874		var q Question
875		q, off, err = unpackQuestion(msg, off)
876		if err != nil {
877			return err
878		}
879		if off1 == off { // Offset does not increase anymore, dh.Qdcount is a lie!
880			dh.Qdcount = uint16(i)
881			break
882		}
883		dns.Question = append(dns.Question, q)
884	}
885
886	dns.Answer, off, err = unpackRRslice(int(dh.Ancount), msg, off)
887	// The header counts might have been wrong so we need to update it
888	dh.Ancount = uint16(len(dns.Answer))
889	if err == nil {
890		dns.Ns, off, err = unpackRRslice(int(dh.Nscount), msg, off)
891	}
892	// The header counts might have been wrong so we need to update it
893	dh.Nscount = uint16(len(dns.Ns))
894	if err == nil {
895		dns.Extra, off, err = unpackRRslice(int(dh.Arcount), msg, off)
896	}
897	// The header counts might have been wrong so we need to update it
898	dh.Arcount = uint16(len(dns.Extra))
899
900	// Set extended Rcode
901	if opt := dns.IsEdns0(); opt != nil {
902		dns.Rcode |= opt.ExtendedRcode()
903	}
904
905	if off != len(msg) {
906		// TODO(miek) make this an error?
907		// use PackOpt to let people tell how detailed the error reporting should be?
908		// println("dns: extra bytes in dns packet", off, "<", len(msg))
909	}
910	return err
911
912}
913
914// Unpack unpacks a binary message to a Msg structure.
915func (dns *Msg) Unpack(msg []byte) (err error) {
916	dh, off, err := unpackMsgHdr(msg, 0)
917	if err != nil {
918		return err
919	}
920
921	dns.setHdr(dh)
922	return dns.unpack(dh, msg, off)
923}
924
925// Convert a complete message to a string with dig-like output.
926func (dns *Msg) String() string {
927	if dns == nil {
928		return "<nil> MsgHdr"
929	}
930	s := dns.MsgHdr.String() + " "
931	s += "QUERY: " + strconv.Itoa(len(dns.Question)) + ", "
932	s += "ANSWER: " + strconv.Itoa(len(dns.Answer)) + ", "
933	s += "AUTHORITY: " + strconv.Itoa(len(dns.Ns)) + ", "
934	s += "ADDITIONAL: " + strconv.Itoa(len(dns.Extra)) + "\n"
935	if len(dns.Question) > 0 {
936		s += "\n;; QUESTION SECTION:\n"
937		for _, r := range dns.Question {
938			s += r.String() + "\n"
939		}
940	}
941	if len(dns.Answer) > 0 {
942		s += "\n;; ANSWER SECTION:\n"
943		for _, r := range dns.Answer {
944			if r != nil {
945				s += r.String() + "\n"
946			}
947		}
948	}
949	if len(dns.Ns) > 0 {
950		s += "\n;; AUTHORITY SECTION:\n"
951		for _, r := range dns.Ns {
952			if r != nil {
953				s += r.String() + "\n"
954			}
955		}
956	}
957	if len(dns.Extra) > 0 {
958		s += "\n;; ADDITIONAL SECTION:\n"
959		for _, r := range dns.Extra {
960			if r != nil {
961				s += r.String() + "\n"
962			}
963		}
964	}
965	return s
966}
967
968// isCompressible returns whether the msg may be compressible.
969func (dns *Msg) isCompressible() bool {
970	// If we only have one question, there is nothing we can ever compress.
971	return len(dns.Question) > 1 || len(dns.Answer) > 0 ||
972		len(dns.Ns) > 0 || len(dns.Extra) > 0
973}
974
975// Len returns the message length when in (un)compressed wire format.
976// If dns.Compress is true compression it is taken into account. Len()
977// is provided to be a faster way to get the size of the resulting packet,
978// than packing it, measuring the size and discarding the buffer.
979func (dns *Msg) Len() int {
980	// If this message can't be compressed, avoid filling the
981	// compression map and creating garbage.
982	if dns.Compress && dns.isCompressible() {
983		compression := make(map[string]struct{})
984		return msgLenWithCompressionMap(dns, compression)
985	}
986
987	return msgLenWithCompressionMap(dns, nil)
988}
989
990func msgLenWithCompressionMap(dns *Msg, compression map[string]struct{}) int {
991	l := headerSize
992
993	for _, r := range dns.Question {
994		l += r.len(l, compression)
995	}
996	for _, r := range dns.Answer {
997		if r != nil {
998			l += r.len(l, compression)
999		}
1000	}
1001	for _, r := range dns.Ns {
1002		if r != nil {
1003			l += r.len(l, compression)
1004		}
1005	}
1006	for _, r := range dns.Extra {
1007		if r != nil {
1008			l += r.len(l, compression)
1009		}
1010	}
1011
1012	return l
1013}
1014
1015func domainNameLen(s string, off int, compression map[string]struct{}, compress bool) int {
1016	if s == "" || s == "." {
1017		return 1
1018	}
1019
1020	escaped := strings.Contains(s, "\\")
1021
1022	if compression != nil && (compress || off < maxCompressionOffset) {
1023		// compressionLenSearch will insert the entry into the compression
1024		// map if it doesn't contain it.
1025		if l, ok := compressionLenSearch(compression, s, off); ok && compress {
1026			if escaped {
1027				return escapedNameLen(s[:l]) + 2
1028			}
1029
1030			return l + 2
1031		}
1032	}
1033
1034	if escaped {
1035		return escapedNameLen(s) + 1
1036	}
1037
1038	return len(s) + 1
1039}
1040
1041func escapedNameLen(s string) int {
1042	nameLen := len(s)
1043	for i := 0; i < len(s); i++ {
1044		if s[i] != '\\' {
1045			continue
1046		}
1047
1048		if i+3 < len(s) && isDigit(s[i+1]) && isDigit(s[i+2]) && isDigit(s[i+3]) {
1049			nameLen -= 3
1050			i += 3
1051		} else {
1052			nameLen--
1053			i++
1054		}
1055	}
1056
1057	return nameLen
1058}
1059
1060func compressionLenSearch(c map[string]struct{}, s string, msgOff int) (int, bool) {
1061	for off, end := 0, false; !end; off, end = NextLabel(s, off) {
1062		if _, ok := c[s[off:]]; ok {
1063			return off, true
1064		}
1065
1066		if msgOff+off < maxCompressionOffset {
1067			c[s[off:]] = struct{}{}
1068		}
1069	}
1070
1071	return 0, false
1072}
1073
1074// Copy returns a new RR which is a deep-copy of r.
1075func Copy(r RR) RR { return r.copy() }
1076
1077// Len returns the length (in octets) of the uncompressed RR in wire format.
1078func Len(r RR) int { return r.len(0, nil) }
1079
1080// Copy returns a new *Msg which is a deep-copy of dns.
1081func (dns *Msg) Copy() *Msg { return dns.CopyTo(new(Msg)) }
1082
1083// CopyTo copies the contents to the provided message using a deep-copy and returns the copy.
1084func (dns *Msg) CopyTo(r1 *Msg) *Msg {
1085	r1.MsgHdr = dns.MsgHdr
1086	r1.Compress = dns.Compress
1087
1088	if len(dns.Question) > 0 {
1089		r1.Question = make([]Question, len(dns.Question))
1090		copy(r1.Question, dns.Question) // TODO(miek): Question is an immutable value, ok to do a shallow-copy
1091	}
1092
1093	rrArr := make([]RR, len(dns.Answer)+len(dns.Ns)+len(dns.Extra))
1094	r1.Answer, rrArr = rrArr[:0:len(dns.Answer)], rrArr[len(dns.Answer):]
1095	r1.Ns, rrArr = rrArr[:0:len(dns.Ns)], rrArr[len(dns.Ns):]
1096	r1.Extra = rrArr[:0:len(dns.Extra)]
1097
1098	for _, r := range dns.Answer {
1099		r1.Answer = append(r1.Answer, r.copy())
1100	}
1101
1102	for _, r := range dns.Ns {
1103		r1.Ns = append(r1.Ns, r.copy())
1104	}
1105
1106	for _, r := range dns.Extra {
1107		r1.Extra = append(r1.Extra, r.copy())
1108	}
1109
1110	return r1
1111}
1112
1113func (q *Question) pack(msg []byte, off int, compression compressionMap, compress bool) (int, error) {
1114	off, err := packDomainName(q.Name, msg, off, compression, compress)
1115	if err != nil {
1116		return off, err
1117	}
1118	off, err = packUint16(q.Qtype, msg, off)
1119	if err != nil {
1120		return off, err
1121	}
1122	off, err = packUint16(q.Qclass, msg, off)
1123	if err != nil {
1124		return off, err
1125	}
1126	return off, nil
1127}
1128
1129func unpackQuestion(msg []byte, off int) (Question, int, error) {
1130	var (
1131		q   Question
1132		err error
1133	)
1134	q.Name, off, err = UnpackDomainName(msg, off)
1135	if err != nil {
1136		return q, off, err
1137	}
1138	if off == len(msg) {
1139		return q, off, nil
1140	}
1141	q.Qtype, off, err = unpackUint16(msg, off)
1142	if err != nil {
1143		return q, off, err
1144	}
1145	if off == len(msg) {
1146		return q, off, nil
1147	}
1148	q.Qclass, off, err = unpackUint16(msg, off)
1149	if off == len(msg) {
1150		return q, off, nil
1151	}
1152	return q, off, err
1153}
1154
1155func (dh *Header) pack(msg []byte, off int, compression compressionMap, compress bool) (int, error) {
1156	off, err := packUint16(dh.Id, msg, off)
1157	if err != nil {
1158		return off, err
1159	}
1160	off, err = packUint16(dh.Bits, msg, off)
1161	if err != nil {
1162		return off, err
1163	}
1164	off, err = packUint16(dh.Qdcount, msg, off)
1165	if err != nil {
1166		return off, err
1167	}
1168	off, err = packUint16(dh.Ancount, msg, off)
1169	if err != nil {
1170		return off, err
1171	}
1172	off, err = packUint16(dh.Nscount, msg, off)
1173	if err != nil {
1174		return off, err
1175	}
1176	off, err = packUint16(dh.Arcount, msg, off)
1177	if err != nil {
1178		return off, err
1179	}
1180	return off, nil
1181}
1182
1183func unpackMsgHdr(msg []byte, off int) (Header, int, error) {
1184	var (
1185		dh  Header
1186		err error
1187	)
1188	dh.Id, off, err = unpackUint16(msg, off)
1189	if err != nil {
1190		return dh, off, err
1191	}
1192	dh.Bits, off, err = unpackUint16(msg, off)
1193	if err != nil {
1194		return dh, off, err
1195	}
1196	dh.Qdcount, off, err = unpackUint16(msg, off)
1197	if err != nil {
1198		return dh, off, err
1199	}
1200	dh.Ancount, off, err = unpackUint16(msg, off)
1201	if err != nil {
1202		return dh, off, err
1203	}
1204	dh.Nscount, off, err = unpackUint16(msg, off)
1205	if err != nil {
1206		return dh, off, err
1207	}
1208	dh.Arcount, off, err = unpackUint16(msg, off)
1209	if err != nil {
1210		return dh, off, err
1211	}
1212	return dh, off, nil
1213}
1214
1215// setHdr set the header in the dns using the binary data in dh.
1216func (dns *Msg) setHdr(dh Header) {
1217	dns.Id = dh.Id
1218	dns.Response = dh.Bits&_QR != 0
1219	dns.Opcode = int(dh.Bits>>11) & 0xF
1220	dns.Authoritative = dh.Bits&_AA != 0
1221	dns.Truncated = dh.Bits&_TC != 0
1222	dns.RecursionDesired = dh.Bits&_RD != 0
1223	dns.RecursionAvailable = dh.Bits&_RA != 0
1224	dns.Zero = dh.Bits&_Z != 0 // _Z covers the zero bit, which should be zero; not sure why we set it to the opposite.
1225	dns.AuthenticatedData = dh.Bits&_AD != 0
1226	dns.CheckingDisabled = dh.Bits&_CD != 0
1227	dns.Rcode = int(dh.Bits & 0xF)
1228}
1229