1// DNS packet assembly, see RFC 1035. Converting from - Unpack() -
2// and to - Pack() - wire format.
3// All the packers and unpackers take a (msg []byte, off int)
4// and return (off1 int, ok bool).  If they return ok==false, they
5// also return off1==len(msg), so that the next unpacker will
6// also fail.  This lets us avoid checks of ok until the end of a
7// packing sequence.
8
9package dns
10
11//go:generate go run msg_generate.go
12//go:generate go run compress_generate.go
13
14import (
15	crand "crypto/rand"
16	"encoding/binary"
17	"fmt"
18	"math/big"
19	"math/rand"
20	"strconv"
21	"sync"
22)
23
24const (
25	maxCompressionOffset    = 2 << 13 // We have 14 bits for the compression pointer
26	maxDomainNameWireOctets = 255     // See RFC 1035 section 2.3.4
27)
28
29// Errors defined in this package.
30var (
31	ErrAlg           error = &Error{err: "bad algorithm"}                  // ErrAlg indicates an error with the (DNSSEC) algorithm.
32	ErrAuth          error = &Error{err: "bad authentication"}             // ErrAuth indicates an error in the TSIG authentication.
33	ErrBuf           error = &Error{err: "buffer size too small"}          // ErrBuf indicates that the buffer used is too small for the message.
34	ErrConnEmpty     error = &Error{err: "conn has no connection"}         // ErrConnEmpty indicates a connection is being used before it is initialized.
35	ErrExtendedRcode error = &Error{err: "bad extended rcode"}             // ErrExtendedRcode ...
36	ErrFqdn          error = &Error{err: "domain must be fully qualified"} // ErrFqdn indicates that a domain name does not have a closing dot.
37	ErrId            error = &Error{err: "id mismatch"}                    // ErrId indicates there is a mismatch with the message's ID.
38	ErrKeyAlg        error = &Error{err: "bad key algorithm"}              // ErrKeyAlg indicates that the algorithm in the key is not valid.
39	ErrKey           error = &Error{err: "bad key"}
40	ErrKeySize       error = &Error{err: "bad key size"}
41	ErrLongDomain    error = &Error{err: fmt.Sprintf("domain name exceeded %d wire-format octets", maxDomainNameWireOctets)}
42	ErrNoSig         error = &Error{err: "no signature found"}
43	ErrPrivKey       error = &Error{err: "bad private key"}
44	ErrRcode         error = &Error{err: "bad rcode"}
45	ErrRdata         error = &Error{err: "bad rdata"}
46	ErrRRset         error = &Error{err: "bad rrset"}
47	ErrSecret        error = &Error{err: "no secrets defined"}
48	ErrShortRead     error = &Error{err: "short read"}
49	ErrSig           error = &Error{err: "bad signature"}                      // ErrSig indicates that a signature can not be cryptographically validated.
50	ErrSoa           error = &Error{err: "no SOA"}                             // ErrSOA indicates that no SOA RR was seen when doing zone transfers.
51	ErrTime          error = &Error{err: "bad time"}                           // ErrTime indicates a timing error in TSIG authentication.
52	ErrTruncated     error = &Error{err: "failed to unpack truncated message"} // ErrTruncated indicates that we failed to unpack a truncated message. We unpacked as much as we had so Msg can still be used, if desired.
53)
54
55// Id by default, returns a 16 bits random number to be used as a
56// message id. The random provided should be good enough. This being a
57// variable the function can be reassigned to a custom function.
58// For instance, to make it return a static value:
59//
60//	dns.Id = func() uint16 { return 3 }
61var Id = id
62
63var (
64	idLock sync.Mutex
65	idRand *rand.Rand
66)
67
68// id returns a 16 bits random number to be used as a
69// message id. The random provided should be good enough.
70func id() uint16 {
71	idLock.Lock()
72
73	if idRand == nil {
74		// This (partially) works around
75		// https://github.com/golang/go/issues/11833 by only
76		// seeding idRand upon the first call to id.
77
78		var seed int64
79		var buf [8]byte
80
81		if _, err := crand.Read(buf[:]); err == nil {
82			seed = int64(binary.LittleEndian.Uint64(buf[:]))
83		} else {
84			seed = rand.Int63()
85		}
86
87		idRand = rand.New(rand.NewSource(seed))
88	}
89
90	// The call to idRand.Uint32 must be within the
91	// mutex lock because *rand.Rand is not safe for
92	// concurrent use.
93	//
94	// There is no added performance overhead to calling
95	// idRand.Uint32 inside a mutex lock over just
96	// calling rand.Uint32 as the global math/rand rng
97	// is internally protected by a sync.Mutex.
98	id := uint16(idRand.Uint32())
99
100	idLock.Unlock()
101	return id
102}
103
104// MsgHdr is a a manually-unpacked version of (id, bits).
105type MsgHdr struct {
106	Id                 uint16
107	Response           bool
108	Opcode             int
109	Authoritative      bool
110	Truncated          bool
111	RecursionDesired   bool
112	RecursionAvailable bool
113	Zero               bool
114	AuthenticatedData  bool
115	CheckingDisabled   bool
116	Rcode              int
117}
118
119// Msg contains the layout of a DNS message.
120type Msg struct {
121	MsgHdr
122	Compress bool       `json:"-"` // If true, the message will be compressed when converted to wire format.
123	Question []Question // Holds the RR(s) of the question section.
124	Answer   []RR       // Holds the RR(s) of the answer section.
125	Ns       []RR       // Holds the RR(s) of the authority section.
126	Extra    []RR       // Holds the RR(s) of the additional section.
127}
128
129// ClassToString is a maps Classes to strings for each CLASS wire type.
130var ClassToString = map[uint16]string{
131	ClassINET:   "IN",
132	ClassCSNET:  "CS",
133	ClassCHAOS:  "CH",
134	ClassHESIOD: "HS",
135	ClassNONE:   "NONE",
136	ClassANY:    "ANY",
137}
138
139// OpcodeToString maps Opcodes to strings.
140var OpcodeToString = map[int]string{
141	OpcodeQuery:  "QUERY",
142	OpcodeIQuery: "IQUERY",
143	OpcodeStatus: "STATUS",
144	OpcodeNotify: "NOTIFY",
145	OpcodeUpdate: "UPDATE",
146}
147
148// RcodeToString maps Rcodes to strings.
149var RcodeToString = map[int]string{
150	RcodeSuccess:        "NOERROR",
151	RcodeFormatError:    "FORMERR",
152	RcodeServerFailure:  "SERVFAIL",
153	RcodeNameError:      "NXDOMAIN",
154	RcodeNotImplemented: "NOTIMPL",
155	RcodeRefused:        "REFUSED",
156	RcodeYXDomain:       "YXDOMAIN", // See RFC 2136
157	RcodeYXRrset:        "YXRRSET",
158	RcodeNXRrset:        "NXRRSET",
159	RcodeNotAuth:        "NOTAUTH",
160	RcodeNotZone:        "NOTZONE",
161	RcodeBadSig:         "BADSIG", // Also known as RcodeBadVers, see RFC 6891
162	//	RcodeBadVers:        "BADVERS",
163	RcodeBadKey:    "BADKEY",
164	RcodeBadTime:   "BADTIME",
165	RcodeBadMode:   "BADMODE",
166	RcodeBadName:   "BADNAME",
167	RcodeBadAlg:    "BADALG",
168	RcodeBadTrunc:  "BADTRUNC",
169	RcodeBadCookie: "BADCOOKIE",
170}
171
172// Domain names are a sequence of counted strings
173// split at the dots. They end with a zero-length string.
174
175// PackDomainName packs a domain name s into msg[off:].
176// If compression is wanted compress must be true and the compression
177// map needs to hold a mapping between domain names and offsets
178// pointing into msg.
179func PackDomainName(s string, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
180	off1, _, err = packDomainName(s, msg, off, compression, compress)
181	return
182}
183
184func packDomainName(s string, msg []byte, off int, compression map[string]int, compress bool) (off1 int, labels int, err error) {
185	// special case if msg == nil
186	lenmsg := 256
187	if msg != nil {
188		lenmsg = len(msg)
189	}
190	ls := len(s)
191	if ls == 0 { // Ok, for instance when dealing with update RR without any rdata.
192		return off, 0, nil
193	}
194	// If not fully qualified, error out, but only if msg == nil #ugly
195	switch {
196	case msg == nil:
197		if s[ls-1] != '.' {
198			s += "."
199			ls++
200		}
201	case msg != nil:
202		if s[ls-1] != '.' {
203			return lenmsg, 0, ErrFqdn
204		}
205	}
206	// Each dot ends a segment of the name.
207	// We trade each dot byte for a length byte.
208	// Except for escaped dots (\.), which are normal dots.
209	// There is also a trailing zero.
210
211	// Compression
212	nameoffset := -1
213	pointer := -1
214	// Emit sequence of counted strings, chopping at dots.
215	begin := 0
216	bs := []byte(s)
217	roBs, bsFresh, escapedDot := s, true, false
218	for i := 0; i < ls; i++ {
219		if bs[i] == '\\' {
220			for j := i; j < ls-1; j++ {
221				bs[j] = bs[j+1]
222			}
223			ls--
224			if off+1 > lenmsg {
225				return lenmsg, labels, ErrBuf
226			}
227			// check for \DDD
228			if i+2 < ls && isDigit(bs[i]) && isDigit(bs[i+1]) && isDigit(bs[i+2]) {
229				bs[i] = dddToByte(bs[i:])
230				for j := i + 1; j < ls-2; j++ {
231					bs[j] = bs[j+2]
232				}
233				ls -= 2
234			}
235			escapedDot = bs[i] == '.'
236			bsFresh = false
237			continue
238		}
239
240		if bs[i] == '.' {
241			if i > 0 && bs[i-1] == '.' && !escapedDot {
242				// two dots back to back is not legal
243				return lenmsg, labels, ErrRdata
244			}
245			if i-begin >= 1<<6 { // top two bits of length must be clear
246				return lenmsg, labels, ErrRdata
247			}
248			// off can already (we're in a loop) be bigger than len(msg)
249			// this happens when a name isn't fully qualified
250			if off+1 > lenmsg {
251				return lenmsg, labels, ErrBuf
252			}
253			if msg != nil {
254				msg[off] = byte(i - begin)
255			}
256			offset := off
257			off++
258			for j := begin; j < i; j++ {
259				if off+1 > lenmsg {
260					return lenmsg, labels, ErrBuf
261				}
262				if msg != nil {
263					msg[off] = bs[j]
264				}
265				off++
266			}
267			if compress && !bsFresh {
268				roBs = string(bs)
269				bsFresh = true
270			}
271			// Don't try to compress '.'
272			// We should only compress when compress it true, but we should also still pick
273			// up names that can be used for *future* compression(s).
274			if compression != nil && roBs[begin:] != "." {
275				if p, ok := compression[roBs[begin:]]; !ok {
276					// Only offsets smaller than this can be used.
277					if offset < maxCompressionOffset {
278						compression[roBs[begin:]] = offset
279					}
280				} else {
281					// The first hit is the longest matching dname
282					// keep the pointer offset we get back and store
283					// the offset of the current name, because that's
284					// where we need to insert the pointer later
285
286					// If compress is true, we're allowed to compress this dname
287					if pointer == -1 && compress {
288						pointer = p         // Where to point to
289						nameoffset = offset // Where to point from
290						break
291					}
292				}
293			}
294			labels++
295			begin = i + 1
296		}
297		escapedDot = false
298	}
299	// Root label is special
300	if len(bs) == 1 && bs[0] == '.' {
301		return off, labels, nil
302	}
303	// If we did compression and we find something add the pointer here
304	if pointer != -1 {
305		// We have two bytes (14 bits) to put the pointer in
306		// if msg == nil, we will never do compression
307		binary.BigEndian.PutUint16(msg[nameoffset:], uint16(pointer^0xC000))
308		off = nameoffset + 1
309		goto End
310	}
311	if msg != nil && off < len(msg) {
312		msg[off] = 0
313	}
314End:
315	off++
316	return off, labels, nil
317}
318
319// Unpack a domain name.
320// In addition to the simple sequences of counted strings above,
321// domain names are allowed to refer to strings elsewhere in the
322// packet, to avoid repeating common suffixes when returning
323// many entries in a single domain.  The pointers are marked
324// by a length byte with the top two bits set.  Ignoring those
325// two bits, that byte and the next give a 14 bit offset from msg[0]
326// where we should pick up the trail.
327// Note that if we jump elsewhere in the packet,
328// we return off1 == the offset after the first pointer we found,
329// which is where the next record will start.
330// In theory, the pointers are only allowed to jump backward.
331// We let them jump anywhere and stop jumping after a while.
332
333// UnpackDomainName unpacks a domain name into a string.
334func UnpackDomainName(msg []byte, off int) (string, int, error) {
335	s := make([]byte, 0, 64)
336	off1 := 0
337	lenmsg := len(msg)
338	maxLen := maxDomainNameWireOctets
339	ptr := 0 // number of pointers followed
340Loop:
341	for {
342		if off >= lenmsg {
343			return "", lenmsg, ErrBuf
344		}
345		c := int(msg[off])
346		off++
347		switch c & 0xC0 {
348		case 0x00:
349			if c == 0x00 {
350				// end of name
351				break Loop
352			}
353			// literal string
354			if off+c > lenmsg {
355				return "", lenmsg, ErrBuf
356			}
357			for j := off; j < off+c; j++ {
358				switch b := msg[j]; b {
359				case '.', '(', ')', ';', ' ', '@':
360					fallthrough
361				case '"', '\\':
362					s = append(s, '\\', b)
363					// presentation-format \X escapes add an extra byte
364					maxLen++
365				default:
366					if b < 32 || b >= 127 { // unprintable, use \DDD
367						var buf [3]byte
368						bufs := strconv.AppendInt(buf[:0], int64(b), 10)
369						s = append(s, '\\')
370						for i := 0; i < 3-len(bufs); i++ {
371							s = append(s, '0')
372						}
373						for _, r := range bufs {
374							s = append(s, r)
375						}
376						// presentation-format \DDD escapes add 3 extra bytes
377						maxLen += 3
378					} else {
379						s = append(s, b)
380					}
381				}
382			}
383			s = append(s, '.')
384			off += c
385		case 0xC0:
386			// pointer to somewhere else in msg.
387			// remember location after first ptr,
388			// since that's how many bytes we consumed.
389			// also, don't follow too many pointers --
390			// maybe there's a loop.
391			if off >= lenmsg {
392				return "", lenmsg, ErrBuf
393			}
394			c1 := msg[off]
395			off++
396			if ptr == 0 {
397				off1 = off
398			}
399			if ptr++; ptr > 10 {
400				return "", lenmsg, &Error{err: "too many compression pointers"}
401			}
402			// pointer should guarantee that it advances and points forwards at least
403			// but the condition on previous three lines guarantees that it's
404			// at least loop-free
405			off = (c^0xC0)<<8 | int(c1)
406		default:
407			// 0x80 and 0x40 are reserved
408			return "", lenmsg, ErrRdata
409		}
410	}
411	if ptr == 0 {
412		off1 = off
413	}
414	if len(s) == 0 {
415		s = []byte(".")
416	} else if len(s) >= maxLen {
417		// error if the name is too long, but don't throw it away
418		return string(s), lenmsg, ErrLongDomain
419	}
420	return string(s), off1, nil
421}
422
423func packTxt(txt []string, msg []byte, offset int, tmp []byte) (int, error) {
424	if len(txt) == 0 {
425		if offset >= len(msg) {
426			return offset, ErrBuf
427		}
428		msg[offset] = 0
429		return offset, nil
430	}
431	var err error
432	for i := range txt {
433		if len(txt[i]) > len(tmp) {
434			return offset, ErrBuf
435		}
436		offset, err = packTxtString(txt[i], msg, offset, tmp)
437		if err != nil {
438			return offset, err
439		}
440	}
441	return offset, nil
442}
443
444func packTxtString(s string, msg []byte, offset int, tmp []byte) (int, error) {
445	lenByteOffset := offset
446	if offset >= len(msg) || len(s) > len(tmp) {
447		return offset, ErrBuf
448	}
449	offset++
450	bs := tmp[:len(s)]
451	copy(bs, s)
452	for i := 0; i < len(bs); i++ {
453		if len(msg) <= offset {
454			return offset, ErrBuf
455		}
456		if bs[i] == '\\' {
457			i++
458			if i == len(bs) {
459				break
460			}
461			// check for \DDD
462			if i+2 < len(bs) && isDigit(bs[i]) && isDigit(bs[i+1]) && isDigit(bs[i+2]) {
463				msg[offset] = dddToByte(bs[i:])
464				i += 2
465			} else {
466				msg[offset] = bs[i]
467			}
468		} else {
469			msg[offset] = bs[i]
470		}
471		offset++
472	}
473	l := offset - lenByteOffset - 1
474	if l > 255 {
475		return offset, &Error{err: "string exceeded 255 bytes in txt"}
476	}
477	msg[lenByteOffset] = byte(l)
478	return offset, nil
479}
480
481func packOctetString(s string, msg []byte, offset int, tmp []byte) (int, error) {
482	if offset >= len(msg) || len(s) > len(tmp) {
483		return offset, ErrBuf
484	}
485	bs := tmp[:len(s)]
486	copy(bs, s)
487	for i := 0; i < len(bs); i++ {
488		if len(msg) <= offset {
489			return offset, ErrBuf
490		}
491		if bs[i] == '\\' {
492			i++
493			if i == len(bs) {
494				break
495			}
496			// check for \DDD
497			if i+2 < len(bs) && isDigit(bs[i]) && isDigit(bs[i+1]) && isDigit(bs[i+2]) {
498				msg[offset] = dddToByte(bs[i:])
499				i += 2
500			} else {
501				msg[offset] = bs[i]
502			}
503		} else {
504			msg[offset] = bs[i]
505		}
506		offset++
507	}
508	return offset, nil
509}
510
511func unpackTxt(msg []byte, off0 int) (ss []string, off int, err error) {
512	off = off0
513	var s string
514	for off < len(msg) && err == nil {
515		s, off, err = unpackTxtString(msg, off)
516		if err == nil {
517			ss = append(ss, s)
518		}
519	}
520	return
521}
522
523func unpackTxtString(msg []byte, offset int) (string, int, error) {
524	if offset+1 > len(msg) {
525		return "", offset, &Error{err: "overflow unpacking txt"}
526	}
527	l := int(msg[offset])
528	if offset+l+1 > len(msg) {
529		return "", offset, &Error{err: "overflow unpacking txt"}
530	}
531	s := make([]byte, 0, l)
532	for _, b := range msg[offset+1 : offset+1+l] {
533		switch b {
534		case '"', '\\':
535			s = append(s, '\\', b)
536		default:
537			if b < 32 || b > 127 { // unprintable
538				var buf [3]byte
539				bufs := strconv.AppendInt(buf[:0], int64(b), 10)
540				s = append(s, '\\')
541				for i := 0; i < 3-len(bufs); i++ {
542					s = append(s, '0')
543				}
544				for _, r := range bufs {
545					s = append(s, r)
546				}
547			} else {
548				s = append(s, b)
549			}
550		}
551	}
552	offset += 1 + l
553	return string(s), offset, nil
554}
555
556// Helpers for dealing with escaped bytes
557func isDigit(b byte) bool { return b >= '0' && b <= '9' }
558
559func dddToByte(s []byte) byte {
560	return byte((s[0]-'0')*100 + (s[1]-'0')*10 + (s[2] - '0'))
561}
562
563// Helper function for packing and unpacking
564func intToBytes(i *big.Int, length int) []byte {
565	buf := i.Bytes()
566	if len(buf) < length {
567		b := make([]byte, length)
568		copy(b[length-len(buf):], buf)
569		return b
570	}
571	return buf
572}
573
574// PackRR packs a resource record rr into msg[off:].
575// See PackDomainName for documentation about the compression.
576func PackRR(rr RR, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
577	if rr == nil {
578		return len(msg), &Error{err: "nil rr"}
579	}
580
581	off1, err = rr.pack(msg, off, compression, compress)
582	if err != nil {
583		return len(msg), err
584	}
585	// TODO(miek): Not sure if this is needed? If removed we can remove rawmsg.go as well.
586	if rawSetRdlength(msg, off, off1) {
587		return off1, nil
588	}
589	return off, ErrRdata
590}
591
592// UnpackRR unpacks msg[off:] into an RR.
593func UnpackRR(msg []byte, off int) (rr RR, off1 int, err error) {
594	h, off, msg, err := unpackHeader(msg, off)
595	if err != nil {
596		return nil, len(msg), err
597	}
598
599	return UnpackRRWithHeader(h, msg, off)
600}
601
602// UnpackRRWithHeader unpacks the record type specific payload given an existing
603// RR_Header.
604func UnpackRRWithHeader(h RR_Header, msg []byte, off int) (rr RR, off1 int, err error) {
605	end := off + int(h.Rdlength)
606
607	if fn, known := typeToUnpack[h.Rrtype]; !known {
608		rr, off, err = unpackRFC3597(h, msg, off)
609	} else {
610		rr, off, err = fn(h, msg, off)
611	}
612	if off != end {
613		return &h, end, &Error{err: "bad rdlength"}
614	}
615	return rr, off, err
616}
617
618// unpackRRslice unpacks msg[off:] into an []RR.
619// If we cannot unpack the whole array, then it will return nil
620func unpackRRslice(l int, msg []byte, off int) (dst1 []RR, off1 int, err error) {
621	var r RR
622	// Don't pre-allocate, l may be under attacker control
623	var dst []RR
624	for i := 0; i < l; i++ {
625		off1 := off
626		r, off, err = UnpackRR(msg, off)
627		if err != nil {
628			off = len(msg)
629			break
630		}
631		// If offset does not increase anymore, l is a lie
632		if off1 == off {
633			l = i
634			break
635		}
636		dst = append(dst, r)
637	}
638	if err != nil && off == len(msg) {
639		dst = nil
640	}
641	return dst, off, err
642}
643
644// Convert a MsgHdr to a string, with dig-like headers:
645//
646//;; opcode: QUERY, status: NOERROR, id: 48404
647//
648//;; flags: qr aa rd ra;
649func (h *MsgHdr) String() string {
650	if h == nil {
651		return "<nil> MsgHdr"
652	}
653
654	s := ";; opcode: " + OpcodeToString[h.Opcode]
655	s += ", status: " + RcodeToString[h.Rcode]
656	s += ", id: " + strconv.Itoa(int(h.Id)) + "\n"
657
658	s += ";; flags:"
659	if h.Response {
660		s += " qr"
661	}
662	if h.Authoritative {
663		s += " aa"
664	}
665	if h.Truncated {
666		s += " tc"
667	}
668	if h.RecursionDesired {
669		s += " rd"
670	}
671	if h.RecursionAvailable {
672		s += " ra"
673	}
674	if h.Zero { // Hmm
675		s += " z"
676	}
677	if h.AuthenticatedData {
678		s += " ad"
679	}
680	if h.CheckingDisabled {
681		s += " cd"
682	}
683
684	s += ";"
685	return s
686}
687
688// Pack packs a Msg: it is converted to to wire format.
689// If the dns.Compress is true the message will be in compressed wire format.
690func (dns *Msg) Pack() (msg []byte, err error) {
691	return dns.PackBuffer(nil)
692}
693
694// PackBuffer packs a Msg, using the given buffer buf. If buf is too small a new buffer is allocated.
695func (dns *Msg) PackBuffer(buf []byte) (msg []byte, err error) {
696	var compression map[string]int
697	if dns.Compress {
698		compression = make(map[string]int) // Compression pointer mappings.
699	}
700	return dns.packBufferWithCompressionMap(buf, compression)
701}
702
703// packBufferWithCompressionMap packs a Msg, using the given buffer buf.
704func (dns *Msg) packBufferWithCompressionMap(buf []byte, compression map[string]int) (msg []byte, err error) {
705	// We use a similar function in tsig.go's stripTsig.
706
707	var dh Header
708
709	if dns.Rcode < 0 || dns.Rcode > 0xFFF {
710		return nil, ErrRcode
711	}
712	if dns.Rcode > 0xF {
713		// Regular RCODE field is 4 bits
714		opt := dns.IsEdns0()
715		if opt == nil {
716			return nil, ErrExtendedRcode
717		}
718		opt.SetExtendedRcode(uint8(dns.Rcode >> 4))
719	}
720
721	// Convert convenient Msg into wire-like Header.
722	dh.Id = dns.Id
723	dh.Bits = uint16(dns.Opcode)<<11 | uint16(dns.Rcode&0xF)
724	if dns.Response {
725		dh.Bits |= _QR
726	}
727	if dns.Authoritative {
728		dh.Bits |= _AA
729	}
730	if dns.Truncated {
731		dh.Bits |= _TC
732	}
733	if dns.RecursionDesired {
734		dh.Bits |= _RD
735	}
736	if dns.RecursionAvailable {
737		dh.Bits |= _RA
738	}
739	if dns.Zero {
740		dh.Bits |= _Z
741	}
742	if dns.AuthenticatedData {
743		dh.Bits |= _AD
744	}
745	if dns.CheckingDisabled {
746		dh.Bits |= _CD
747	}
748
749	// Prepare variable sized arrays.
750	question := dns.Question
751	answer := dns.Answer
752	ns := dns.Ns
753	extra := dns.Extra
754
755	dh.Qdcount = uint16(len(question))
756	dh.Ancount = uint16(len(answer))
757	dh.Nscount = uint16(len(ns))
758	dh.Arcount = uint16(len(extra))
759
760	// We need the uncompressed length here, because we first pack it and then compress it.
761	msg = buf
762	uncompressedLen := compressedLen(dns, false)
763	if packLen := uncompressedLen + 1; len(msg) < packLen {
764		msg = make([]byte, packLen)
765	}
766
767	// Pack it in: header and then the pieces.
768	off := 0
769	off, err = dh.pack(msg, off, compression, dns.Compress)
770	if err != nil {
771		return nil, err
772	}
773	for i := 0; i < len(question); i++ {
774		off, err = question[i].pack(msg, off, compression, dns.Compress)
775		if err != nil {
776			return nil, err
777		}
778	}
779	for i := 0; i < len(answer); i++ {
780		off, err = PackRR(answer[i], msg, off, compression, dns.Compress)
781		if err != nil {
782			return nil, err
783		}
784	}
785	for i := 0; i < len(ns); i++ {
786		off, err = PackRR(ns[i], msg, off, compression, dns.Compress)
787		if err != nil {
788			return nil, err
789		}
790	}
791	for i := 0; i < len(extra); i++ {
792		off, err = PackRR(extra[i], msg, off, compression, dns.Compress)
793		if err != nil {
794			return nil, err
795		}
796	}
797	return msg[:off], nil
798}
799
800// Unpack unpacks a binary message to a Msg structure.
801func (dns *Msg) Unpack(msg []byte) (err error) {
802	var (
803		dh  Header
804		off int
805	)
806	if dh, off, err = unpackMsgHdr(msg, off); err != nil {
807		return err
808	}
809
810	dns.Id = dh.Id
811	dns.Response = (dh.Bits & _QR) != 0
812	dns.Opcode = int(dh.Bits>>11) & 0xF
813	dns.Authoritative = (dh.Bits & _AA) != 0
814	dns.Truncated = (dh.Bits & _TC) != 0
815	dns.RecursionDesired = (dh.Bits & _RD) != 0
816	dns.RecursionAvailable = (dh.Bits & _RA) != 0
817	dns.Zero = (dh.Bits & _Z) != 0
818	dns.AuthenticatedData = (dh.Bits & _AD) != 0
819	dns.CheckingDisabled = (dh.Bits & _CD) != 0
820	dns.Rcode = int(dh.Bits & 0xF)
821
822	// If we are at the end of the message we should return *just* the
823	// header. This can still be useful to the caller. 9.9.9.9 sends these
824	// when responding with REFUSED for instance.
825	if off == len(msg) {
826		// reset sections before returning
827		dns.Question, dns.Answer, dns.Ns, dns.Extra = nil, nil, nil, nil
828		return nil
829	}
830
831	// Qdcount, Ancount, Nscount, Arcount can't be trusted, as they are
832	// attacker controlled. This means we can't use them to pre-allocate
833	// slices.
834	dns.Question = nil
835	for i := 0; i < int(dh.Qdcount); i++ {
836		off1 := off
837		var q Question
838		q, off, err = unpackQuestion(msg, off)
839		if err != nil {
840			// Even if Truncated is set, we only will set ErrTruncated if we
841			// actually got the questions
842			return err
843		}
844		if off1 == off { // Offset does not increase anymore, dh.Qdcount is a lie!
845			dh.Qdcount = uint16(i)
846			break
847		}
848		dns.Question = append(dns.Question, q)
849	}
850
851	dns.Answer, off, err = unpackRRslice(int(dh.Ancount), msg, off)
852	// The header counts might have been wrong so we need to update it
853	dh.Ancount = uint16(len(dns.Answer))
854	if err == nil {
855		dns.Ns, off, err = unpackRRslice(int(dh.Nscount), msg, off)
856	}
857	// The header counts might have been wrong so we need to update it
858	dh.Nscount = uint16(len(dns.Ns))
859	if err == nil {
860		dns.Extra, off, err = unpackRRslice(int(dh.Arcount), msg, off)
861	}
862	// The header counts might have been wrong so we need to update it
863	dh.Arcount = uint16(len(dns.Extra))
864
865	if off != len(msg) {
866		// TODO(miek) make this an error?
867		// use PackOpt to let people tell how detailed the error reporting should be?
868		// println("dns: extra bytes in dns packet", off, "<", len(msg))
869	} else if dns.Truncated {
870		// Whether we ran into a an error or not, we want to return that it
871		// was truncated
872		err = ErrTruncated
873	}
874	return err
875}
876
877// Convert a complete message to a string with dig-like output.
878func (dns *Msg) String() string {
879	if dns == nil {
880		return "<nil> MsgHdr"
881	}
882	s := dns.MsgHdr.String() + " "
883	s += "QUERY: " + strconv.Itoa(len(dns.Question)) + ", "
884	s += "ANSWER: " + strconv.Itoa(len(dns.Answer)) + ", "
885	s += "AUTHORITY: " + strconv.Itoa(len(dns.Ns)) + ", "
886	s += "ADDITIONAL: " + strconv.Itoa(len(dns.Extra)) + "\n"
887	if len(dns.Question) > 0 {
888		s += "\n;; QUESTION SECTION:\n"
889		for i := 0; i < len(dns.Question); i++ {
890			s += dns.Question[i].String() + "\n"
891		}
892	}
893	if len(dns.Answer) > 0 {
894		s += "\n;; ANSWER SECTION:\n"
895		for i := 0; i < len(dns.Answer); i++ {
896			if dns.Answer[i] != nil {
897				s += dns.Answer[i].String() + "\n"
898			}
899		}
900	}
901	if len(dns.Ns) > 0 {
902		s += "\n;; AUTHORITY SECTION:\n"
903		for i := 0; i < len(dns.Ns); i++ {
904			if dns.Ns[i] != nil {
905				s += dns.Ns[i].String() + "\n"
906			}
907		}
908	}
909	if len(dns.Extra) > 0 {
910		s += "\n;; ADDITIONAL SECTION:\n"
911		for i := 0; i < len(dns.Extra); i++ {
912			if dns.Extra[i] != nil {
913				s += dns.Extra[i].String() + "\n"
914			}
915		}
916	}
917	return s
918}
919
920// Len returns the message length when in (un)compressed wire format.
921// If dns.Compress is true compression it is taken into account. Len()
922// is provided to be a faster way to get the size of the resulting packet,
923// than packing it, measuring the size and discarding the buffer.
924func (dns *Msg) Len() int { return compressedLen(dns, dns.Compress) }
925
926func compressedLenWithCompressionMap(dns *Msg, compression map[string]int) int {
927	l := 12 // Message header is always 12 bytes
928	for _, r := range dns.Question {
929		compressionLenHelper(compression, r.Name, l)
930		l += r.len()
931	}
932	l += compressionLenSlice(l, compression, dns.Answer)
933	l += compressionLenSlice(l, compression, dns.Ns)
934	l += compressionLenSlice(l, compression, dns.Extra)
935	return l
936}
937
938// compressedLen returns the message length when in compressed wire format
939// when compress is true, otherwise the uncompressed length is returned.
940func compressedLen(dns *Msg, compress bool) int {
941	// We always return one more than needed.
942	if compress {
943		compression := map[string]int{}
944		return compressedLenWithCompressionMap(dns, compression)
945	}
946	l := 12 // Message header is always 12 bytes
947
948	for _, r := range dns.Question {
949		l += r.len()
950	}
951	for _, r := range dns.Answer {
952		if r != nil {
953			l += r.len()
954		}
955	}
956	for _, r := range dns.Ns {
957		if r != nil {
958			l += r.len()
959		}
960	}
961	for _, r := range dns.Extra {
962		if r != nil {
963			l += r.len()
964		}
965	}
966
967	return l
968}
969
970func compressionLenSlice(lenp int, c map[string]int, rs []RR) int {
971	initLen := lenp
972	for _, r := range rs {
973		if r == nil {
974			continue
975		}
976		// TmpLen is to track len of record at 14bits boudaries
977		tmpLen := lenp
978
979		x := r.len()
980		// track this length, and the global length in len, while taking compression into account for both.
981		k, ok, _ := compressionLenSearch(c, r.Header().Name)
982		if ok {
983			// Size of x is reduced by k, but we add 1 since k includes the '.' and label descriptor take 2 bytes
984			// so, basically x:= x - k - 1 + 2
985			x += 1 - k
986		}
987
988		tmpLen += compressionLenHelper(c, r.Header().Name, tmpLen)
989		k, ok, _ = compressionLenSearchType(c, r)
990		if ok {
991			x += 1 - k
992		}
993		lenp += x
994		tmpLen = lenp
995		tmpLen += compressionLenHelperType(c, r, tmpLen)
996
997	}
998	return lenp - initLen
999}
1000
1001// Put the parts of the name in the compression map, return the size in bytes added in payload
1002func compressionLenHelper(c map[string]int, s string, currentLen int) int {
1003	if currentLen > maxCompressionOffset {
1004		// We won't be able to add any label that could be re-used later anyway
1005		return 0
1006	}
1007	if _, ok := c[s]; ok {
1008		return 0
1009	}
1010	initLen := currentLen
1011	pref := ""
1012	prev := s
1013	lbs := Split(s)
1014	for j := 0; j < len(lbs); j++ {
1015		pref = s[lbs[j]:]
1016		currentLen += len(prev) - len(pref)
1017		prev = pref
1018		if _, ok := c[pref]; !ok {
1019			// If first byte label is within the first 14bits, it might be re-used later
1020			if currentLen < maxCompressionOffset {
1021				c[pref] = currentLen
1022			}
1023		} else {
1024			added := currentLen - initLen
1025			if j > 0 {
1026				// We added a new PTR
1027				added += 2
1028			}
1029			return added
1030		}
1031	}
1032	return currentLen - initLen
1033}
1034
1035// Look for each part in the compression map and returns its length,
1036// keep on searching so we get the longest match.
1037// Will return the size of compression found, whether a match has been
1038// found and the size of record if added in payload
1039func compressionLenSearch(c map[string]int, s string) (int, bool, int) {
1040	off := 0
1041	end := false
1042	if s == "" { // don't bork on bogus data
1043		return 0, false, 0
1044	}
1045	fullSize := 0
1046	for {
1047		if _, ok := c[s[off:]]; ok {
1048			return len(s[off:]), true, fullSize + off
1049		}
1050		if end {
1051			break
1052		}
1053		// Each label descriptor takes 2 bytes, add it
1054		fullSize += 2
1055		off, end = NextLabel(s, off)
1056	}
1057	return 0, false, fullSize + len(s)
1058}
1059
1060// Copy returns a new RR which is a deep-copy of r.
1061func Copy(r RR) RR { r1 := r.copy(); return r1 }
1062
1063// Len returns the length (in octets) of the uncompressed RR in wire format.
1064func Len(r RR) int { return r.len() }
1065
1066// Copy returns a new *Msg which is a deep-copy of dns.
1067func (dns *Msg) Copy() *Msg { return dns.CopyTo(new(Msg)) }
1068
1069// CopyTo copies the contents to the provided message using a deep-copy and returns the copy.
1070func (dns *Msg) CopyTo(r1 *Msg) *Msg {
1071	r1.MsgHdr = dns.MsgHdr
1072	r1.Compress = dns.Compress
1073
1074	if len(dns.Question) > 0 {
1075		r1.Question = make([]Question, len(dns.Question))
1076		copy(r1.Question, dns.Question) // TODO(miek): Question is an immutable value, ok to do a shallow-copy
1077	}
1078
1079	rrArr := make([]RR, len(dns.Answer)+len(dns.Ns)+len(dns.Extra))
1080	var rri int
1081
1082	if len(dns.Answer) > 0 {
1083		rrbegin := rri
1084		for i := 0; i < len(dns.Answer); i++ {
1085			rrArr[rri] = dns.Answer[i].copy()
1086			rri++
1087		}
1088		r1.Answer = rrArr[rrbegin:rri:rri]
1089	}
1090
1091	if len(dns.Ns) > 0 {
1092		rrbegin := rri
1093		for i := 0; i < len(dns.Ns); i++ {
1094			rrArr[rri] = dns.Ns[i].copy()
1095			rri++
1096		}
1097		r1.Ns = rrArr[rrbegin:rri:rri]
1098	}
1099
1100	if len(dns.Extra) > 0 {
1101		rrbegin := rri
1102		for i := 0; i < len(dns.Extra); i++ {
1103			rrArr[rri] = dns.Extra[i].copy()
1104			rri++
1105		}
1106		r1.Extra = rrArr[rrbegin:rri:rri]
1107	}
1108
1109	return r1
1110}
1111
1112func (q *Question) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) {
1113	off, err := PackDomainName(q.Name, msg, off, compression, compress)
1114	if err != nil {
1115		return off, err
1116	}
1117	off, err = packUint16(q.Qtype, msg, off)
1118	if err != nil {
1119		return off, err
1120	}
1121	off, err = packUint16(q.Qclass, msg, off)
1122	if err != nil {
1123		return off, err
1124	}
1125	return off, nil
1126}
1127
1128func unpackQuestion(msg []byte, off int) (Question, int, error) {
1129	var (
1130		q   Question
1131		err error
1132	)
1133	q.Name, off, err = UnpackDomainName(msg, off)
1134	if err != nil {
1135		return q, off, err
1136	}
1137	if off == len(msg) {
1138		return q, off, nil
1139	}
1140	q.Qtype, off, err = unpackUint16(msg, off)
1141	if err != nil {
1142		return q, off, err
1143	}
1144	if off == len(msg) {
1145		return q, off, nil
1146	}
1147	q.Qclass, off, err = unpackUint16(msg, off)
1148	if off == len(msg) {
1149		return q, off, nil
1150	}
1151	return q, off, err
1152}
1153
1154func (dh *Header) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) {
1155	off, err := packUint16(dh.Id, msg, off)
1156	if err != nil {
1157		return off, err
1158	}
1159	off, err = packUint16(dh.Bits, msg, off)
1160	if err != nil {
1161		return off, err
1162	}
1163	off, err = packUint16(dh.Qdcount, msg, off)
1164	if err != nil {
1165		return off, err
1166	}
1167	off, err = packUint16(dh.Ancount, msg, off)
1168	if err != nil {
1169		return off, err
1170	}
1171	off, err = packUint16(dh.Nscount, msg, off)
1172	if err != nil {
1173		return off, err
1174	}
1175	off, err = packUint16(dh.Arcount, msg, off)
1176	return off, err
1177}
1178
1179func unpackMsgHdr(msg []byte, off int) (Header, int, error) {
1180	var (
1181		dh  Header
1182		err error
1183	)
1184	dh.Id, off, err = unpackUint16(msg, off)
1185	if err != nil {
1186		return dh, off, err
1187	}
1188	dh.Bits, off, err = unpackUint16(msg, off)
1189	if err != nil {
1190		return dh, off, err
1191	}
1192	dh.Qdcount, off, err = unpackUint16(msg, off)
1193	if err != nil {
1194		return dh, off, err
1195	}
1196	dh.Ancount, off, err = unpackUint16(msg, off)
1197	if err != nil {
1198		return dh, off, err
1199	}
1200	dh.Nscount, off, err = unpackUint16(msg, off)
1201	if err != nil {
1202		return dh, off, err
1203	}
1204	dh.Arcount, off, err = unpackUint16(msg, off)
1205	return dh, off, err
1206}
1207