1// DNS packet assembly, see RFC 1035. Converting from - Unpack() -
2// and to - Pack() - wire format.
3// All the packers and unpackers take a (msg []byte, off int)
4// and return (off1 int, ok bool).  If they return ok==false, they
5// also return off1==len(msg), so that the next unpacker will
6// also fail.  This lets us avoid checks of ok until the end of a
7// packing sequence.
8
9package dns
10
11//go:generate go run msg_generate.go
12
13import (
14	"crypto/rand"
15	"encoding/binary"
16	"fmt"
17	"math/big"
18	"strconv"
19	"strings"
20)
21
22const (
23	maxCompressionOffset    = 2 << 13 // We have 14 bits for the compression pointer
24	maxDomainNameWireOctets = 255     // See RFC 1035 section 2.3.4
25
26	// This is the maximum number of compression pointers that should occur in a
27	// semantically valid message. Each label in a domain name must be at least one
28	// octet and is separated by a period. The root label won't be represented by a
29	// compression pointer to a compression pointer, hence the -2 to exclude the
30	// smallest valid root label.
31	//
32	// It is possible to construct a valid message that has more compression pointers
33	// than this, and still doesn't loop, by pointing to a previous pointer. This is
34	// not something a well written implementation should ever do, so we leave them
35	// to trip the maximum compression pointer check.
36	maxCompressionPointers = (maxDomainNameWireOctets+1)/2 - 2
37
38	// This is the maximum length of a domain name in presentation format. The
39	// maximum wire length of a domain name is 255 octets (see above), with the
40	// maximum label length being 63. The wire format requires one extra byte over
41	// the presentation format, reducing the number of octets by 1. Each label in
42	// the name will be separated by a single period, with each octet in the label
43	// expanding to at most 4 bytes (\DDD). If all other labels are of the maximum
44	// length, then the final label can only be 61 octets long to not exceed the
45	// maximum allowed wire length.
46	maxDomainNamePresentationLength = 61*4 + 1 + 63*4 + 1 + 63*4 + 1 + 63*4 + 1
47)
48
49// Errors defined in this package.
50var (
51	ErrAlg           error = &Error{err: "bad algorithm"}                  // ErrAlg indicates an error with the (DNSSEC) algorithm.
52	ErrAuth          error = &Error{err: "bad authentication"}             // ErrAuth indicates an error in the TSIG authentication.
53	ErrBuf           error = &Error{err: "buffer size too small"}          // ErrBuf indicates that the buffer used is too small for the message.
54	ErrConnEmpty     error = &Error{err: "conn has no connection"}         // ErrConnEmpty indicates a connection is being used before it is initialized.
55	ErrExtendedRcode error = &Error{err: "bad extended rcode"}             // ErrExtendedRcode ...
56	ErrFqdn          error = &Error{err: "domain must be fully qualified"} // ErrFqdn indicates that a domain name does not have a closing dot.
57	ErrId            error = &Error{err: "id mismatch"}                    // ErrId indicates there is a mismatch with the message's ID.
58	ErrKeyAlg        error = &Error{err: "bad key algorithm"}              // ErrKeyAlg indicates that the algorithm in the key is not valid.
59	ErrKey           error = &Error{err: "bad key"}
60	ErrKeySize       error = &Error{err: "bad key size"}
61	ErrLongDomain    error = &Error{err: fmt.Sprintf("domain name exceeded %d wire-format octets", maxDomainNameWireOctets)}
62	ErrNoSig         error = &Error{err: "no signature found"}
63	ErrPrivKey       error = &Error{err: "bad private key"}
64	ErrRcode         error = &Error{err: "bad rcode"}
65	ErrRdata         error = &Error{err: "bad rdata"}
66	ErrRRset         error = &Error{err: "bad rrset"}
67	ErrSecret        error = &Error{err: "no secrets defined"}
68	ErrShortRead     error = &Error{err: "short read"}
69	ErrSig           error = &Error{err: "bad signature"} // ErrSig indicates that a signature can not be cryptographically validated.
70	ErrSoa           error = &Error{err: "no SOA"}        // ErrSOA indicates that no SOA RR was seen when doing zone transfers.
71	ErrTime          error = &Error{err: "bad time"}      // ErrTime indicates a timing error in TSIG authentication.
72)
73
74// Id by default returns a 16-bit random number to be used as a message id. The
75// number is drawn from a cryptographically secure random number generator.
76// This being a variable the function can be reassigned to a custom function.
77// For instance, to make it return a static value for testing:
78//
79//	dns.Id = func() uint16 { return 3 }
80var Id = id
81
82// id returns a 16 bits random number to be used as a
83// message id. The random provided should be good enough.
84func id() uint16 {
85	var output uint16
86	err := binary.Read(rand.Reader, binary.BigEndian, &output)
87	if err != nil {
88		panic("dns: reading random id failed: " + err.Error())
89	}
90	return output
91}
92
93// MsgHdr is a a manually-unpacked version of (id, bits).
94type MsgHdr struct {
95	Id                 uint16
96	Response           bool
97	Opcode             int
98	Authoritative      bool
99	Truncated          bool
100	RecursionDesired   bool
101	RecursionAvailable bool
102	Zero               bool
103	AuthenticatedData  bool
104	CheckingDisabled   bool
105	Rcode              int
106}
107
108// Msg contains the layout of a DNS message.
109type Msg struct {
110	MsgHdr
111	Compress bool       `json:"-"` // If true, the message will be compressed when converted to wire format.
112	Question []Question // Holds the RR(s) of the question section.
113	Answer   []RR       // Holds the RR(s) of the answer section.
114	Ns       []RR       // Holds the RR(s) of the authority section.
115	Extra    []RR       // Holds the RR(s) of the additional section.
116}
117
118// ClassToString is a maps Classes to strings for each CLASS wire type.
119var ClassToString = map[uint16]string{
120	ClassINET:   "IN",
121	ClassCSNET:  "CS",
122	ClassCHAOS:  "CH",
123	ClassHESIOD: "HS",
124	ClassNONE:   "NONE",
125	ClassANY:    "ANY",
126}
127
128// OpcodeToString maps Opcodes to strings.
129var OpcodeToString = map[int]string{
130	OpcodeQuery:  "QUERY",
131	OpcodeIQuery: "IQUERY",
132	OpcodeStatus: "STATUS",
133	OpcodeNotify: "NOTIFY",
134	OpcodeUpdate: "UPDATE",
135}
136
137// RcodeToString maps Rcodes to strings.
138var RcodeToString = map[int]string{
139	RcodeSuccess:        "NOERROR",
140	RcodeFormatError:    "FORMERR",
141	RcodeServerFailure:  "SERVFAIL",
142	RcodeNameError:      "NXDOMAIN",
143	RcodeNotImplemented: "NOTIMP",
144	RcodeRefused:        "REFUSED",
145	RcodeYXDomain:       "YXDOMAIN", // See RFC 2136
146	RcodeYXRrset:        "YXRRSET",
147	RcodeNXRrset:        "NXRRSET",
148	RcodeNotAuth:        "NOTAUTH",
149	RcodeNotZone:        "NOTZONE",
150	RcodeBadSig:         "BADSIG", // Also known as RcodeBadVers, see RFC 6891
151	//	RcodeBadVers:        "BADVERS",
152	RcodeBadKey:    "BADKEY",
153	RcodeBadTime:   "BADTIME",
154	RcodeBadMode:   "BADMODE",
155	RcodeBadName:   "BADNAME",
156	RcodeBadAlg:    "BADALG",
157	RcodeBadTrunc:  "BADTRUNC",
158	RcodeBadCookie: "BADCOOKIE",
159}
160
161// compressionMap is used to allow a more efficient compression map
162// to be used for internal packDomainName calls without changing the
163// signature or functionality of public API.
164//
165// In particular, map[string]uint16 uses 25% less per-entry memory
166// than does map[string]int.
167type compressionMap struct {
168	ext map[string]int    // external callers
169	int map[string]uint16 // internal callers
170}
171
172func (m compressionMap) valid() bool {
173	return m.int != nil || m.ext != nil
174}
175
176func (m compressionMap) insert(s string, pos int) {
177	if m.ext != nil {
178		m.ext[s] = pos
179	} else {
180		m.int[s] = uint16(pos)
181	}
182}
183
184func (m compressionMap) find(s string) (int, bool) {
185	if m.ext != nil {
186		pos, ok := m.ext[s]
187		return pos, ok
188	}
189
190	pos, ok := m.int[s]
191	return int(pos), ok
192}
193
194// Domain names are a sequence of counted strings
195// split at the dots. They end with a zero-length string.
196
197// PackDomainName packs a domain name s into msg[off:].
198// If compression is wanted compress must be true and the compression
199// map needs to hold a mapping between domain names and offsets
200// pointing into msg.
201func PackDomainName(s string, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
202	return packDomainName(s, msg, off, compressionMap{ext: compression}, compress)
203}
204
205func packDomainName(s string, msg []byte, off int, compression compressionMap, compress bool) (off1 int, err error) {
206	// XXX: A logical copy of this function exists in IsDomainName and
207	// should be kept in sync with this function.
208
209	ls := len(s)
210	if ls == 0 { // Ok, for instance when dealing with update RR without any rdata.
211		return off, nil
212	}
213
214	// If not fully qualified, error out.
215	if !IsFqdn(s) {
216		return len(msg), ErrFqdn
217	}
218
219	// Each dot ends a segment of the name.
220	// We trade each dot byte for a length byte.
221	// Except for escaped dots (\.), which are normal dots.
222	// There is also a trailing zero.
223
224	// Compression
225	pointer := -1
226
227	// Emit sequence of counted strings, chopping at dots.
228	var (
229		begin     int
230		compBegin int
231		compOff   int
232		bs        []byte
233		wasDot    bool
234	)
235loop:
236	for i := 0; i < ls; i++ {
237		var c byte
238		if bs == nil {
239			c = s[i]
240		} else {
241			c = bs[i]
242		}
243
244		switch c {
245		case '\\':
246			if off+1 > len(msg) {
247				return len(msg), ErrBuf
248			}
249
250			if bs == nil {
251				bs = []byte(s)
252			}
253
254			// check for \DDD
255			if i+3 < ls && isDigit(bs[i+1]) && isDigit(bs[i+2]) && isDigit(bs[i+3]) {
256				bs[i] = dddToByte(bs[i+1:])
257				copy(bs[i+1:ls-3], bs[i+4:])
258				ls -= 3
259				compOff += 3
260			} else {
261				copy(bs[i:ls-1], bs[i+1:])
262				ls--
263				compOff++
264			}
265
266			wasDot = false
267		case '.':
268			if wasDot {
269				// two dots back to back is not legal
270				return len(msg), ErrRdata
271			}
272			wasDot = true
273
274			labelLen := i - begin
275			if labelLen >= 1<<6 { // top two bits of length must be clear
276				return len(msg), ErrRdata
277			}
278
279			// off can already (we're in a loop) be bigger than len(msg)
280			// this happens when a name isn't fully qualified
281			if off+1+labelLen > len(msg) {
282				return len(msg), ErrBuf
283			}
284
285			// Don't try to compress '.'
286			// We should only compress when compress is true, but we should also still pick
287			// up names that can be used for *future* compression(s).
288			if compression.valid() && !isRootLabel(s, bs, begin, ls) {
289				if p, ok := compression.find(s[compBegin:]); ok {
290					// The first hit is the longest matching dname
291					// keep the pointer offset we get back and store
292					// the offset of the current name, because that's
293					// where we need to insert the pointer later
294
295					// If compress is true, we're allowed to compress this dname
296					if compress {
297						pointer = p // Where to point to
298						break loop
299					}
300				} else if off < maxCompressionOffset {
301					// Only offsets smaller than maxCompressionOffset can be used.
302					compression.insert(s[compBegin:], off)
303				}
304			}
305
306			// The following is covered by the length check above.
307			msg[off] = byte(labelLen)
308
309			if bs == nil {
310				copy(msg[off+1:], s[begin:i])
311			} else {
312				copy(msg[off+1:], bs[begin:i])
313			}
314			off += 1 + labelLen
315
316			begin = i + 1
317			compBegin = begin + compOff
318		default:
319			wasDot = false
320		}
321	}
322
323	// Root label is special
324	if isRootLabel(s, bs, 0, ls) {
325		return off, nil
326	}
327
328	// If we did compression and we find something add the pointer here
329	if pointer != -1 {
330		// We have two bytes (14 bits) to put the pointer in
331		binary.BigEndian.PutUint16(msg[off:], uint16(pointer^0xC000))
332		return off + 2, nil
333	}
334
335	if off < len(msg) {
336		msg[off] = 0
337	}
338
339	return off + 1, nil
340}
341
342// isRootLabel returns whether s or bs, from off to end, is the root
343// label ".".
344//
345// If bs is nil, s will be checked, otherwise bs will be checked.
346func isRootLabel(s string, bs []byte, off, end int) bool {
347	if bs == nil {
348		return s[off:end] == "."
349	}
350
351	return end-off == 1 && bs[off] == '.'
352}
353
354// Unpack a domain name.
355// In addition to the simple sequences of counted strings above,
356// domain names are allowed to refer to strings elsewhere in the
357// packet, to avoid repeating common suffixes when returning
358// many entries in a single domain.  The pointers are marked
359// by a length byte with the top two bits set.  Ignoring those
360// two bits, that byte and the next give a 14 bit offset from msg[0]
361// where we should pick up the trail.
362// Note that if we jump elsewhere in the packet,
363// we return off1 == the offset after the first pointer we found,
364// which is where the next record will start.
365// In theory, the pointers are only allowed to jump backward.
366// We let them jump anywhere and stop jumping after a while.
367
368// UnpackDomainName unpacks a domain name into a string. It returns
369// the name, the new offset into msg and any error that occurred.
370//
371// When an error is encountered, the unpacked name will be discarded
372// and len(msg) will be returned as the offset.
373func UnpackDomainName(msg []byte, off int) (string, int, error) {
374	s := make([]byte, 0, maxDomainNamePresentationLength)
375	off1 := 0
376	lenmsg := len(msg)
377	budget := maxDomainNameWireOctets
378	ptr := 0 // number of pointers followed
379Loop:
380	for {
381		if off >= lenmsg {
382			return "", lenmsg, ErrBuf
383		}
384		c := int(msg[off])
385		off++
386		switch c & 0xC0 {
387		case 0x00:
388			if c == 0x00 {
389				// end of name
390				break Loop
391			}
392			// literal string
393			if off+c > lenmsg {
394				return "", lenmsg, ErrBuf
395			}
396			budget -= c + 1 // +1 for the label separator
397			if budget <= 0 {
398				return "", lenmsg, ErrLongDomain
399			}
400			for _, b := range msg[off : off+c] {
401				switch b {
402				case '.', '(', ')', ';', ' ', '@':
403					fallthrough
404				case '"', '\\':
405					s = append(s, '\\', b)
406				default:
407					if b < ' ' || b > '~' { // unprintable, use \DDD
408						s = append(s, escapeByte(b)...)
409					} else {
410						s = append(s, b)
411					}
412				}
413			}
414			s = append(s, '.')
415			off += c
416		case 0xC0:
417			// pointer to somewhere else in msg.
418			// remember location after first ptr,
419			// since that's how many bytes we consumed.
420			// also, don't follow too many pointers --
421			// maybe there's a loop.
422			if off >= lenmsg {
423				return "", lenmsg, ErrBuf
424			}
425			c1 := msg[off]
426			off++
427			if ptr == 0 {
428				off1 = off
429			}
430			if ptr++; ptr > maxCompressionPointers {
431				return "", lenmsg, &Error{err: "too many compression pointers"}
432			}
433			// pointer should guarantee that it advances and points forwards at least
434			// but the condition on previous three lines guarantees that it's
435			// at least loop-free
436			off = (c^0xC0)<<8 | int(c1)
437		default:
438			// 0x80 and 0x40 are reserved
439			return "", lenmsg, ErrRdata
440		}
441	}
442	if ptr == 0 {
443		off1 = off
444	}
445	if len(s) == 0 {
446		return ".", off1, nil
447	}
448	return string(s), off1, nil
449}
450
451func packTxt(txt []string, msg []byte, offset int, tmp []byte) (int, error) {
452	if len(txt) == 0 {
453		if offset >= len(msg) {
454			return offset, ErrBuf
455		}
456		msg[offset] = 0
457		return offset, nil
458	}
459	var err error
460	for _, s := range txt {
461		if len(s) > len(tmp) {
462			return offset, ErrBuf
463		}
464		offset, err = packTxtString(s, msg, offset, tmp)
465		if err != nil {
466			return offset, err
467		}
468	}
469	return offset, nil
470}
471
472func packTxtString(s string, msg []byte, offset int, tmp []byte) (int, error) {
473	lenByteOffset := offset
474	if offset >= len(msg) || len(s) > len(tmp) {
475		return offset, ErrBuf
476	}
477	offset++
478	bs := tmp[:len(s)]
479	copy(bs, s)
480	for i := 0; i < len(bs); i++ {
481		if len(msg) <= offset {
482			return offset, ErrBuf
483		}
484		if bs[i] == '\\' {
485			i++
486			if i == len(bs) {
487				break
488			}
489			// check for \DDD
490			if i+2 < len(bs) && isDigit(bs[i]) && isDigit(bs[i+1]) && isDigit(bs[i+2]) {
491				msg[offset] = dddToByte(bs[i:])
492				i += 2
493			} else {
494				msg[offset] = bs[i]
495			}
496		} else {
497			msg[offset] = bs[i]
498		}
499		offset++
500	}
501	l := offset - lenByteOffset - 1
502	if l > 255 {
503		return offset, &Error{err: "string exceeded 255 bytes in txt"}
504	}
505	msg[lenByteOffset] = byte(l)
506	return offset, nil
507}
508
509func packOctetString(s string, msg []byte, offset int, tmp []byte) (int, error) {
510	if offset >= len(msg) || len(s) > len(tmp) {
511		return offset, ErrBuf
512	}
513	bs := tmp[:len(s)]
514	copy(bs, s)
515	for i := 0; i < len(bs); i++ {
516		if len(msg) <= offset {
517			return offset, ErrBuf
518		}
519		if bs[i] == '\\' {
520			i++
521			if i == len(bs) {
522				break
523			}
524			// check for \DDD
525			if i+2 < len(bs) && isDigit(bs[i]) && isDigit(bs[i+1]) && isDigit(bs[i+2]) {
526				msg[offset] = dddToByte(bs[i:])
527				i += 2
528			} else {
529				msg[offset] = bs[i]
530			}
531		} else {
532			msg[offset] = bs[i]
533		}
534		offset++
535	}
536	return offset, nil
537}
538
539func unpackTxt(msg []byte, off0 int) (ss []string, off int, err error) {
540	off = off0
541	var s string
542	for off < len(msg) && err == nil {
543		s, off, err = unpackString(msg, off)
544		if err == nil {
545			ss = append(ss, s)
546		}
547	}
548	return
549}
550
551// Helpers for dealing with escaped bytes
552func isDigit(b byte) bool { return b >= '0' && b <= '9' }
553
554func dddToByte(s []byte) byte {
555	_ = s[2] // bounds check hint to compiler; see golang.org/issue/14808
556	return byte((s[0]-'0')*100 + (s[1]-'0')*10 + (s[2] - '0'))
557}
558
559func dddStringToByte(s string) byte {
560	_ = s[2] // bounds check hint to compiler; see golang.org/issue/14808
561	return byte((s[0]-'0')*100 + (s[1]-'0')*10 + (s[2] - '0'))
562}
563
564// Helper function for packing and unpacking
565func intToBytes(i *big.Int, length int) []byte {
566	buf := i.Bytes()
567	if len(buf) < length {
568		b := make([]byte, length)
569		copy(b[length-len(buf):], buf)
570		return b
571	}
572	return buf
573}
574
575// PackRR packs a resource record rr into msg[off:].
576// See PackDomainName for documentation about the compression.
577func PackRR(rr RR, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
578	headerEnd, off1, err := packRR(rr, msg, off, compressionMap{ext: compression}, compress)
579	if err == nil {
580		// packRR no longer sets the Rdlength field on the rr, but
581		// callers might be expecting it so we set it here.
582		rr.Header().Rdlength = uint16(off1 - headerEnd)
583	}
584	return off1, err
585}
586
587func packRR(rr RR, msg []byte, off int, compression compressionMap, compress bool) (headerEnd int, off1 int, err error) {
588	if rr == nil {
589		return len(msg), len(msg), &Error{err: "nil rr"}
590	}
591
592	headerEnd, err = rr.Header().packHeader(msg, off, compression, compress)
593	if err != nil {
594		return headerEnd, len(msg), err
595	}
596
597	off1, err = rr.pack(msg, headerEnd, compression, compress)
598	if err != nil {
599		return headerEnd, len(msg), err
600	}
601
602	rdlength := off1 - headerEnd
603	if int(uint16(rdlength)) != rdlength { // overflow
604		return headerEnd, len(msg), ErrRdata
605	}
606
607	// The RDLENGTH field is the last field in the header and we set it here.
608	binary.BigEndian.PutUint16(msg[headerEnd-2:], uint16(rdlength))
609	return headerEnd, off1, nil
610}
611
612// UnpackRR unpacks msg[off:] into an RR.
613func UnpackRR(msg []byte, off int) (rr RR, off1 int, err error) {
614	h, off, msg, err := unpackHeader(msg, off)
615	if err != nil {
616		return nil, len(msg), err
617	}
618
619	return UnpackRRWithHeader(h, msg, off)
620}
621
622// UnpackRRWithHeader unpacks the record type specific payload given an existing
623// RR_Header.
624func UnpackRRWithHeader(h RR_Header, msg []byte, off int) (rr RR, off1 int, err error) {
625	if newFn, ok := TypeToRR[h.Rrtype]; ok {
626		rr = newFn()
627		*rr.Header() = h
628	} else {
629		rr = &RFC3597{Hdr: h}
630	}
631
632	if noRdata(h) {
633		return rr, off, nil
634	}
635
636	end := off + int(h.Rdlength)
637
638	off, err = rr.unpack(msg, off)
639	if err != nil {
640		return nil, end, err
641	}
642	if off != end {
643		return &h, end, &Error{err: "bad rdlength"}
644	}
645
646	return rr, off, nil
647}
648
649// unpackRRslice unpacks msg[off:] into an []RR.
650// If we cannot unpack the whole array, then it will return nil
651func unpackRRslice(l int, msg []byte, off int) (dst1 []RR, off1 int, err error) {
652	var r RR
653	// Don't pre-allocate, l may be under attacker control
654	var dst []RR
655	for i := 0; i < l; i++ {
656		off1 := off
657		r, off, err = UnpackRR(msg, off)
658		if err != nil {
659			off = len(msg)
660			break
661		}
662		// If offset does not increase anymore, l is a lie
663		if off1 == off {
664			l = i
665			break
666		}
667		dst = append(dst, r)
668	}
669	if err != nil && off == len(msg) {
670		dst = nil
671	}
672	return dst, off, err
673}
674
675// Convert a MsgHdr to a string, with dig-like headers:
676//
677//;; opcode: QUERY, status: NOERROR, id: 48404
678//
679//;; flags: qr aa rd ra;
680func (h *MsgHdr) String() string {
681	if h == nil {
682		return "<nil> MsgHdr"
683	}
684
685	s := ";; opcode: " + OpcodeToString[h.Opcode]
686	s += ", status: " + RcodeToString[h.Rcode]
687	s += ", id: " + strconv.Itoa(int(h.Id)) + "\n"
688
689	s += ";; flags:"
690	if h.Response {
691		s += " qr"
692	}
693	if h.Authoritative {
694		s += " aa"
695	}
696	if h.Truncated {
697		s += " tc"
698	}
699	if h.RecursionDesired {
700		s += " rd"
701	}
702	if h.RecursionAvailable {
703		s += " ra"
704	}
705	if h.Zero { // Hmm
706		s += " z"
707	}
708	if h.AuthenticatedData {
709		s += " ad"
710	}
711	if h.CheckingDisabled {
712		s += " cd"
713	}
714
715	s += ";"
716	return s
717}
718
719// Pack packs a Msg: it is converted to to wire format.
720// If the dns.Compress is true the message will be in compressed wire format.
721func (dns *Msg) Pack() (msg []byte, err error) {
722	return dns.PackBuffer(nil)
723}
724
725// PackBuffer packs a Msg, using the given buffer buf. If buf is too small a new buffer is allocated.
726func (dns *Msg) PackBuffer(buf []byte) (msg []byte, err error) {
727	// If this message can't be compressed, avoid filling the
728	// compression map and creating garbage.
729	if dns.Compress && dns.isCompressible() {
730		compression := make(map[string]uint16) // Compression pointer mappings.
731		return dns.packBufferWithCompressionMap(buf, compressionMap{int: compression}, true)
732	}
733
734	return dns.packBufferWithCompressionMap(buf, compressionMap{}, false)
735}
736
737// packBufferWithCompressionMap packs a Msg, using the given buffer buf.
738func (dns *Msg) packBufferWithCompressionMap(buf []byte, compression compressionMap, compress bool) (msg []byte, err error) {
739	if dns.Rcode < 0 || dns.Rcode > 0xFFF {
740		return nil, ErrRcode
741	}
742
743	// Set extended rcode unconditionally if we have an opt, this will allow
744	// reseting the extended rcode bits if they need to.
745	if opt := dns.IsEdns0(); opt != nil {
746		opt.SetExtendedRcode(uint16(dns.Rcode))
747	} else if dns.Rcode > 0xF {
748		// If Rcode is an extended one and opt is nil, error out.
749		return nil, ErrExtendedRcode
750	}
751
752	// Convert convenient Msg into wire-like Header.
753	var dh Header
754	dh.Id = dns.Id
755	dh.Bits = uint16(dns.Opcode)<<11 | uint16(dns.Rcode&0xF)
756	if dns.Response {
757		dh.Bits |= _QR
758	}
759	if dns.Authoritative {
760		dh.Bits |= _AA
761	}
762	if dns.Truncated {
763		dh.Bits |= _TC
764	}
765	if dns.RecursionDesired {
766		dh.Bits |= _RD
767	}
768	if dns.RecursionAvailable {
769		dh.Bits |= _RA
770	}
771	if dns.Zero {
772		dh.Bits |= _Z
773	}
774	if dns.AuthenticatedData {
775		dh.Bits |= _AD
776	}
777	if dns.CheckingDisabled {
778		dh.Bits |= _CD
779	}
780
781	dh.Qdcount = uint16(len(dns.Question))
782	dh.Ancount = uint16(len(dns.Answer))
783	dh.Nscount = uint16(len(dns.Ns))
784	dh.Arcount = uint16(len(dns.Extra))
785
786	// We need the uncompressed length here, because we first pack it and then compress it.
787	msg = buf
788	uncompressedLen := msgLenWithCompressionMap(dns, nil)
789	if packLen := uncompressedLen + 1; len(msg) < packLen {
790		msg = make([]byte, packLen)
791	}
792
793	// Pack it in: header and then the pieces.
794	off := 0
795	off, err = dh.pack(msg, off, compression, compress)
796	if err != nil {
797		return nil, err
798	}
799	for _, r := range dns.Question {
800		off, err = r.pack(msg, off, compression, compress)
801		if err != nil {
802			return nil, err
803		}
804	}
805	for _, r := range dns.Answer {
806		_, off, err = packRR(r, msg, off, compression, compress)
807		if err != nil {
808			return nil, err
809		}
810	}
811	for _, r := range dns.Ns {
812		_, off, err = packRR(r, msg, off, compression, compress)
813		if err != nil {
814			return nil, err
815		}
816	}
817	for _, r := range dns.Extra {
818		_, off, err = packRR(r, msg, off, compression, compress)
819		if err != nil {
820			return nil, err
821		}
822	}
823	return msg[:off], nil
824}
825
826func (dns *Msg) unpack(dh Header, msg []byte, off int) (err error) {
827	// If we are at the end of the message we should return *just* the
828	// header. This can still be useful to the caller. 9.9.9.9 sends these
829	// when responding with REFUSED for instance.
830	if off == len(msg) {
831		// reset sections before returning
832		dns.Question, dns.Answer, dns.Ns, dns.Extra = nil, nil, nil, nil
833		return nil
834	}
835
836	// Qdcount, Ancount, Nscount, Arcount can't be trusted, as they are
837	// attacker controlled. This means we can't use them to pre-allocate
838	// slices.
839	dns.Question = nil
840	for i := 0; i < int(dh.Qdcount); i++ {
841		off1 := off
842		var q Question
843		q, off, err = unpackQuestion(msg, off)
844		if err != nil {
845			return err
846		}
847		if off1 == off { // Offset does not increase anymore, dh.Qdcount is a lie!
848			dh.Qdcount = uint16(i)
849			break
850		}
851		dns.Question = append(dns.Question, q)
852	}
853
854	dns.Answer, off, err = unpackRRslice(int(dh.Ancount), msg, off)
855	// The header counts might have been wrong so we need to update it
856	dh.Ancount = uint16(len(dns.Answer))
857	if err == nil {
858		dns.Ns, off, err = unpackRRslice(int(dh.Nscount), msg, off)
859	}
860	// The header counts might have been wrong so we need to update it
861	dh.Nscount = uint16(len(dns.Ns))
862	if err == nil {
863		dns.Extra, off, err = unpackRRslice(int(dh.Arcount), msg, off)
864	}
865	// The header counts might have been wrong so we need to update it
866	dh.Arcount = uint16(len(dns.Extra))
867
868	// Set extended Rcode
869	if opt := dns.IsEdns0(); opt != nil {
870		dns.Rcode |= opt.ExtendedRcode()
871	}
872
873	if off != len(msg) {
874		// TODO(miek) make this an error?
875		// use PackOpt to let people tell how detailed the error reporting should be?
876		// println("dns: extra bytes in dns packet", off, "<", len(msg))
877	}
878	return err
879
880}
881
882// Unpack unpacks a binary message to a Msg structure.
883func (dns *Msg) Unpack(msg []byte) (err error) {
884	dh, off, err := unpackMsgHdr(msg, 0)
885	if err != nil {
886		return err
887	}
888
889	dns.setHdr(dh)
890	return dns.unpack(dh, msg, off)
891}
892
893// Convert a complete message to a string with dig-like output.
894func (dns *Msg) String() string {
895	if dns == nil {
896		return "<nil> MsgHdr"
897	}
898	s := dns.MsgHdr.String() + " "
899	s += "QUERY: " + strconv.Itoa(len(dns.Question)) + ", "
900	s += "ANSWER: " + strconv.Itoa(len(dns.Answer)) + ", "
901	s += "AUTHORITY: " + strconv.Itoa(len(dns.Ns)) + ", "
902	s += "ADDITIONAL: " + strconv.Itoa(len(dns.Extra)) + "\n"
903	if len(dns.Question) > 0 {
904		s += "\n;; QUESTION SECTION:\n"
905		for _, r := range dns.Question {
906			s += r.String() + "\n"
907		}
908	}
909	if len(dns.Answer) > 0 {
910		s += "\n;; ANSWER SECTION:\n"
911		for _, r := range dns.Answer {
912			if r != nil {
913				s += r.String() + "\n"
914			}
915		}
916	}
917	if len(dns.Ns) > 0 {
918		s += "\n;; AUTHORITY SECTION:\n"
919		for _, r := range dns.Ns {
920			if r != nil {
921				s += r.String() + "\n"
922			}
923		}
924	}
925	if len(dns.Extra) > 0 {
926		s += "\n;; ADDITIONAL SECTION:\n"
927		for _, r := range dns.Extra {
928			if r != nil {
929				s += r.String() + "\n"
930			}
931		}
932	}
933	return s
934}
935
936// isCompressible returns whether the msg may be compressible.
937func (dns *Msg) isCompressible() bool {
938	// If we only have one question, there is nothing we can ever compress.
939	return len(dns.Question) > 1 || len(dns.Answer) > 0 ||
940		len(dns.Ns) > 0 || len(dns.Extra) > 0
941}
942
943// Len returns the message length when in (un)compressed wire format.
944// If dns.Compress is true compression it is taken into account. Len()
945// is provided to be a faster way to get the size of the resulting packet,
946// than packing it, measuring the size and discarding the buffer.
947func (dns *Msg) Len() int {
948	// If this message can't be compressed, avoid filling the
949	// compression map and creating garbage.
950	if dns.Compress && dns.isCompressible() {
951		compression := make(map[string]struct{})
952		return msgLenWithCompressionMap(dns, compression)
953	}
954
955	return msgLenWithCompressionMap(dns, nil)
956}
957
958func msgLenWithCompressionMap(dns *Msg, compression map[string]struct{}) int {
959	l := headerSize
960
961	for _, r := range dns.Question {
962		l += r.len(l, compression)
963	}
964	for _, r := range dns.Answer {
965		if r != nil {
966			l += r.len(l, compression)
967		}
968	}
969	for _, r := range dns.Ns {
970		if r != nil {
971			l += r.len(l, compression)
972		}
973	}
974	for _, r := range dns.Extra {
975		if r != nil {
976			l += r.len(l, compression)
977		}
978	}
979
980	return l
981}
982
983func domainNameLen(s string, off int, compression map[string]struct{}, compress bool) int {
984	if s == "" || s == "." {
985		return 1
986	}
987
988	escaped := strings.Contains(s, "\\")
989
990	if compression != nil && (compress || off < maxCompressionOffset) {
991		// compressionLenSearch will insert the entry into the compression
992		// map if it doesn't contain it.
993		if l, ok := compressionLenSearch(compression, s, off); ok && compress {
994			if escaped {
995				return escapedNameLen(s[:l]) + 2
996			}
997
998			return l + 2
999		}
1000	}
1001
1002	if escaped {
1003		return escapedNameLen(s) + 1
1004	}
1005
1006	return len(s) + 1
1007}
1008
1009func escapedNameLen(s string) int {
1010	nameLen := len(s)
1011	for i := 0; i < len(s); i++ {
1012		if s[i] != '\\' {
1013			continue
1014		}
1015
1016		if i+3 < len(s) && isDigit(s[i+1]) && isDigit(s[i+2]) && isDigit(s[i+3]) {
1017			nameLen -= 3
1018			i += 3
1019		} else {
1020			nameLen--
1021			i++
1022		}
1023	}
1024
1025	return nameLen
1026}
1027
1028func compressionLenSearch(c map[string]struct{}, s string, msgOff int) (int, bool) {
1029	for off, end := 0, false; !end; off, end = NextLabel(s, off) {
1030		if _, ok := c[s[off:]]; ok {
1031			return off, true
1032		}
1033
1034		if msgOff+off < maxCompressionOffset {
1035			c[s[off:]] = struct{}{}
1036		}
1037	}
1038
1039	return 0, false
1040}
1041
1042// Copy returns a new RR which is a deep-copy of r.
1043func Copy(r RR) RR { return r.copy() }
1044
1045// Len returns the length (in octets) of the uncompressed RR in wire format.
1046func Len(r RR) int { return r.len(0, nil) }
1047
1048// Copy returns a new *Msg which is a deep-copy of dns.
1049func (dns *Msg) Copy() *Msg { return dns.CopyTo(new(Msg)) }
1050
1051// CopyTo copies the contents to the provided message using a deep-copy and returns the copy.
1052func (dns *Msg) CopyTo(r1 *Msg) *Msg {
1053	r1.MsgHdr = dns.MsgHdr
1054	r1.Compress = dns.Compress
1055
1056	if len(dns.Question) > 0 {
1057		r1.Question = make([]Question, len(dns.Question))
1058		copy(r1.Question, dns.Question) // TODO(miek): Question is an immutable value, ok to do a shallow-copy
1059	}
1060
1061	rrArr := make([]RR, len(dns.Answer)+len(dns.Ns)+len(dns.Extra))
1062	r1.Answer, rrArr = rrArr[:0:len(dns.Answer)], rrArr[len(dns.Answer):]
1063	r1.Ns, rrArr = rrArr[:0:len(dns.Ns)], rrArr[len(dns.Ns):]
1064	r1.Extra = rrArr[:0:len(dns.Extra)]
1065
1066	for _, r := range dns.Answer {
1067		r1.Answer = append(r1.Answer, r.copy())
1068	}
1069
1070	for _, r := range dns.Ns {
1071		r1.Ns = append(r1.Ns, r.copy())
1072	}
1073
1074	for _, r := range dns.Extra {
1075		r1.Extra = append(r1.Extra, r.copy())
1076	}
1077
1078	return r1
1079}
1080
1081func (q *Question) pack(msg []byte, off int, compression compressionMap, compress bool) (int, error) {
1082	off, err := packDomainName(q.Name, msg, off, compression, compress)
1083	if err != nil {
1084		return off, err
1085	}
1086	off, err = packUint16(q.Qtype, msg, off)
1087	if err != nil {
1088		return off, err
1089	}
1090	off, err = packUint16(q.Qclass, msg, off)
1091	if err != nil {
1092		return off, err
1093	}
1094	return off, nil
1095}
1096
1097func unpackQuestion(msg []byte, off int) (Question, int, error) {
1098	var (
1099		q   Question
1100		err error
1101	)
1102	q.Name, off, err = UnpackDomainName(msg, off)
1103	if err != nil {
1104		return q, off, err
1105	}
1106	if off == len(msg) {
1107		return q, off, nil
1108	}
1109	q.Qtype, off, err = unpackUint16(msg, off)
1110	if err != nil {
1111		return q, off, err
1112	}
1113	if off == len(msg) {
1114		return q, off, nil
1115	}
1116	q.Qclass, off, err = unpackUint16(msg, off)
1117	if off == len(msg) {
1118		return q, off, nil
1119	}
1120	return q, off, err
1121}
1122
1123func (dh *Header) pack(msg []byte, off int, compression compressionMap, compress bool) (int, error) {
1124	off, err := packUint16(dh.Id, msg, off)
1125	if err != nil {
1126		return off, err
1127	}
1128	off, err = packUint16(dh.Bits, msg, off)
1129	if err != nil {
1130		return off, err
1131	}
1132	off, err = packUint16(dh.Qdcount, msg, off)
1133	if err != nil {
1134		return off, err
1135	}
1136	off, err = packUint16(dh.Ancount, msg, off)
1137	if err != nil {
1138		return off, err
1139	}
1140	off, err = packUint16(dh.Nscount, msg, off)
1141	if err != nil {
1142		return off, err
1143	}
1144	off, err = packUint16(dh.Arcount, msg, off)
1145	if err != nil {
1146		return off, err
1147	}
1148	return off, nil
1149}
1150
1151func unpackMsgHdr(msg []byte, off int) (Header, int, error) {
1152	var (
1153		dh  Header
1154		err error
1155	)
1156	dh.Id, off, err = unpackUint16(msg, off)
1157	if err != nil {
1158		return dh, off, err
1159	}
1160	dh.Bits, off, err = unpackUint16(msg, off)
1161	if err != nil {
1162		return dh, off, err
1163	}
1164	dh.Qdcount, off, err = unpackUint16(msg, off)
1165	if err != nil {
1166		return dh, off, err
1167	}
1168	dh.Ancount, off, err = unpackUint16(msg, off)
1169	if err != nil {
1170		return dh, off, err
1171	}
1172	dh.Nscount, off, err = unpackUint16(msg, off)
1173	if err != nil {
1174		return dh, off, err
1175	}
1176	dh.Arcount, off, err = unpackUint16(msg, off)
1177	if err != nil {
1178		return dh, off, err
1179	}
1180	return dh, off, nil
1181}
1182
1183// setHdr set the header in the dns using the binary data in dh.
1184func (dns *Msg) setHdr(dh Header) {
1185	dns.Id = dh.Id
1186	dns.Response = dh.Bits&_QR != 0
1187	dns.Opcode = int(dh.Bits>>11) & 0xF
1188	dns.Authoritative = dh.Bits&_AA != 0
1189	dns.Truncated = dh.Bits&_TC != 0
1190	dns.RecursionDesired = dh.Bits&_RD != 0
1191	dns.RecursionAvailable = dh.Bits&_RA != 0
1192	dns.Zero = dh.Bits&_Z != 0 // _Z covers the zero bit, which should be zero; not sure why we set it to the opposite.
1193	dns.AuthenticatedData = dh.Bits&_AD != 0
1194	dns.CheckingDisabled = dh.Bits&_CD != 0
1195	dns.Rcode = int(dh.Bits & 0xF)
1196}
1197