1// DNS packet assembly, see RFC 1035. Converting from - Unpack() -
2// and to - Pack() - wire format.
3// All the packers and unpackers take a (msg []byte, off int)
4// and return (off1 int, ok bool).  If they return ok==false, they
5// also return off1==len(msg), so that the next unpacker will
6// also fail.  This lets us avoid checks of ok until the end of a
7// packing sequence.
8
9package dns
10
11//go:generate go run msg_generate.go
12
13import (
14	crand "crypto/rand"
15	"encoding/binary"
16	"math/big"
17	"math/rand"
18	"strconv"
19)
20
21func init() {
22	// Initialize default math/rand source using crypto/rand to provide better
23	// security without the performance trade-off.
24	buf := make([]byte, 8)
25	_, err := crand.Read(buf)
26	if err != nil {
27		// Failed to read from cryptographic source, fallback to default initial
28		// seed (1) by returning early
29		return
30	}
31	seed := binary.BigEndian.Uint64(buf)
32	rand.Seed(int64(seed))
33}
34
35const maxCompressionOffset = 2 << 13 // We have 14 bits for the compression pointer
36
37var (
38	ErrAlg           error = &Error{err: "bad algorithm"}                  // ErrAlg indicates an error with the (DNSSEC) algorithm.
39	ErrAuth          error = &Error{err: "bad authentication"}             // ErrAuth indicates an error in the TSIG authentication.
40	ErrBuf           error = &Error{err: "buffer size too small"}          // ErrBuf indicates that the buffer used it too small for the message.
41	ErrConnEmpty     error = &Error{err: "conn has no connection"}         // ErrConnEmpty indicates a connection is being uses before it is initialized.
42	ErrExtendedRcode error = &Error{err: "bad extended rcode"}             // ErrExtendedRcode ...
43	ErrFqdn          error = &Error{err: "domain must be fully qualified"} // ErrFqdn indicates that a domain name does not have a closing dot.
44	ErrId            error = &Error{err: "id mismatch"}                    // ErrId indicates there is a mismatch with the message's ID.
45	ErrKeyAlg        error = &Error{err: "bad key algorithm"}              // ErrKeyAlg indicates that the algorithm in the key is not valid.
46	ErrKey           error = &Error{err: "bad key"}
47	ErrKeySize       error = &Error{err: "bad key size"}
48	ErrNoSig         error = &Error{err: "no signature found"}
49	ErrPrivKey       error = &Error{err: "bad private key"}
50	ErrRcode         error = &Error{err: "bad rcode"}
51	ErrRdata         error = &Error{err: "bad rdata"}
52	ErrRRset         error = &Error{err: "bad rrset"}
53	ErrSecret        error = &Error{err: "no secrets defined"}
54	ErrShortRead     error = &Error{err: "short read"}
55	ErrSig           error = &Error{err: "bad signature"}                      // ErrSig indicates that a signature can not be cryptographically validated.
56	ErrSoa           error = &Error{err: "no SOA"}                             // ErrSOA indicates that no SOA RR was seen when doing zone transfers.
57	ErrTime          error = &Error{err: "bad time"}                           // ErrTime indicates a timing error in TSIG authentication.
58	ErrTruncated     error = &Error{err: "failed to unpack truncated message"} // ErrTruncated indicates that we failed to unpack a truncated message. We unpacked as much as we had so Msg can still be used, if desired.
59)
60
61// Id by default, returns a 16 bits random number to be used as a
62// message id. The random provided should be good enough. This being a
63// variable the function can be reassigned to a custom function.
64// For instance, to make it return a static value:
65//
66//	dns.Id = func() uint16 { return 3 }
67var Id func() uint16 = id
68
69// id returns a 16 bits random number to be used as a
70// message id. The random provided should be good enough.
71func id() uint16 {
72	id32 := rand.Uint32()
73	return uint16(id32)
74}
75
76// MsgHdr is a a manually-unpacked version of (id, bits).
77type MsgHdr struct {
78	Id                 uint16
79	Response           bool
80	Opcode             int
81	Authoritative      bool
82	Truncated          bool
83	RecursionDesired   bool
84	RecursionAvailable bool
85	Zero               bool
86	AuthenticatedData  bool
87	CheckingDisabled   bool
88	Rcode              int
89}
90
91// Msg contains the layout of a DNS message.
92type Msg struct {
93	MsgHdr
94	Compress bool       `json:"-"` // If true, the message will be compressed when converted to wire format.
95	Question []Question // Holds the RR(s) of the question section.
96	Answer   []RR       // Holds the RR(s) of the answer section.
97	Ns       []RR       // Holds the RR(s) of the authority section.
98	Extra    []RR       // Holds the RR(s) of the additional section.
99}
100
101// ClassToString is a maps Classes to strings for each CLASS wire type.
102var ClassToString = map[uint16]string{
103	ClassINET:   "IN",
104	ClassCSNET:  "CS",
105	ClassCHAOS:  "CH",
106	ClassHESIOD: "HS",
107	ClassNONE:   "NONE",
108	ClassANY:    "ANY",
109}
110
111// OpcodeToString maps Opcodes to strings.
112var OpcodeToString = map[int]string{
113	OpcodeQuery:  "QUERY",
114	OpcodeIQuery: "IQUERY",
115	OpcodeStatus: "STATUS",
116	OpcodeNotify: "NOTIFY",
117	OpcodeUpdate: "UPDATE",
118}
119
120// RcodeToString maps Rcodes to strings.
121var RcodeToString = map[int]string{
122	RcodeSuccess:        "NOERROR",
123	RcodeFormatError:    "FORMERR",
124	RcodeServerFailure:  "SERVFAIL",
125	RcodeNameError:      "NXDOMAIN",
126	RcodeNotImplemented: "NOTIMPL",
127	RcodeRefused:        "REFUSED",
128	RcodeYXDomain:       "YXDOMAIN", // See RFC 2136
129	RcodeYXRrset:        "YXRRSET",
130	RcodeNXRrset:        "NXRRSET",
131	RcodeNotAuth:        "NOTAUTH",
132	RcodeNotZone:        "NOTZONE",
133	RcodeBadSig:         "BADSIG", // Also known as RcodeBadVers, see RFC 6891
134	//	RcodeBadVers:        "BADVERS",
135	RcodeBadKey:    "BADKEY",
136	RcodeBadTime:   "BADTIME",
137	RcodeBadMode:   "BADMODE",
138	RcodeBadName:   "BADNAME",
139	RcodeBadAlg:    "BADALG",
140	RcodeBadTrunc:  "BADTRUNC",
141	RcodeBadCookie: "BADCOOKIE",
142}
143
144// Domain names are a sequence of counted strings
145// split at the dots. They end with a zero-length string.
146
147// PackDomainName packs a domain name s into msg[off:].
148// If compression is wanted compress must be true and the compression
149// map needs to hold a mapping between domain names and offsets
150// pointing into msg.
151func PackDomainName(s string, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
152	off1, _, err = packDomainName(s, msg, off, compression, compress)
153	return
154}
155
156func packDomainName(s string, msg []byte, off int, compression map[string]int, compress bool) (off1 int, labels int, err error) {
157	// special case if msg == nil
158	lenmsg := 256
159	if msg != nil {
160		lenmsg = len(msg)
161	}
162	ls := len(s)
163	if ls == 0 { // Ok, for instance when dealing with update RR without any rdata.
164		return off, 0, nil
165	}
166	// If not fully qualified, error out, but only if msg == nil #ugly
167	switch {
168	case msg == nil:
169		if s[ls-1] != '.' {
170			s += "."
171			ls++
172		}
173	case msg != nil:
174		if s[ls-1] != '.' {
175			return lenmsg, 0, ErrFqdn
176		}
177	}
178	// Each dot ends a segment of the name.
179	// We trade each dot byte for a length byte.
180	// Except for escaped dots (\.), which are normal dots.
181	// There is also a trailing zero.
182
183	// Compression
184	nameoffset := -1
185	pointer := -1
186	// Emit sequence of counted strings, chopping at dots.
187	begin := 0
188	bs := []byte(s)
189	roBs, bsFresh, escapedDot := s, true, false
190	for i := 0; i < ls; i++ {
191		if bs[i] == '\\' {
192			for j := i; j < ls-1; j++ {
193				bs[j] = bs[j+1]
194			}
195			ls--
196			if off+1 > lenmsg {
197				return lenmsg, labels, ErrBuf
198			}
199			// check for \DDD
200			if i+2 < ls && isDigit(bs[i]) && isDigit(bs[i+1]) && isDigit(bs[i+2]) {
201				bs[i] = dddToByte(bs[i:])
202				for j := i + 1; j < ls-2; j++ {
203					bs[j] = bs[j+2]
204				}
205				ls -= 2
206			} else if bs[i] == 't' {
207				bs[i] = '\t'
208			} else if bs[i] == 'r' {
209				bs[i] = '\r'
210			} else if bs[i] == 'n' {
211				bs[i] = '\n'
212			}
213			escapedDot = bs[i] == '.'
214			bsFresh = false
215			continue
216		}
217
218		if bs[i] == '.' {
219			if i > 0 && bs[i-1] == '.' && !escapedDot {
220				// two dots back to back is not legal
221				return lenmsg, labels, ErrRdata
222			}
223			if i-begin >= 1<<6 { // top two bits of length must be clear
224				return lenmsg, labels, ErrRdata
225			}
226			// off can already (we're in a loop) be bigger than len(msg)
227			// this happens when a name isn't fully qualified
228			if off+1 > lenmsg {
229				return lenmsg, labels, ErrBuf
230			}
231			if msg != nil {
232				msg[off] = byte(i - begin)
233			}
234			offset := off
235			off++
236			for j := begin; j < i; j++ {
237				if off+1 > lenmsg {
238					return lenmsg, labels, ErrBuf
239				}
240				if msg != nil {
241					msg[off] = bs[j]
242				}
243				off++
244			}
245			if compress && !bsFresh {
246				roBs = string(bs)
247				bsFresh = true
248			}
249			// Don't try to compress '.'
250			if compress && roBs[begin:] != "." {
251				if p, ok := compression[roBs[begin:]]; !ok {
252					// Only offsets smaller than this can be used.
253					if offset < maxCompressionOffset {
254						compression[roBs[begin:]] = offset
255					}
256				} else {
257					// The first hit is the longest matching dname
258					// keep the pointer offset we get back and store
259					// the offset of the current name, because that's
260					// where we need to insert the pointer later
261
262					// If compress is true, we're allowed to compress this dname
263					if pointer == -1 && compress {
264						pointer = p         // Where to point to
265						nameoffset = offset // Where to point from
266						break
267					}
268				}
269			}
270			labels++
271			begin = i + 1
272		}
273		escapedDot = false
274	}
275	// Root label is special
276	if len(bs) == 1 && bs[0] == '.' {
277		return off, labels, nil
278	}
279	// If we did compression and we find something add the pointer here
280	if pointer != -1 {
281		// We have two bytes (14 bits) to put the pointer in
282		// if msg == nil, we will never do compression
283		binary.BigEndian.PutUint16(msg[nameoffset:], uint16(pointer^0xC000))
284		off = nameoffset + 1
285		goto End
286	}
287	if msg != nil && off < len(msg) {
288		msg[off] = 0
289	}
290End:
291	off++
292	return off, labels, nil
293}
294
295// Unpack a domain name.
296// In addition to the simple sequences of counted strings above,
297// domain names are allowed to refer to strings elsewhere in the
298// packet, to avoid repeating common suffixes when returning
299// many entries in a single domain.  The pointers are marked
300// by a length byte with the top two bits set.  Ignoring those
301// two bits, that byte and the next give a 14 bit offset from msg[0]
302// where we should pick up the trail.
303// Note that if we jump elsewhere in the packet,
304// we return off1 == the offset after the first pointer we found,
305// which is where the next record will start.
306// In theory, the pointers are only allowed to jump backward.
307// We let them jump anywhere and stop jumping after a while.
308
309// UnpackDomainName unpacks a domain name into a string.
310func UnpackDomainName(msg []byte, off int) (string, int, error) {
311	s := make([]byte, 0, 64)
312	off1 := 0
313	lenmsg := len(msg)
314	ptr := 0 // number of pointers followed
315Loop:
316	for {
317		if off >= lenmsg {
318			return "", lenmsg, ErrBuf
319		}
320		c := int(msg[off])
321		off++
322		switch c & 0xC0 {
323		case 0x00:
324			if c == 0x00 {
325				// end of name
326				break Loop
327			}
328			// literal string
329			if off+c > lenmsg {
330				return "", lenmsg, ErrBuf
331			}
332			for j := off; j < off+c; j++ {
333				switch b := msg[j]; b {
334				case '.', '(', ')', ';', ' ', '@':
335					fallthrough
336				case '"', '\\':
337					s = append(s, '\\', b)
338				case '\t':
339					s = append(s, '\\', 't')
340				case '\r':
341					s = append(s, '\\', 'r')
342				default:
343					if b < 32 || b >= 127 { // unprintable use \DDD
344						var buf [3]byte
345						bufs := strconv.AppendInt(buf[:0], int64(b), 10)
346						s = append(s, '\\')
347						for i := 0; i < 3-len(bufs); i++ {
348							s = append(s, '0')
349						}
350						for _, r := range bufs {
351							s = append(s, r)
352						}
353					} else {
354						s = append(s, b)
355					}
356				}
357			}
358			s = append(s, '.')
359			off += c
360		case 0xC0:
361			// pointer to somewhere else in msg.
362			// remember location after first ptr,
363			// since that's how many bytes we consumed.
364			// also, don't follow too many pointers --
365			// maybe there's a loop.
366			if off >= lenmsg {
367				return "", lenmsg, ErrBuf
368			}
369			c1 := msg[off]
370			off++
371			if ptr == 0 {
372				off1 = off
373			}
374			if ptr++; ptr > 10 {
375				return "", lenmsg, &Error{err: "too many compression pointers"}
376			}
377			off = (c^0xC0)<<8 | int(c1)
378		default:
379			// 0x80 and 0x40 are reserved
380			return "", lenmsg, ErrRdata
381		}
382	}
383	if ptr == 0 {
384		off1 = off
385	}
386	if len(s) == 0 {
387		s = []byte(".")
388	}
389	return string(s), off1, nil
390}
391
392func packTxt(txt []string, msg []byte, offset int, tmp []byte) (int, error) {
393	if len(txt) == 0 {
394		if offset >= len(msg) {
395			return offset, ErrBuf
396		}
397		msg[offset] = 0
398		return offset, nil
399	}
400	var err error
401	for i := range txt {
402		if len(txt[i]) > len(tmp) {
403			return offset, ErrBuf
404		}
405		offset, err = packTxtString(txt[i], msg, offset, tmp)
406		if err != nil {
407			return offset, err
408		}
409	}
410	return offset, nil
411}
412
413func packTxtString(s string, msg []byte, offset int, tmp []byte) (int, error) {
414	lenByteOffset := offset
415	if offset >= len(msg) || len(s) > len(tmp) {
416		return offset, ErrBuf
417	}
418	offset++
419	bs := tmp[:len(s)]
420	copy(bs, s)
421	for i := 0; i < len(bs); i++ {
422		if len(msg) <= offset {
423			return offset, ErrBuf
424		}
425		if bs[i] == '\\' {
426			i++
427			if i == len(bs) {
428				break
429			}
430			// check for \DDD
431			if i+2 < len(bs) && isDigit(bs[i]) && isDigit(bs[i+1]) && isDigit(bs[i+2]) {
432				msg[offset] = dddToByte(bs[i:])
433				i += 2
434			} else if bs[i] == 't' {
435				msg[offset] = '\t'
436			} else if bs[i] == 'r' {
437				msg[offset] = '\r'
438			} else if bs[i] == 'n' {
439				msg[offset] = '\n'
440			} else {
441				msg[offset] = bs[i]
442			}
443		} else {
444			msg[offset] = bs[i]
445		}
446		offset++
447	}
448	l := offset - lenByteOffset - 1
449	if l > 255 {
450		return offset, &Error{err: "string exceeded 255 bytes in txt"}
451	}
452	msg[lenByteOffset] = byte(l)
453	return offset, nil
454}
455
456func packOctetString(s string, msg []byte, offset int, tmp []byte) (int, error) {
457	if offset >= len(msg) || len(s) > len(tmp) {
458		return offset, ErrBuf
459	}
460	bs := tmp[:len(s)]
461	copy(bs, s)
462	for i := 0; i < len(bs); i++ {
463		if len(msg) <= offset {
464			return offset, ErrBuf
465		}
466		if bs[i] == '\\' {
467			i++
468			if i == len(bs) {
469				break
470			}
471			// check for \DDD
472			if i+2 < len(bs) && isDigit(bs[i]) && isDigit(bs[i+1]) && isDigit(bs[i+2]) {
473				msg[offset] = dddToByte(bs[i:])
474				i += 2
475			} else {
476				msg[offset] = bs[i]
477			}
478		} else {
479			msg[offset] = bs[i]
480		}
481		offset++
482	}
483	return offset, nil
484}
485
486func unpackTxt(msg []byte, off0 int) (ss []string, off int, err error) {
487	off = off0
488	var s string
489	for off < len(msg) && err == nil {
490		s, off, err = unpackTxtString(msg, off)
491		if err == nil {
492			ss = append(ss, s)
493		}
494	}
495	return
496}
497
498func unpackTxtString(msg []byte, offset int) (string, int, error) {
499	if offset+1 > len(msg) {
500		return "", offset, &Error{err: "overflow unpacking txt"}
501	}
502	l := int(msg[offset])
503	if offset+l+1 > len(msg) {
504		return "", offset, &Error{err: "overflow unpacking txt"}
505	}
506	s := make([]byte, 0, l)
507	for _, b := range msg[offset+1 : offset+1+l] {
508		switch b {
509		case '"', '\\':
510			s = append(s, '\\', b)
511		case '\t':
512			s = append(s, `\t`...)
513		case '\r':
514			s = append(s, `\r`...)
515		case '\n':
516			s = append(s, `\n`...)
517		default:
518			if b < 32 || b > 127 { // unprintable
519				var buf [3]byte
520				bufs := strconv.AppendInt(buf[:0], int64(b), 10)
521				s = append(s, '\\')
522				for i := 0; i < 3-len(bufs); i++ {
523					s = append(s, '0')
524				}
525				for _, r := range bufs {
526					s = append(s, r)
527				}
528			} else {
529				s = append(s, b)
530			}
531		}
532	}
533	offset += 1 + l
534	return string(s), offset, nil
535}
536
537// Helpers for dealing with escaped bytes
538func isDigit(b byte) bool { return b >= '0' && b <= '9' }
539
540func dddToByte(s []byte) byte {
541	return byte((s[0]-'0')*100 + (s[1]-'0')*10 + (s[2] - '0'))
542}
543
544// Helper function for packing and unpacking
545func intToBytes(i *big.Int, length int) []byte {
546	buf := i.Bytes()
547	if len(buf) < length {
548		b := make([]byte, length)
549		copy(b[length-len(buf):], buf)
550		return b
551	}
552	return buf
553}
554
555// PackRR packs a resource record rr into msg[off:].
556// See PackDomainName for documentation about the compression.
557func PackRR(rr RR, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
558	if rr == nil {
559		return len(msg), &Error{err: "nil rr"}
560	}
561
562	off1, err = rr.pack(msg, off, compression, compress)
563	if err != nil {
564		return len(msg), err
565	}
566	// TODO(miek): Not sure if this is needed? If removed we can remove rawmsg.go as well.
567	if rawSetRdlength(msg, off, off1) {
568		return off1, nil
569	}
570	return off, ErrRdata
571}
572
573// UnpackRR unpacks msg[off:] into an RR.
574func UnpackRR(msg []byte, off int) (rr RR, off1 int, err error) {
575	h, off, msg, err := unpackHeader(msg, off)
576	if err != nil {
577		return nil, len(msg), err
578	}
579	end := off + int(h.Rdlength)
580
581	if fn, known := typeToUnpack[h.Rrtype]; !known {
582		rr, off, err = unpackRFC3597(h, msg, off)
583	} else {
584		rr, off, err = fn(h, msg, off)
585	}
586	if off != end {
587		return &h, end, &Error{err: "bad rdlength"}
588	}
589	return rr, off, err
590}
591
592// unpackRRslice unpacks msg[off:] into an []RR.
593// If we cannot unpack the whole array, then it will return nil
594func unpackRRslice(l int, msg []byte, off int) (dst1 []RR, off1 int, err error) {
595	var r RR
596	// Optimistically make dst be the length that was sent
597	dst := make([]RR, 0, l)
598	for i := 0; i < l; i++ {
599		off1 := off
600		r, off, err = UnpackRR(msg, off)
601		if err != nil {
602			off = len(msg)
603			break
604		}
605		// If offset does not increase anymore, l is a lie
606		if off1 == off {
607			l = i
608			break
609		}
610		dst = append(dst, r)
611	}
612	if err != nil && off == len(msg) {
613		dst = nil
614	}
615	return dst, off, err
616}
617
618// Convert a MsgHdr to a string, with dig-like headers:
619//
620//;; opcode: QUERY, status: NOERROR, id: 48404
621//
622//;; flags: qr aa rd ra;
623func (h *MsgHdr) String() string {
624	if h == nil {
625		return "<nil> MsgHdr"
626	}
627
628	s := ";; opcode: " + OpcodeToString[h.Opcode]
629	s += ", status: " + RcodeToString[h.Rcode]
630	s += ", id: " + strconv.Itoa(int(h.Id)) + "\n"
631
632	s += ";; flags:"
633	if h.Response {
634		s += " qr"
635	}
636	if h.Authoritative {
637		s += " aa"
638	}
639	if h.Truncated {
640		s += " tc"
641	}
642	if h.RecursionDesired {
643		s += " rd"
644	}
645	if h.RecursionAvailable {
646		s += " ra"
647	}
648	if h.Zero { // Hmm
649		s += " z"
650	}
651	if h.AuthenticatedData {
652		s += " ad"
653	}
654	if h.CheckingDisabled {
655		s += " cd"
656	}
657
658	s += ";"
659	return s
660}
661
662// Pack packs a Msg: it is converted to to wire format.
663// If the dns.Compress is true the message will be in compressed wire format.
664func (dns *Msg) Pack() (msg []byte, err error) {
665	return dns.PackBuffer(nil)
666}
667
668// PackBuffer packs a Msg, using the given buffer buf. If buf is too small
669// a new buffer is allocated.
670func (dns *Msg) PackBuffer(buf []byte) (msg []byte, err error) {
671	// We use a similar function in tsig.go's stripTsig.
672	var (
673		dh          Header
674		compression map[string]int
675	)
676
677	if dns.Compress {
678		compression = make(map[string]int) // Compression pointer mappings
679	}
680
681	if dns.Rcode < 0 || dns.Rcode > 0xFFF {
682		return nil, ErrRcode
683	}
684	if dns.Rcode > 0xF {
685		// Regular RCODE field is 4 bits
686		opt := dns.IsEdns0()
687		if opt == nil {
688			return nil, ErrExtendedRcode
689		}
690		opt.SetExtendedRcode(uint8(dns.Rcode >> 4))
691		dns.Rcode &= 0xF
692	}
693
694	// Convert convenient Msg into wire-like Header.
695	dh.Id = dns.Id
696	dh.Bits = uint16(dns.Opcode)<<11 | uint16(dns.Rcode)
697	if dns.Response {
698		dh.Bits |= _QR
699	}
700	if dns.Authoritative {
701		dh.Bits |= _AA
702	}
703	if dns.Truncated {
704		dh.Bits |= _TC
705	}
706	if dns.RecursionDesired {
707		dh.Bits |= _RD
708	}
709	if dns.RecursionAvailable {
710		dh.Bits |= _RA
711	}
712	if dns.Zero {
713		dh.Bits |= _Z
714	}
715	if dns.AuthenticatedData {
716		dh.Bits |= _AD
717	}
718	if dns.CheckingDisabled {
719		dh.Bits |= _CD
720	}
721
722	// Prepare variable sized arrays.
723	question := dns.Question
724	answer := dns.Answer
725	ns := dns.Ns
726	extra := dns.Extra
727
728	dh.Qdcount = uint16(len(question))
729	dh.Ancount = uint16(len(answer))
730	dh.Nscount = uint16(len(ns))
731	dh.Arcount = uint16(len(extra))
732
733	// We need the uncompressed length here, because we first pack it and then compress it.
734	msg = buf
735	compress := dns.Compress
736	dns.Compress = false
737	if packLen := dns.Len() + 1; len(msg) < packLen {
738		msg = make([]byte, packLen)
739	}
740	dns.Compress = compress
741
742	// Pack it in: header and then the pieces.
743	off := 0
744	off, err = dh.pack(msg, off, compression, dns.Compress)
745	if err != nil {
746		return nil, err
747	}
748	for i := 0; i < len(question); i++ {
749		off, err = question[i].pack(msg, off, compression, dns.Compress)
750		if err != nil {
751			return nil, err
752		}
753	}
754	for i := 0; i < len(answer); i++ {
755		off, err = PackRR(answer[i], msg, off, compression, dns.Compress)
756		if err != nil {
757			return nil, err
758		}
759	}
760	for i := 0; i < len(ns); i++ {
761		off, err = PackRR(ns[i], msg, off, compression, dns.Compress)
762		if err != nil {
763			return nil, err
764		}
765	}
766	for i := 0; i < len(extra); i++ {
767		off, err = PackRR(extra[i], msg, off, compression, dns.Compress)
768		if err != nil {
769			return nil, err
770		}
771	}
772	return msg[:off], nil
773}
774
775// Unpack unpacks a binary message to a Msg structure.
776func (dns *Msg) Unpack(msg []byte) (err error) {
777	var (
778		dh  Header
779		off int
780	)
781	if dh, off, err = unpackMsgHdr(msg, off); err != nil {
782		return err
783	}
784	if off == len(msg) {
785		return ErrTruncated
786	}
787
788	dns.Id = dh.Id
789	dns.Response = (dh.Bits & _QR) != 0
790	dns.Opcode = int(dh.Bits>>11) & 0xF
791	dns.Authoritative = (dh.Bits & _AA) != 0
792	dns.Truncated = (dh.Bits & _TC) != 0
793	dns.RecursionDesired = (dh.Bits & _RD) != 0
794	dns.RecursionAvailable = (dh.Bits & _RA) != 0
795	dns.Zero = (dh.Bits & _Z) != 0
796	dns.AuthenticatedData = (dh.Bits & _AD) != 0
797	dns.CheckingDisabled = (dh.Bits & _CD) != 0
798	dns.Rcode = int(dh.Bits & 0xF)
799
800	// Optimistically use the count given to us in the header
801	dns.Question = make([]Question, 0, int(dh.Qdcount))
802
803	for i := 0; i < int(dh.Qdcount); i++ {
804		off1 := off
805		var q Question
806		q, off, err = unpackQuestion(msg, off)
807		if err != nil {
808			// Even if Truncated is set, we only will set ErrTruncated if we
809			// actually got the questions
810			return err
811		}
812		if off1 == off { // Offset does not increase anymore, dh.Qdcount is a lie!
813			dh.Qdcount = uint16(i)
814			break
815		}
816		dns.Question = append(dns.Question, q)
817	}
818
819	dns.Answer, off, err = unpackRRslice(int(dh.Ancount), msg, off)
820	// The header counts might have been wrong so we need to update it
821	dh.Ancount = uint16(len(dns.Answer))
822	if err == nil {
823		dns.Ns, off, err = unpackRRslice(int(dh.Nscount), msg, off)
824	}
825	// The header counts might have been wrong so we need to update it
826	dh.Nscount = uint16(len(dns.Ns))
827	if err == nil {
828		dns.Extra, off, err = unpackRRslice(int(dh.Arcount), msg, off)
829	}
830	// The header counts might have been wrong so we need to update it
831	dh.Arcount = uint16(len(dns.Extra))
832
833	if off != len(msg) {
834		// TODO(miek) make this an error?
835		// use PackOpt to let people tell how detailed the error reporting should be?
836		// println("dns: extra bytes in dns packet", off, "<", len(msg))
837	} else if dns.Truncated {
838		// Whether we ran into a an error or not, we want to return that it
839		// was truncated
840		err = ErrTruncated
841	}
842	return err
843}
844
845// Convert a complete message to a string with dig-like output.
846func (dns *Msg) String() string {
847	if dns == nil {
848		return "<nil> MsgHdr"
849	}
850	s := dns.MsgHdr.String() + " "
851	s += "QUERY: " + strconv.Itoa(len(dns.Question)) + ", "
852	s += "ANSWER: " + strconv.Itoa(len(dns.Answer)) + ", "
853	s += "AUTHORITY: " + strconv.Itoa(len(dns.Ns)) + ", "
854	s += "ADDITIONAL: " + strconv.Itoa(len(dns.Extra)) + "\n"
855	if len(dns.Question) > 0 {
856		s += "\n;; QUESTION SECTION:\n"
857		for i := 0; i < len(dns.Question); i++ {
858			s += dns.Question[i].String() + "\n"
859		}
860	}
861	if len(dns.Answer) > 0 {
862		s += "\n;; ANSWER SECTION:\n"
863		for i := 0; i < len(dns.Answer); i++ {
864			if dns.Answer[i] != nil {
865				s += dns.Answer[i].String() + "\n"
866			}
867		}
868	}
869	if len(dns.Ns) > 0 {
870		s += "\n;; AUTHORITY SECTION:\n"
871		for i := 0; i < len(dns.Ns); i++ {
872			if dns.Ns[i] != nil {
873				s += dns.Ns[i].String() + "\n"
874			}
875		}
876	}
877	if len(dns.Extra) > 0 {
878		s += "\n;; ADDITIONAL SECTION:\n"
879		for i := 0; i < len(dns.Extra); i++ {
880			if dns.Extra[i] != nil {
881				s += dns.Extra[i].String() + "\n"
882			}
883		}
884	}
885	return s
886}
887
888// Len returns the message length when in (un)compressed wire format.
889// If dns.Compress is true compression it is taken into account. Len()
890// is provided to be a faster way to get the size of the resulting packet,
891// than packing it, measuring the size and discarding the buffer.
892func (dns *Msg) Len() int {
893	// We always return one more than needed.
894	l := 12 // Message header is always 12 bytes
895	var compression map[string]int
896	if dns.Compress {
897		compression = make(map[string]int)
898	}
899	for i := 0; i < len(dns.Question); i++ {
900		l += dns.Question[i].len()
901		if dns.Compress {
902			compressionLenHelper(compression, dns.Question[i].Name)
903		}
904	}
905	for i := 0; i < len(dns.Answer); i++ {
906		if dns.Answer[i] == nil {
907			continue
908		}
909		l += dns.Answer[i].len()
910		if dns.Compress {
911			k, ok := compressionLenSearch(compression, dns.Answer[i].Header().Name)
912			if ok {
913				l += 1 - k
914			}
915			compressionLenHelper(compression, dns.Answer[i].Header().Name)
916			k, ok = compressionLenSearchType(compression, dns.Answer[i])
917			if ok {
918				l += 1 - k
919			}
920			compressionLenHelperType(compression, dns.Answer[i])
921		}
922	}
923	for i := 0; i < len(dns.Ns); i++ {
924		if dns.Ns[i] == nil {
925			continue
926		}
927		l += dns.Ns[i].len()
928		if dns.Compress {
929			k, ok := compressionLenSearch(compression, dns.Ns[i].Header().Name)
930			if ok {
931				l += 1 - k
932			}
933			compressionLenHelper(compression, dns.Ns[i].Header().Name)
934			k, ok = compressionLenSearchType(compression, dns.Ns[i])
935			if ok {
936				l += 1 - k
937			}
938			compressionLenHelperType(compression, dns.Ns[i])
939		}
940	}
941	for i := 0; i < len(dns.Extra); i++ {
942		if dns.Extra[i] == nil {
943			continue
944		}
945		l += dns.Extra[i].len()
946		if dns.Compress {
947			k, ok := compressionLenSearch(compression, dns.Extra[i].Header().Name)
948			if ok {
949				l += 1 - k
950			}
951			compressionLenHelper(compression, dns.Extra[i].Header().Name)
952			k, ok = compressionLenSearchType(compression, dns.Extra[i])
953			if ok {
954				l += 1 - k
955			}
956			compressionLenHelperType(compression, dns.Extra[i])
957		}
958	}
959	return l
960}
961
962// Put the parts of the name in the compression map.
963func compressionLenHelper(c map[string]int, s string) {
964	pref := ""
965	lbs := Split(s)
966	for j := len(lbs) - 1; j >= 0; j-- {
967		pref = s[lbs[j]:]
968		if _, ok := c[pref]; !ok {
969			c[pref] = len(pref)
970		}
971	}
972}
973
974// Look for each part in the compression map and returns its length,
975// keep on searching so we get the longest match.
976func compressionLenSearch(c map[string]int, s string) (int, bool) {
977	off := 0
978	end := false
979	if s == "" { // don't bork on bogus data
980		return 0, false
981	}
982	for {
983		if _, ok := c[s[off:]]; ok {
984			return len(s[off:]), true
985		}
986		if end {
987			break
988		}
989		off, end = NextLabel(s, off)
990	}
991	return 0, false
992}
993
994// TODO(miek): should add all types, because the all can be *used* for compression. Autogenerate from msg_generate and put in zmsg.go
995func compressionLenHelperType(c map[string]int, r RR) {
996	switch x := r.(type) {
997	case *NS:
998		compressionLenHelper(c, x.Ns)
999	case *MX:
1000		compressionLenHelper(c, x.Mx)
1001	case *CNAME:
1002		compressionLenHelper(c, x.Target)
1003	case *PTR:
1004		compressionLenHelper(c, x.Ptr)
1005	case *SOA:
1006		compressionLenHelper(c, x.Ns)
1007		compressionLenHelper(c, x.Mbox)
1008	case *MB:
1009		compressionLenHelper(c, x.Mb)
1010	case *MG:
1011		compressionLenHelper(c, x.Mg)
1012	case *MR:
1013		compressionLenHelper(c, x.Mr)
1014	case *MF:
1015		compressionLenHelper(c, x.Mf)
1016	case *MD:
1017		compressionLenHelper(c, x.Md)
1018	case *RT:
1019		compressionLenHelper(c, x.Host)
1020	case *RP:
1021		compressionLenHelper(c, x.Mbox)
1022		compressionLenHelper(c, x.Txt)
1023	case *MINFO:
1024		compressionLenHelper(c, x.Rmail)
1025		compressionLenHelper(c, x.Email)
1026	case *AFSDB:
1027		compressionLenHelper(c, x.Hostname)
1028	case *SRV:
1029		compressionLenHelper(c, x.Target)
1030	case *NAPTR:
1031		compressionLenHelper(c, x.Replacement)
1032	case *RRSIG:
1033		compressionLenHelper(c, x.SignerName)
1034	case *NSEC:
1035		compressionLenHelper(c, x.NextDomain)
1036		// HIP?
1037	}
1038}
1039
1040// Only search on compressing these types.
1041func compressionLenSearchType(c map[string]int, r RR) (int, bool) {
1042	switch x := r.(type) {
1043	case *NS:
1044		return compressionLenSearch(c, x.Ns)
1045	case *MX:
1046		return compressionLenSearch(c, x.Mx)
1047	case *CNAME:
1048		return compressionLenSearch(c, x.Target)
1049	case *DNAME:
1050		return compressionLenSearch(c, x.Target)
1051	case *PTR:
1052		return compressionLenSearch(c, x.Ptr)
1053	case *SOA:
1054		k, ok := compressionLenSearch(c, x.Ns)
1055		k1, ok1 := compressionLenSearch(c, x.Mbox)
1056		if !ok && !ok1 {
1057			return 0, false
1058		}
1059		return k + k1, true
1060	case *MB:
1061		return compressionLenSearch(c, x.Mb)
1062	case *MG:
1063		return compressionLenSearch(c, x.Mg)
1064	case *MR:
1065		return compressionLenSearch(c, x.Mr)
1066	case *MF:
1067		return compressionLenSearch(c, x.Mf)
1068	case *MD:
1069		return compressionLenSearch(c, x.Md)
1070	case *RT:
1071		return compressionLenSearch(c, x.Host)
1072	case *MINFO:
1073		k, ok := compressionLenSearch(c, x.Rmail)
1074		k1, ok1 := compressionLenSearch(c, x.Email)
1075		if !ok && !ok1 {
1076			return 0, false
1077		}
1078		return k + k1, true
1079	case *AFSDB:
1080		return compressionLenSearch(c, x.Hostname)
1081	}
1082	return 0, false
1083}
1084
1085// Copy returns a new RR which is a deep-copy of r.
1086func Copy(r RR) RR { r1 := r.copy(); return r1 }
1087
1088// Len returns the length (in octets) of the uncompressed RR in wire format.
1089func Len(r RR) int { return r.len() }
1090
1091// Copy returns a new *Msg which is a deep-copy of dns.
1092func (dns *Msg) Copy() *Msg { return dns.CopyTo(new(Msg)) }
1093
1094// CopyTo copies the contents to the provided message using a deep-copy and returns the copy.
1095func (dns *Msg) CopyTo(r1 *Msg) *Msg {
1096	r1.MsgHdr = dns.MsgHdr
1097	r1.Compress = dns.Compress
1098
1099	if len(dns.Question) > 0 {
1100		r1.Question = make([]Question, len(dns.Question))
1101		copy(r1.Question, dns.Question) // TODO(miek): Question is an immutable value, ok to do a shallow-copy
1102	}
1103
1104	rrArr := make([]RR, len(dns.Answer)+len(dns.Ns)+len(dns.Extra))
1105	var rri int
1106
1107	if len(dns.Answer) > 0 {
1108		rrbegin := rri
1109		for i := 0; i < len(dns.Answer); i++ {
1110			rrArr[rri] = dns.Answer[i].copy()
1111			rri++
1112		}
1113		r1.Answer = rrArr[rrbegin:rri:rri]
1114	}
1115
1116	if len(dns.Ns) > 0 {
1117		rrbegin := rri
1118		for i := 0; i < len(dns.Ns); i++ {
1119			rrArr[rri] = dns.Ns[i].copy()
1120			rri++
1121		}
1122		r1.Ns = rrArr[rrbegin:rri:rri]
1123	}
1124
1125	if len(dns.Extra) > 0 {
1126		rrbegin := rri
1127		for i := 0; i < len(dns.Extra); i++ {
1128			rrArr[rri] = dns.Extra[i].copy()
1129			rri++
1130		}
1131		r1.Extra = rrArr[rrbegin:rri:rri]
1132	}
1133
1134	return r1
1135}
1136
1137func (q *Question) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) {
1138	off, err := PackDomainName(q.Name, msg, off, compression, compress)
1139	if err != nil {
1140		return off, err
1141	}
1142	off, err = packUint16(q.Qtype, msg, off)
1143	if err != nil {
1144		return off, err
1145	}
1146	off, err = packUint16(q.Qclass, msg, off)
1147	if err != nil {
1148		return off, err
1149	}
1150	return off, nil
1151}
1152
1153func unpackQuestion(msg []byte, off int) (Question, int, error) {
1154	var (
1155		q   Question
1156		err error
1157	)
1158	q.Name, off, err = UnpackDomainName(msg, off)
1159	if err != nil {
1160		return q, off, err
1161	}
1162	if off == len(msg) {
1163		return q, off, nil
1164	}
1165	q.Qtype, off, err = unpackUint16(msg, off)
1166	if err != nil {
1167		return q, off, err
1168	}
1169	if off == len(msg) {
1170		return q, off, nil
1171	}
1172	q.Qclass, off, err = unpackUint16(msg, off)
1173	if off == len(msg) {
1174		return q, off, nil
1175	}
1176	return q, off, err
1177}
1178
1179func (dh *Header) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) {
1180	off, err := packUint16(dh.Id, msg, off)
1181	if err != nil {
1182		return off, err
1183	}
1184	off, err = packUint16(dh.Bits, msg, off)
1185	if err != nil {
1186		return off, err
1187	}
1188	off, err = packUint16(dh.Qdcount, msg, off)
1189	if err != nil {
1190		return off, err
1191	}
1192	off, err = packUint16(dh.Ancount, msg, off)
1193	if err != nil {
1194		return off, err
1195	}
1196	off, err = packUint16(dh.Nscount, msg, off)
1197	if err != nil {
1198		return off, err
1199	}
1200	off, err = packUint16(dh.Arcount, msg, off)
1201	return off, err
1202}
1203
1204func unpackMsgHdr(msg []byte, off int) (Header, int, error) {
1205	var (
1206		dh  Header
1207		err error
1208	)
1209	dh.Id, off, err = unpackUint16(msg, off)
1210	if err != nil {
1211		return dh, off, err
1212	}
1213	dh.Bits, off, err = unpackUint16(msg, off)
1214	if err != nil {
1215		return dh, off, err
1216	}
1217	dh.Qdcount, off, err = unpackUint16(msg, off)
1218	if err != nil {
1219		return dh, off, err
1220	}
1221	dh.Ancount, off, err = unpackUint16(msg, off)
1222	if err != nil {
1223		return dh, off, err
1224	}
1225	dh.Nscount, off, err = unpackUint16(msg, off)
1226	if err != nil {
1227		return dh, off, err
1228	}
1229	dh.Arcount, off, err = unpackUint16(msg, off)
1230	return dh, off, err
1231}
1232