1// DNS packet assembly, see RFC 1035. Converting from - Unpack() -
2// and to - Pack() - wire format.
3// All the packers and unpackers take a (msg []byte, off int)
4// and return (off1 int, ok bool).  If they return ok==false, they
5// also return off1==len(msg), so that the next unpacker will
6// also fail.  This lets us avoid checks of ok until the end of a
7// packing sequence.
8
9package dns
10
11//go:generate go run msg_generate.go
12//go:generate go run compress_generate.go
13
14import (
15	crand "crypto/rand"
16	"encoding/binary"
17	"fmt"
18	"math/big"
19	"math/rand"
20	"strconv"
21	"sync"
22)
23
24const (
25	maxCompressionOffset    = 2 << 13 // We have 14 bits for the compression pointer
26	maxDomainNameWireOctets = 255     // See RFC 1035 section 2.3.4
27)
28
29// Errors defined in this package.
30var (
31	ErrAlg           error = &Error{err: "bad algorithm"}                  // ErrAlg indicates an error with the (DNSSEC) algorithm.
32	ErrAuth          error = &Error{err: "bad authentication"}             // ErrAuth indicates an error in the TSIG authentication.
33	ErrBuf           error = &Error{err: "buffer size too small"}          // ErrBuf indicates that the buffer used is too small for the message.
34	ErrConnEmpty     error = &Error{err: "conn has no connection"}         // ErrConnEmpty indicates a connection is being used before it is initialized.
35	ErrExtendedRcode error = &Error{err: "bad extended rcode"}             // ErrExtendedRcode ...
36	ErrFqdn          error = &Error{err: "domain must be fully qualified"} // ErrFqdn indicates that a domain name does not have a closing dot.
37	ErrId            error = &Error{err: "id mismatch"}                    // ErrId indicates there is a mismatch with the message's ID.
38	ErrKeyAlg        error = &Error{err: "bad key algorithm"}              // ErrKeyAlg indicates that the algorithm in the key is not valid.
39	ErrKey           error = &Error{err: "bad key"}
40	ErrKeySize       error = &Error{err: "bad key size"}
41	ErrLongDomain    error = &Error{err: fmt.Sprintf("domain name exceeded %d wire-format octets", maxDomainNameWireOctets)}
42	ErrNoSig         error = &Error{err: "no signature found"}
43	ErrPrivKey       error = &Error{err: "bad private key"}
44	ErrRcode         error = &Error{err: "bad rcode"}
45	ErrRdata         error = &Error{err: "bad rdata"}
46	ErrRRset         error = &Error{err: "bad rrset"}
47	ErrSecret        error = &Error{err: "no secrets defined"}
48	ErrShortRead     error = &Error{err: "short read"}
49	ErrSig           error = &Error{err: "bad signature"}                      // ErrSig indicates that a signature can not be cryptographically validated.
50	ErrSoa           error = &Error{err: "no SOA"}                             // ErrSOA indicates that no SOA RR was seen when doing zone transfers.
51	ErrTime          error = &Error{err: "bad time"}                           // ErrTime indicates a timing error in TSIG authentication.
52	ErrTruncated     error = &Error{err: "failed to unpack truncated message"} // ErrTruncated indicates that we failed to unpack a truncated message. We unpacked as much as we had so Msg can still be used, if desired.
53)
54
55// Id by default, returns a 16 bits random number to be used as a
56// message id. The random provided should be good enough. This being a
57// variable the function can be reassigned to a custom function.
58// For instance, to make it return a static value:
59//
60//	dns.Id = func() uint16 { return 3 }
61var Id = id
62
63var (
64	idLock sync.Mutex
65	idRand *rand.Rand
66)
67
68// id returns a 16 bits random number to be used as a
69// message id. The random provided should be good enough.
70func id() uint16 {
71	idLock.Lock()
72
73	if idRand == nil {
74		// This (partially) works around
75		// https://github.com/golang/go/issues/11833 by only
76		// seeding idRand upon the first call to id.
77
78		var seed int64
79		var buf [8]byte
80
81		if _, err := crand.Read(buf[:]); err == nil {
82			seed = int64(binary.LittleEndian.Uint64(buf[:]))
83		} else {
84			seed = rand.Int63()
85		}
86
87		idRand = rand.New(rand.NewSource(seed))
88	}
89
90	// The call to idRand.Uint32 must be within the
91	// mutex lock because *rand.Rand is not safe for
92	// concurrent use.
93	//
94	// There is no added performance overhead to calling
95	// idRand.Uint32 inside a mutex lock over just
96	// calling rand.Uint32 as the global math/rand rng
97	// is internally protected by a sync.Mutex.
98	id := uint16(idRand.Uint32())
99
100	idLock.Unlock()
101	return id
102}
103
104// MsgHdr is a a manually-unpacked version of (id, bits).
105type MsgHdr struct {
106	Id                 uint16
107	Response           bool
108	Opcode             int
109	Authoritative      bool
110	Truncated          bool
111	RecursionDesired   bool
112	RecursionAvailable bool
113	Zero               bool
114	AuthenticatedData  bool
115	CheckingDisabled   bool
116	Rcode              int
117}
118
119// Msg contains the layout of a DNS message.
120type Msg struct {
121	MsgHdr
122	Compress bool       `json:"-"` // If true, the message will be compressed when converted to wire format.
123	Question []Question // Holds the RR(s) of the question section.
124	Answer   []RR       // Holds the RR(s) of the answer section.
125	Ns       []RR       // Holds the RR(s) of the authority section.
126	Extra    []RR       // Holds the RR(s) of the additional section.
127}
128
129// ClassToString is a maps Classes to strings for each CLASS wire type.
130var ClassToString = map[uint16]string{
131	ClassINET:   "IN",
132	ClassCSNET:  "CS",
133	ClassCHAOS:  "CH",
134	ClassHESIOD: "HS",
135	ClassNONE:   "NONE",
136	ClassANY:    "ANY",
137}
138
139// OpcodeToString maps Opcodes to strings.
140var OpcodeToString = map[int]string{
141	OpcodeQuery:  "QUERY",
142	OpcodeIQuery: "IQUERY",
143	OpcodeStatus: "STATUS",
144	OpcodeNotify: "NOTIFY",
145	OpcodeUpdate: "UPDATE",
146}
147
148// RcodeToString maps Rcodes to strings.
149var RcodeToString = map[int]string{
150	RcodeSuccess:        "NOERROR",
151	RcodeFormatError:    "FORMERR",
152	RcodeServerFailure:  "SERVFAIL",
153	RcodeNameError:      "NXDOMAIN",
154	RcodeNotImplemented: "NOTIMPL",
155	RcodeRefused:        "REFUSED",
156	RcodeYXDomain:       "YXDOMAIN", // See RFC 2136
157	RcodeYXRrset:        "YXRRSET",
158	RcodeNXRrset:        "NXRRSET",
159	RcodeNotAuth:        "NOTAUTH",
160	RcodeNotZone:        "NOTZONE",
161	RcodeBadSig:         "BADSIG", // Also known as RcodeBadVers, see RFC 6891
162	//	RcodeBadVers:        "BADVERS",
163	RcodeBadKey:    "BADKEY",
164	RcodeBadTime:   "BADTIME",
165	RcodeBadMode:   "BADMODE",
166	RcodeBadName:   "BADNAME",
167	RcodeBadAlg:    "BADALG",
168	RcodeBadTrunc:  "BADTRUNC",
169	RcodeBadCookie: "BADCOOKIE",
170}
171
172// Domain names are a sequence of counted strings
173// split at the dots. They end with a zero-length string.
174
175// PackDomainName packs a domain name s into msg[off:].
176// If compression is wanted compress must be true and the compression
177// map needs to hold a mapping between domain names and offsets
178// pointing into msg.
179func PackDomainName(s string, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
180	off1, _, err = packDomainName(s, msg, off, compression, compress)
181	return
182}
183
184func packDomainName(s string, msg []byte, off int, compression map[string]int, compress bool) (off1 int, labels int, err error) {
185	// special case if msg == nil
186	lenmsg := 256
187	if msg != nil {
188		lenmsg = len(msg)
189	}
190	ls := len(s)
191	if ls == 0 { // Ok, for instance when dealing with update RR without any rdata.
192		return off, 0, nil
193	}
194	// If not fully qualified, error out, but only if msg == nil #ugly
195	switch {
196	case msg == nil:
197		if s[ls-1] != '.' {
198			s += "."
199			ls++
200		}
201	case msg != nil:
202		if s[ls-1] != '.' {
203			return lenmsg, 0, ErrFqdn
204		}
205	}
206	// Each dot ends a segment of the name.
207	// We trade each dot byte for a length byte.
208	// Except for escaped dots (\.), which are normal dots.
209	// There is also a trailing zero.
210
211	// Compression
212	nameoffset := -1
213	pointer := -1
214	// Emit sequence of counted strings, chopping at dots.
215	begin := 0
216	bs := []byte(s)
217	roBs, bsFresh, escapedDot := s, true, false
218	for i := 0; i < ls; i++ {
219		if bs[i] == '\\' {
220			for j := i; j < ls-1; j++ {
221				bs[j] = bs[j+1]
222			}
223			ls--
224			if off+1 > lenmsg {
225				return lenmsg, labels, ErrBuf
226			}
227			// check for \DDD
228			if i+2 < ls && isDigit(bs[i]) && isDigit(bs[i+1]) && isDigit(bs[i+2]) {
229				bs[i] = dddToByte(bs[i:])
230				for j := i + 1; j < ls-2; j++ {
231					bs[j] = bs[j+2]
232				}
233				ls -= 2
234			}
235			escapedDot = bs[i] == '.'
236			bsFresh = false
237			continue
238		}
239
240		if bs[i] == '.' {
241			if i > 0 && bs[i-1] == '.' && !escapedDot {
242				// two dots back to back is not legal
243				return lenmsg, labels, ErrRdata
244			}
245			if i-begin >= 1<<6 { // top two bits of length must be clear
246				return lenmsg, labels, ErrRdata
247			}
248			// off can already (we're in a loop) be bigger than len(msg)
249			// this happens when a name isn't fully qualified
250			if off+1 > lenmsg {
251				return lenmsg, labels, ErrBuf
252			}
253			if msg != nil {
254				msg[off] = byte(i - begin)
255			}
256			offset := off
257			off++
258			for j := begin; j < i; j++ {
259				if off+1 > lenmsg {
260					return lenmsg, labels, ErrBuf
261				}
262				if msg != nil {
263					msg[off] = bs[j]
264				}
265				off++
266			}
267			if compress && !bsFresh {
268				roBs = string(bs)
269				bsFresh = true
270			}
271			// Don't try to compress '.'
272			// We should only compress when compress it true, but we should also still pick
273			// up names that can be used for *future* compression(s).
274			if compression != nil && roBs[begin:] != "." {
275				if p, ok := compression[roBs[begin:]]; !ok {
276					// Only offsets smaller than this can be used.
277					if offset < maxCompressionOffset {
278						compression[roBs[begin:]] = offset
279					}
280				} else {
281					// The first hit is the longest matching dname
282					// keep the pointer offset we get back and store
283					// the offset of the current name, because that's
284					// where we need to insert the pointer later
285
286					// If compress is true, we're allowed to compress this dname
287					if pointer == -1 && compress {
288						pointer = p         // Where to point to
289						nameoffset = offset // Where to point from
290						break
291					}
292				}
293			}
294			labels++
295			begin = i + 1
296		}
297		escapedDot = false
298	}
299	// Root label is special
300	if len(bs) == 1 && bs[0] == '.' {
301		return off, labels, nil
302	}
303	// If we did compression and we find something add the pointer here
304	if pointer != -1 {
305		// Clear the msg buffer after the pointer location, otherwise
306		// packDataNsec writes the wrong data to msg.
307		tainted := msg[nameoffset:off]
308		for i := range tainted {
309			tainted[i] = 0
310		}
311		// We have two bytes (14 bits) to put the pointer in
312		// if msg == nil, we will never do compression
313		binary.BigEndian.PutUint16(msg[nameoffset:], uint16(pointer^0xC000))
314		off = nameoffset + 1
315		goto End
316	}
317	if msg != nil && off < len(msg) {
318		msg[off] = 0
319	}
320End:
321	off++
322	return off, labels, nil
323}
324
325// Unpack a domain name.
326// In addition to the simple sequences of counted strings above,
327// domain names are allowed to refer to strings elsewhere in the
328// packet, to avoid repeating common suffixes when returning
329// many entries in a single domain.  The pointers are marked
330// by a length byte with the top two bits set.  Ignoring those
331// two bits, that byte and the next give a 14 bit offset from msg[0]
332// where we should pick up the trail.
333// Note that if we jump elsewhere in the packet,
334// we return off1 == the offset after the first pointer we found,
335// which is where the next record will start.
336// In theory, the pointers are only allowed to jump backward.
337// We let them jump anywhere and stop jumping after a while.
338
339// UnpackDomainName unpacks a domain name into a string.
340func UnpackDomainName(msg []byte, off int) (string, int, error) {
341	s := make([]byte, 0, 64)
342	off1 := 0
343	lenmsg := len(msg)
344	maxLen := maxDomainNameWireOctets
345	ptr := 0 // number of pointers followed
346Loop:
347	for {
348		if off >= lenmsg {
349			return "", lenmsg, ErrBuf
350		}
351		c := int(msg[off])
352		off++
353		switch c & 0xC0 {
354		case 0x00:
355			if c == 0x00 {
356				// end of name
357				break Loop
358			}
359			// literal string
360			if off+c > lenmsg {
361				return "", lenmsg, ErrBuf
362			}
363			for j := off; j < off+c; j++ {
364				switch b := msg[j]; b {
365				case '.', '(', ')', ';', ' ', '@':
366					fallthrough
367				case '"', '\\':
368					s = append(s, '\\', b)
369					// presentation-format \X escapes add an extra byte
370					maxLen++
371				default:
372					if b < 32 || b >= 127 { // unprintable, use \DDD
373						var buf [3]byte
374						bufs := strconv.AppendInt(buf[:0], int64(b), 10)
375						s = append(s, '\\')
376						for i := len(bufs); i < 3; i++ {
377							s = append(s, '0')
378						}
379						s = append(s, bufs...)
380						// presentation-format \DDD escapes add 3 extra bytes
381						maxLen += 3
382					} else {
383						s = append(s, b)
384					}
385				}
386			}
387			s = append(s, '.')
388			off += c
389		case 0xC0:
390			// pointer to somewhere else in msg.
391			// remember location after first ptr,
392			// since that's how many bytes we consumed.
393			// also, don't follow too many pointers --
394			// maybe there's a loop.
395			if off >= lenmsg {
396				return "", lenmsg, ErrBuf
397			}
398			c1 := msg[off]
399			off++
400			if ptr == 0 {
401				off1 = off
402			}
403			if ptr++; ptr > 10 {
404				return "", lenmsg, &Error{err: "too many compression pointers"}
405			}
406			// pointer should guarantee that it advances and points forwards at least
407			// but the condition on previous three lines guarantees that it's
408			// at least loop-free
409			off = (c^0xC0)<<8 | int(c1)
410		default:
411			// 0x80 and 0x40 are reserved
412			return "", lenmsg, ErrRdata
413		}
414	}
415	if ptr == 0 {
416		off1 = off
417	}
418	if len(s) == 0 {
419		s = []byte(".")
420	} else if len(s) >= maxLen {
421		// error if the name is too long, but don't throw it away
422		return string(s), lenmsg, ErrLongDomain
423	}
424	return string(s), off1, nil
425}
426
427func packTxt(txt []string, msg []byte, offset int, tmp []byte) (int, error) {
428	if len(txt) == 0 {
429		if offset >= len(msg) {
430			return offset, ErrBuf
431		}
432		msg[offset] = 0
433		return offset, nil
434	}
435	var err error
436	for i := range txt {
437		if len(txt[i]) > len(tmp) {
438			return offset, ErrBuf
439		}
440		offset, err = packTxtString(txt[i], msg, offset, tmp)
441		if err != nil {
442			return offset, err
443		}
444	}
445	return offset, nil
446}
447
448func packTxtString(s string, msg []byte, offset int, tmp []byte) (int, error) {
449	lenByteOffset := offset
450	if offset >= len(msg) || len(s) > len(tmp) {
451		return offset, ErrBuf
452	}
453	offset++
454	bs := tmp[:len(s)]
455	copy(bs, s)
456	for i := 0; i < len(bs); i++ {
457		if len(msg) <= offset {
458			return offset, ErrBuf
459		}
460		if bs[i] == '\\' {
461			i++
462			if i == len(bs) {
463				break
464			}
465			// check for \DDD
466			if i+2 < len(bs) && isDigit(bs[i]) && isDigit(bs[i+1]) && isDigit(bs[i+2]) {
467				msg[offset] = dddToByte(bs[i:])
468				i += 2
469			} else {
470				msg[offset] = bs[i]
471			}
472		} else {
473			msg[offset] = bs[i]
474		}
475		offset++
476	}
477	l := offset - lenByteOffset - 1
478	if l > 255 {
479		return offset, &Error{err: "string exceeded 255 bytes in txt"}
480	}
481	msg[lenByteOffset] = byte(l)
482	return offset, nil
483}
484
485func packOctetString(s string, msg []byte, offset int, tmp []byte) (int, error) {
486	if offset >= len(msg) || len(s) > len(tmp) {
487		return offset, ErrBuf
488	}
489	bs := tmp[:len(s)]
490	copy(bs, s)
491	for i := 0; i < len(bs); i++ {
492		if len(msg) <= offset {
493			return offset, ErrBuf
494		}
495		if bs[i] == '\\' {
496			i++
497			if i == len(bs) {
498				break
499			}
500			// check for \DDD
501			if i+2 < len(bs) && isDigit(bs[i]) && isDigit(bs[i+1]) && isDigit(bs[i+2]) {
502				msg[offset] = dddToByte(bs[i:])
503				i += 2
504			} else {
505				msg[offset] = bs[i]
506			}
507		} else {
508			msg[offset] = bs[i]
509		}
510		offset++
511	}
512	return offset, nil
513}
514
515func unpackTxt(msg []byte, off0 int) (ss []string, off int, err error) {
516	off = off0
517	var s string
518	for off < len(msg) && err == nil {
519		s, off, err = unpackString(msg, off)
520		if err == nil {
521			ss = append(ss, s)
522		}
523	}
524	return
525}
526
527// Helpers for dealing with escaped bytes
528func isDigit(b byte) bool { return b >= '0' && b <= '9' }
529
530func dddToByte(s []byte) byte {
531	return byte((s[0]-'0')*100 + (s[1]-'0')*10 + (s[2] - '0'))
532}
533
534func dddStringToByte(s string) byte {
535	return byte((s[0]-'0')*100 + (s[1]-'0')*10 + (s[2] - '0'))
536}
537
538// Helper function for packing and unpacking
539func intToBytes(i *big.Int, length int) []byte {
540	buf := i.Bytes()
541	if len(buf) < length {
542		b := make([]byte, length)
543		copy(b[length-len(buf):], buf)
544		return b
545	}
546	return buf
547}
548
549// PackRR packs a resource record rr into msg[off:].
550// See PackDomainName for documentation about the compression.
551func PackRR(rr RR, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
552	if rr == nil {
553		return len(msg), &Error{err: "nil rr"}
554	}
555
556	off1, err = rr.pack(msg, off, compression, compress)
557	if err != nil {
558		return len(msg), err
559	}
560	// TODO(miek): Not sure if this is needed? If removed we can remove rawmsg.go as well.
561	if rawSetRdlength(msg, off, off1) {
562		return off1, nil
563	}
564	return off, ErrRdata
565}
566
567// UnpackRR unpacks msg[off:] into an RR.
568func UnpackRR(msg []byte, off int) (rr RR, off1 int, err error) {
569	h, off, msg, err := unpackHeader(msg, off)
570	if err != nil {
571		return nil, len(msg), err
572	}
573
574	return UnpackRRWithHeader(h, msg, off)
575}
576
577// UnpackRRWithHeader unpacks the record type specific payload given an existing
578// RR_Header.
579func UnpackRRWithHeader(h RR_Header, msg []byte, off int) (rr RR, off1 int, err error) {
580	end := off + int(h.Rdlength)
581
582	if fn, known := typeToUnpack[h.Rrtype]; !known {
583		rr, off, err = unpackRFC3597(h, msg, off)
584	} else {
585		rr, off, err = fn(h, msg, off)
586	}
587	if off != end {
588		return &h, end, &Error{err: "bad rdlength"}
589	}
590	return rr, off, err
591}
592
593// unpackRRslice unpacks msg[off:] into an []RR.
594// If we cannot unpack the whole array, then it will return nil
595func unpackRRslice(l int, msg []byte, off int) (dst1 []RR, off1 int, err error) {
596	var r RR
597	// Don't pre-allocate, l may be under attacker control
598	var dst []RR
599	for i := 0; i < l; i++ {
600		off1 := off
601		r, off, err = UnpackRR(msg, off)
602		if err != nil {
603			off = len(msg)
604			break
605		}
606		// If offset does not increase anymore, l is a lie
607		if off1 == off {
608			l = i
609			break
610		}
611		dst = append(dst, r)
612	}
613	if err != nil && off == len(msg) {
614		dst = nil
615	}
616	return dst, off, err
617}
618
619// Convert a MsgHdr to a string, with dig-like headers:
620//
621//;; opcode: QUERY, status: NOERROR, id: 48404
622//
623//;; flags: qr aa rd ra;
624func (h *MsgHdr) String() string {
625	if h == nil {
626		return "<nil> MsgHdr"
627	}
628
629	s := ";; opcode: " + OpcodeToString[h.Opcode]
630	s += ", status: " + RcodeToString[h.Rcode]
631	s += ", id: " + strconv.Itoa(int(h.Id)) + "\n"
632
633	s += ";; flags:"
634	if h.Response {
635		s += " qr"
636	}
637	if h.Authoritative {
638		s += " aa"
639	}
640	if h.Truncated {
641		s += " tc"
642	}
643	if h.RecursionDesired {
644		s += " rd"
645	}
646	if h.RecursionAvailable {
647		s += " ra"
648	}
649	if h.Zero { // Hmm
650		s += " z"
651	}
652	if h.AuthenticatedData {
653		s += " ad"
654	}
655	if h.CheckingDisabled {
656		s += " cd"
657	}
658
659	s += ";"
660	return s
661}
662
663// Pack packs a Msg: it is converted to to wire format.
664// If the dns.Compress is true the message will be in compressed wire format.
665func (dns *Msg) Pack() (msg []byte, err error) {
666	return dns.PackBuffer(nil)
667}
668
669// PackBuffer packs a Msg, using the given buffer buf. If buf is too small a new buffer is allocated.
670func (dns *Msg) PackBuffer(buf []byte) (msg []byte, err error) {
671	var compression map[string]int
672	if dns.Compress {
673		compression = make(map[string]int) // Compression pointer mappings.
674	}
675	return dns.packBufferWithCompressionMap(buf, compression)
676}
677
678// packBufferWithCompressionMap packs a Msg, using the given buffer buf.
679func (dns *Msg) packBufferWithCompressionMap(buf []byte, compression map[string]int) (msg []byte, err error) {
680	// We use a similar function in tsig.go's stripTsig.
681
682	var dh Header
683
684	if dns.Rcode < 0 || dns.Rcode > 0xFFF {
685		return nil, ErrRcode
686	}
687	if dns.Rcode > 0xF {
688		// Regular RCODE field is 4 bits
689		opt := dns.IsEdns0()
690		if opt == nil {
691			return nil, ErrExtendedRcode
692		}
693		opt.SetExtendedRcode(uint8(dns.Rcode >> 4))
694	}
695
696	// Convert convenient Msg into wire-like Header.
697	dh.Id = dns.Id
698	dh.Bits = uint16(dns.Opcode)<<11 | uint16(dns.Rcode&0xF)
699	if dns.Response {
700		dh.Bits |= _QR
701	}
702	if dns.Authoritative {
703		dh.Bits |= _AA
704	}
705	if dns.Truncated {
706		dh.Bits |= _TC
707	}
708	if dns.RecursionDesired {
709		dh.Bits |= _RD
710	}
711	if dns.RecursionAvailable {
712		dh.Bits |= _RA
713	}
714	if dns.Zero {
715		dh.Bits |= _Z
716	}
717	if dns.AuthenticatedData {
718		dh.Bits |= _AD
719	}
720	if dns.CheckingDisabled {
721		dh.Bits |= _CD
722	}
723
724	// Prepare variable sized arrays.
725	question := dns.Question
726	answer := dns.Answer
727	ns := dns.Ns
728	extra := dns.Extra
729
730	dh.Qdcount = uint16(len(question))
731	dh.Ancount = uint16(len(answer))
732	dh.Nscount = uint16(len(ns))
733	dh.Arcount = uint16(len(extra))
734
735	// We need the uncompressed length here, because we first pack it and then compress it.
736	msg = buf
737	uncompressedLen := compressedLen(dns, false)
738	if packLen := uncompressedLen + 1; len(msg) < packLen {
739		msg = make([]byte, packLen)
740	}
741
742	// Pack it in: header and then the pieces.
743	off := 0
744	off, err = dh.pack(msg, off, compression, dns.Compress)
745	if err != nil {
746		return nil, err
747	}
748	for i := 0; i < len(question); i++ {
749		off, err = question[i].pack(msg, off, compression, dns.Compress)
750		if err != nil {
751			return nil, err
752		}
753	}
754	for i := 0; i < len(answer); i++ {
755		off, err = PackRR(answer[i], msg, off, compression, dns.Compress)
756		if err != nil {
757			return nil, err
758		}
759	}
760	for i := 0; i < len(ns); i++ {
761		off, err = PackRR(ns[i], msg, off, compression, dns.Compress)
762		if err != nil {
763			return nil, err
764		}
765	}
766	for i := 0; i < len(extra); i++ {
767		off, err = PackRR(extra[i], msg, off, compression, dns.Compress)
768		if err != nil {
769			return nil, err
770		}
771	}
772	return msg[:off], nil
773}
774
775// Unpack unpacks a binary message to a Msg structure.
776func (dns *Msg) Unpack(msg []byte) (err error) {
777	var (
778		dh  Header
779		off int
780	)
781	if dh, off, err = unpackMsgHdr(msg, off); err != nil {
782		return err
783	}
784
785	dns.Id = dh.Id
786	dns.Response = dh.Bits&_QR != 0
787	dns.Opcode = int(dh.Bits>>11) & 0xF
788	dns.Authoritative = dh.Bits&_AA != 0
789	dns.Truncated = dh.Bits&_TC != 0
790	dns.RecursionDesired = dh.Bits&_RD != 0
791	dns.RecursionAvailable = dh.Bits&_RA != 0
792	dns.Zero = dh.Bits&_Z != 0
793	dns.AuthenticatedData = dh.Bits&_AD != 0
794	dns.CheckingDisabled = dh.Bits&_CD != 0
795	dns.Rcode = int(dh.Bits & 0xF)
796
797	// If we are at the end of the message we should return *just* the
798	// header. This can still be useful to the caller. 9.9.9.9 sends these
799	// when responding with REFUSED for instance.
800	if off == len(msg) {
801		// reset sections before returning
802		dns.Question, dns.Answer, dns.Ns, dns.Extra = nil, nil, nil, nil
803		return nil
804	}
805
806	// Qdcount, Ancount, Nscount, Arcount can't be trusted, as they are
807	// attacker controlled. This means we can't use them to pre-allocate
808	// slices.
809	dns.Question = nil
810	for i := 0; i < int(dh.Qdcount); i++ {
811		off1 := off
812		var q Question
813		q, off, err = unpackQuestion(msg, off)
814		if err != nil {
815			// Even if Truncated is set, we only will set ErrTruncated if we
816			// actually got the questions
817			return err
818		}
819		if off1 == off { // Offset does not increase anymore, dh.Qdcount is a lie!
820			dh.Qdcount = uint16(i)
821			break
822		}
823		dns.Question = append(dns.Question, q)
824	}
825
826	dns.Answer, off, err = unpackRRslice(int(dh.Ancount), msg, off)
827	// The header counts might have been wrong so we need to update it
828	dh.Ancount = uint16(len(dns.Answer))
829	if err == nil {
830		dns.Ns, off, err = unpackRRslice(int(dh.Nscount), msg, off)
831	}
832	// The header counts might have been wrong so we need to update it
833	dh.Nscount = uint16(len(dns.Ns))
834	if err == nil {
835		dns.Extra, off, err = unpackRRslice(int(dh.Arcount), msg, off)
836	}
837	// The header counts might have been wrong so we need to update it
838	dh.Arcount = uint16(len(dns.Extra))
839
840	if off != len(msg) {
841		// TODO(miek) make this an error?
842		// use PackOpt to let people tell how detailed the error reporting should be?
843		// println("dns: extra bytes in dns packet", off, "<", len(msg))
844	} else if dns.Truncated {
845		// Whether we ran into a an error or not, we want to return that it
846		// was truncated
847		err = ErrTruncated
848	}
849	return err
850}
851
852// Convert a complete message to a string with dig-like output.
853func (dns *Msg) String() string {
854	if dns == nil {
855		return "<nil> MsgHdr"
856	}
857	s := dns.MsgHdr.String() + " "
858	s += "QUERY: " + strconv.Itoa(len(dns.Question)) + ", "
859	s += "ANSWER: " + strconv.Itoa(len(dns.Answer)) + ", "
860	s += "AUTHORITY: " + strconv.Itoa(len(dns.Ns)) + ", "
861	s += "ADDITIONAL: " + strconv.Itoa(len(dns.Extra)) + "\n"
862	if len(dns.Question) > 0 {
863		s += "\n;; QUESTION SECTION:\n"
864		for i := 0; i < len(dns.Question); i++ {
865			s += dns.Question[i].String() + "\n"
866		}
867	}
868	if len(dns.Answer) > 0 {
869		s += "\n;; ANSWER SECTION:\n"
870		for i := 0; i < len(dns.Answer); i++ {
871			if dns.Answer[i] != nil {
872				s += dns.Answer[i].String() + "\n"
873			}
874		}
875	}
876	if len(dns.Ns) > 0 {
877		s += "\n;; AUTHORITY SECTION:\n"
878		for i := 0; i < len(dns.Ns); i++ {
879			if dns.Ns[i] != nil {
880				s += dns.Ns[i].String() + "\n"
881			}
882		}
883	}
884	if len(dns.Extra) > 0 {
885		s += "\n;; ADDITIONAL SECTION:\n"
886		for i := 0; i < len(dns.Extra); i++ {
887			if dns.Extra[i] != nil {
888				s += dns.Extra[i].String() + "\n"
889			}
890		}
891	}
892	return s
893}
894
895// Len returns the message length when in (un)compressed wire format.
896// If dns.Compress is true compression it is taken into account. Len()
897// is provided to be a faster way to get the size of the resulting packet,
898// than packing it, measuring the size and discarding the buffer.
899func (dns *Msg) Len() int { return compressedLen(dns, dns.Compress) }
900
901func compressedLenWithCompressionMap(dns *Msg, compression map[string]int) int {
902	l := 12 // Message header is always 12 bytes
903	for _, r := range dns.Question {
904		compressionLenHelper(compression, r.Name, l)
905		l += r.len()
906	}
907	l += compressionLenSlice(l, compression, dns.Answer)
908	l += compressionLenSlice(l, compression, dns.Ns)
909	l += compressionLenSlice(l, compression, dns.Extra)
910	return l
911}
912
913// compressedLen returns the message length when in compressed wire format
914// when compress is true, otherwise the uncompressed length is returned.
915func compressedLen(dns *Msg, compress bool) int {
916	// We always return one more than needed.
917	if compress {
918		compression := map[string]int{}
919		return compressedLenWithCompressionMap(dns, compression)
920	}
921	l := 12 // Message header is always 12 bytes
922
923	for _, r := range dns.Question {
924		l += r.len()
925	}
926	for _, r := range dns.Answer {
927		if r != nil {
928			l += r.len()
929		}
930	}
931	for _, r := range dns.Ns {
932		if r != nil {
933			l += r.len()
934		}
935	}
936	for _, r := range dns.Extra {
937		if r != nil {
938			l += r.len()
939		}
940	}
941
942	return l
943}
944
945func compressionLenSlice(lenp int, c map[string]int, rs []RR) int {
946	initLen := lenp
947	for _, r := range rs {
948		if r == nil {
949			continue
950		}
951		// TmpLen is to track len of record at 14bits boudaries
952		tmpLen := lenp
953
954		x := r.len()
955		// track this length, and the global length in len, while taking compression into account for both.
956		k, ok, _ := compressionLenSearch(c, r.Header().Name)
957		if ok {
958			// Size of x is reduced by k, but we add 1 since k includes the '.' and label descriptor take 2 bytes
959			// so, basically x:= x - k - 1 + 2
960			x += 1 - k
961		}
962
963		tmpLen += compressionLenHelper(c, r.Header().Name, tmpLen)
964		k, ok, _ = compressionLenSearchType(c, r)
965		if ok {
966			x += 1 - k
967		}
968		lenp += x
969		tmpLen = lenp
970		tmpLen += compressionLenHelperType(c, r, tmpLen)
971
972	}
973	return lenp - initLen
974}
975
976// Put the parts of the name in the compression map, return the size in bytes added in payload
977func compressionLenHelper(c map[string]int, s string, currentLen int) int {
978	if currentLen > maxCompressionOffset {
979		// We won't be able to add any label that could be re-used later anyway
980		return 0
981	}
982	if _, ok := c[s]; ok {
983		return 0
984	}
985	initLen := currentLen
986	pref := ""
987	prev := s
988	lbs := Split(s)
989	for j := 0; j < len(lbs); j++ {
990		pref = s[lbs[j]:]
991		currentLen += len(prev) - len(pref)
992		prev = pref
993		if _, ok := c[pref]; !ok {
994			// If first byte label is within the first 14bits, it might be re-used later
995			if currentLen < maxCompressionOffset {
996				c[pref] = currentLen
997			}
998		} else {
999			added := currentLen - initLen
1000			if j > 0 {
1001				// We added a new PTR
1002				added += 2
1003			}
1004			return added
1005		}
1006	}
1007	return currentLen - initLen
1008}
1009
1010// Look for each part in the compression map and returns its length,
1011// keep on searching so we get the longest match.
1012// Will return the size of compression found, whether a match has been
1013// found and the size of record if added in payload
1014func compressionLenSearch(c map[string]int, s string) (int, bool, int) {
1015	off := 0
1016	end := false
1017	if s == "" { // don't bork on bogus data
1018		return 0, false, 0
1019	}
1020	fullSize := 0
1021	for {
1022		if _, ok := c[s[off:]]; ok {
1023			return len(s[off:]), true, fullSize + off
1024		}
1025		if end {
1026			break
1027		}
1028		// Each label descriptor takes 2 bytes, add it
1029		fullSize += 2
1030		off, end = NextLabel(s, off)
1031	}
1032	return 0, false, fullSize + len(s)
1033}
1034
1035// Copy returns a new RR which is a deep-copy of r.
1036func Copy(r RR) RR { r1 := r.copy(); return r1 }
1037
1038// Len returns the length (in octets) of the uncompressed RR in wire format.
1039func Len(r RR) int { return r.len() }
1040
1041// Copy returns a new *Msg which is a deep-copy of dns.
1042func (dns *Msg) Copy() *Msg { return dns.CopyTo(new(Msg)) }
1043
1044// CopyTo copies the contents to the provided message using a deep-copy and returns the copy.
1045func (dns *Msg) CopyTo(r1 *Msg) *Msg {
1046	r1.MsgHdr = dns.MsgHdr
1047	r1.Compress = dns.Compress
1048
1049	if len(dns.Question) > 0 {
1050		r1.Question = make([]Question, len(dns.Question))
1051		copy(r1.Question, dns.Question) // TODO(miek): Question is an immutable value, ok to do a shallow-copy
1052	}
1053
1054	rrArr := make([]RR, len(dns.Answer)+len(dns.Ns)+len(dns.Extra))
1055	var rri int
1056
1057	if len(dns.Answer) > 0 {
1058		rrbegin := rri
1059		for i := 0; i < len(dns.Answer); i++ {
1060			rrArr[rri] = dns.Answer[i].copy()
1061			rri++
1062		}
1063		r1.Answer = rrArr[rrbegin:rri:rri]
1064	}
1065
1066	if len(dns.Ns) > 0 {
1067		rrbegin := rri
1068		for i := 0; i < len(dns.Ns); i++ {
1069			rrArr[rri] = dns.Ns[i].copy()
1070			rri++
1071		}
1072		r1.Ns = rrArr[rrbegin:rri:rri]
1073	}
1074
1075	if len(dns.Extra) > 0 {
1076		rrbegin := rri
1077		for i := 0; i < len(dns.Extra); i++ {
1078			rrArr[rri] = dns.Extra[i].copy()
1079			rri++
1080		}
1081		r1.Extra = rrArr[rrbegin:rri:rri]
1082	}
1083
1084	return r1
1085}
1086
1087func (q *Question) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) {
1088	off, err := PackDomainName(q.Name, msg, off, compression, compress)
1089	if err != nil {
1090		return off, err
1091	}
1092	off, err = packUint16(q.Qtype, msg, off)
1093	if err != nil {
1094		return off, err
1095	}
1096	off, err = packUint16(q.Qclass, msg, off)
1097	if err != nil {
1098		return off, err
1099	}
1100	return off, nil
1101}
1102
1103func unpackQuestion(msg []byte, off int) (Question, int, error) {
1104	var (
1105		q   Question
1106		err error
1107	)
1108	q.Name, off, err = UnpackDomainName(msg, off)
1109	if err != nil {
1110		return q, off, err
1111	}
1112	if off == len(msg) {
1113		return q, off, nil
1114	}
1115	q.Qtype, off, err = unpackUint16(msg, off)
1116	if err != nil {
1117		return q, off, err
1118	}
1119	if off == len(msg) {
1120		return q, off, nil
1121	}
1122	q.Qclass, off, err = unpackUint16(msg, off)
1123	if off == len(msg) {
1124		return q, off, nil
1125	}
1126	return q, off, err
1127}
1128
1129func (dh *Header) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) {
1130	off, err := packUint16(dh.Id, msg, off)
1131	if err != nil {
1132		return off, err
1133	}
1134	off, err = packUint16(dh.Bits, msg, off)
1135	if err != nil {
1136		return off, err
1137	}
1138	off, err = packUint16(dh.Qdcount, msg, off)
1139	if err != nil {
1140		return off, err
1141	}
1142	off, err = packUint16(dh.Ancount, msg, off)
1143	if err != nil {
1144		return off, err
1145	}
1146	off, err = packUint16(dh.Nscount, msg, off)
1147	if err != nil {
1148		return off, err
1149	}
1150	off, err = packUint16(dh.Arcount, msg, off)
1151	return off, err
1152}
1153
1154func unpackMsgHdr(msg []byte, off int) (Header, int, error) {
1155	var (
1156		dh  Header
1157		err error
1158	)
1159	dh.Id, off, err = unpackUint16(msg, off)
1160	if err != nil {
1161		return dh, off, err
1162	}
1163	dh.Bits, off, err = unpackUint16(msg, off)
1164	if err != nil {
1165		return dh, off, err
1166	}
1167	dh.Qdcount, off, err = unpackUint16(msg, off)
1168	if err != nil {
1169		return dh, off, err
1170	}
1171	dh.Ancount, off, err = unpackUint16(msg, off)
1172	if err != nil {
1173		return dh, off, err
1174	}
1175	dh.Nscount, off, err = unpackUint16(msg, off)
1176	if err != nil {
1177		return dh, off, err
1178	}
1179	dh.Arcount, off, err = unpackUint16(msg, off)
1180	return dh, off, err
1181}
1182