1package dns
2
3import (
4	"encoding/base32"
5	"encoding/base64"
6	"encoding/binary"
7	"encoding/hex"
8	"net"
9	"strconv"
10)
11
12// helper functions called from the generated zmsg.go
13
14// These function are named after the tag to help pack/unpack, if there is no tag it is the name
15// of the type they pack/unpack (string, int, etc). We prefix all with unpackData or packData, so packDataA or
16// packDataDomainName.
17
18func unpackDataA(msg []byte, off int) (net.IP, int, error) {
19	if off+net.IPv4len > len(msg) {
20		return nil, len(msg), &Error{err: "overflow unpacking a"}
21	}
22	a := append(make(net.IP, 0, net.IPv4len), msg[off:off+net.IPv4len]...)
23	off += net.IPv4len
24	return a, off, nil
25}
26
27func packDataA(a net.IP, msg []byte, off int) (int, error) {
28	// It must be a slice of 4, even if it is 16, we encode only the first 4
29	if off+net.IPv4len > len(msg) {
30		return len(msg), &Error{err: "overflow packing a"}
31	}
32	switch len(a) {
33	case net.IPv4len, net.IPv6len:
34		copy(msg[off:], a.To4())
35		off += net.IPv4len
36	case 0:
37		// Allowed, for dynamic updates.
38	default:
39		return len(msg), &Error{err: "overflow packing a"}
40	}
41	return off, nil
42}
43
44func unpackDataAAAA(msg []byte, off int) (net.IP, int, error) {
45	if off+net.IPv6len > len(msg) {
46		return nil, len(msg), &Error{err: "overflow unpacking aaaa"}
47	}
48	aaaa := append(make(net.IP, 0, net.IPv6len), msg[off:off+net.IPv6len]...)
49	off += net.IPv6len
50	return aaaa, off, nil
51}
52
53func packDataAAAA(aaaa net.IP, msg []byte, off int) (int, error) {
54	if off+net.IPv6len > len(msg) {
55		return len(msg), &Error{err: "overflow packing aaaa"}
56	}
57
58	switch len(aaaa) {
59	case net.IPv6len:
60		copy(msg[off:], aaaa)
61		off += net.IPv6len
62	case 0:
63		// Allowed, dynamic updates.
64	default:
65		return len(msg), &Error{err: "overflow packing aaaa"}
66	}
67	return off, nil
68}
69
70// unpackHeader unpacks an RR header, returning the offset to the end of the header and a
71// re-sliced msg according to the expected length of the RR.
72func unpackHeader(msg []byte, off int) (rr RR_Header, off1 int, truncmsg []byte, err error) {
73	hdr := RR_Header{}
74	if off == len(msg) {
75		return hdr, off, msg, nil
76	}
77
78	hdr.Name, off, err = UnpackDomainName(msg, off)
79	if err != nil {
80		return hdr, len(msg), msg, err
81	}
82	hdr.Rrtype, off, err = unpackUint16(msg, off)
83	if err != nil {
84		return hdr, len(msg), msg, err
85	}
86	hdr.Class, off, err = unpackUint16(msg, off)
87	if err != nil {
88		return hdr, len(msg), msg, err
89	}
90	hdr.Ttl, off, err = unpackUint32(msg, off)
91	if err != nil {
92		return hdr, len(msg), msg, err
93	}
94	hdr.Rdlength, off, err = unpackUint16(msg, off)
95	if err != nil {
96		return hdr, len(msg), msg, err
97	}
98	msg, err = truncateMsgFromRdlength(msg, off, hdr.Rdlength)
99	return hdr, off, msg, nil
100}
101
102// pack packs an RR header, returning the offset to the end of the header.
103// See PackDomainName for documentation about the compression.
104func (hdr RR_Header) pack(msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
105	if off == len(msg) {
106		return off, nil
107	}
108
109	off, err = PackDomainName(hdr.Name, msg, off, compression, compress)
110	if err != nil {
111		return len(msg), err
112	}
113	off, err = packUint16(hdr.Rrtype, msg, off)
114	if err != nil {
115		return len(msg), err
116	}
117	off, err = packUint16(hdr.Class, msg, off)
118	if err != nil {
119		return len(msg), err
120	}
121	off, err = packUint32(hdr.Ttl, msg, off)
122	if err != nil {
123		return len(msg), err
124	}
125	off, err = packUint16(hdr.Rdlength, msg, off)
126	if err != nil {
127		return len(msg), err
128	}
129	return off, nil
130}
131
132// helper helper functions.
133
134// truncateMsgFromRdLength truncates msg to match the expected length of the RR.
135// Returns an error if msg is smaller than the expected size.
136func truncateMsgFromRdlength(msg []byte, off int, rdlength uint16) (truncmsg []byte, err error) {
137	lenrd := off + int(rdlength)
138	if lenrd > len(msg) {
139		return msg, &Error{err: "overflowing header size"}
140	}
141	return msg[:lenrd], nil
142}
143
144func fromBase32(s []byte) (buf []byte, err error) {
145	for i, b := range s {
146		if b >= 'a' && b <= 'z' {
147			s[i] = b - 32
148		}
149	}
150	buflen := base32.HexEncoding.DecodedLen(len(s))
151	buf = make([]byte, buflen)
152	n, err := base32.HexEncoding.Decode(buf, s)
153	buf = buf[:n]
154	return
155}
156
157func toBase32(b []byte) string { return base32.HexEncoding.EncodeToString(b) }
158
159func fromBase64(s []byte) (buf []byte, err error) {
160	buflen := base64.StdEncoding.DecodedLen(len(s))
161	buf = make([]byte, buflen)
162	n, err := base64.StdEncoding.Decode(buf, s)
163	buf = buf[:n]
164	return
165}
166
167func toBase64(b []byte) string { return base64.StdEncoding.EncodeToString(b) }
168
169// dynamicUpdate returns true if the Rdlength is zero.
170func noRdata(h RR_Header) bool { return h.Rdlength == 0 }
171
172func unpackUint8(msg []byte, off int) (i uint8, off1 int, err error) {
173	if off+1 > len(msg) {
174		return 0, len(msg), &Error{err: "overflow unpacking uint8"}
175	}
176	return uint8(msg[off]), off + 1, nil
177}
178
179func packUint8(i uint8, msg []byte, off int) (off1 int, err error) {
180	if off+1 > len(msg) {
181		return len(msg), &Error{err: "overflow packing uint8"}
182	}
183	msg[off] = byte(i)
184	return off + 1, nil
185}
186
187func unpackUint16(msg []byte, off int) (i uint16, off1 int, err error) {
188	if off+2 > len(msg) {
189		return 0, len(msg), &Error{err: "overflow unpacking uint16"}
190	}
191	return binary.BigEndian.Uint16(msg[off:]), off + 2, nil
192}
193
194func packUint16(i uint16, msg []byte, off int) (off1 int, err error) {
195	if off+2 > len(msg) {
196		return len(msg), &Error{err: "overflow packing uint16"}
197	}
198	binary.BigEndian.PutUint16(msg[off:], i)
199	return off + 2, nil
200}
201
202func unpackUint32(msg []byte, off int) (i uint32, off1 int, err error) {
203	if off+4 > len(msg) {
204		return 0, len(msg), &Error{err: "overflow unpacking uint32"}
205	}
206	return binary.BigEndian.Uint32(msg[off:]), off + 4, nil
207}
208
209func packUint32(i uint32, msg []byte, off int) (off1 int, err error) {
210	if off+4 > len(msg) {
211		return len(msg), &Error{err: "overflow packing uint32"}
212	}
213	binary.BigEndian.PutUint32(msg[off:], i)
214	return off + 4, nil
215}
216
217func unpackUint48(msg []byte, off int) (i uint64, off1 int, err error) {
218	if off+6 > len(msg) {
219		return 0, len(msg), &Error{err: "overflow unpacking uint64 as uint48"}
220	}
221	// Used in TSIG where the last 48 bits are occupied, so for now, assume a uint48 (6 bytes)
222	i = (uint64(uint64(msg[off])<<40 | uint64(msg[off+1])<<32 | uint64(msg[off+2])<<24 | uint64(msg[off+3])<<16 |
223		uint64(msg[off+4])<<8 | uint64(msg[off+5])))
224	off += 6
225	return i, off, nil
226}
227
228func packUint48(i uint64, msg []byte, off int) (off1 int, err error) {
229	if off+6 > len(msg) {
230		return len(msg), &Error{err: "overflow packing uint64 as uint48"}
231	}
232	msg[off] = byte(i >> 40)
233	msg[off+1] = byte(i >> 32)
234	msg[off+2] = byte(i >> 24)
235	msg[off+3] = byte(i >> 16)
236	msg[off+4] = byte(i >> 8)
237	msg[off+5] = byte(i)
238	off += 6
239	return off, nil
240}
241
242func unpackUint64(msg []byte, off int) (i uint64, off1 int, err error) {
243	if off+8 > len(msg) {
244		return 0, len(msg), &Error{err: "overflow unpacking uint64"}
245	}
246	return binary.BigEndian.Uint64(msg[off:]), off + 8, nil
247}
248
249func packUint64(i uint64, msg []byte, off int) (off1 int, err error) {
250	if off+8 > len(msg) {
251		return len(msg), &Error{err: "overflow packing uint64"}
252	}
253	binary.BigEndian.PutUint64(msg[off:], i)
254	off += 8
255	return off, nil
256}
257
258func unpackString(msg []byte, off int) (string, int, error) {
259	if off+1 > len(msg) {
260		return "", off, &Error{err: "overflow unpacking txt"}
261	}
262	l := int(msg[off])
263	if off+l+1 > len(msg) {
264		return "", off, &Error{err: "overflow unpacking txt"}
265	}
266	s := make([]byte, 0, l)
267	for _, b := range msg[off+1 : off+1+l] {
268		switch b {
269		case '"', '\\':
270			s = append(s, '\\', b)
271		default:
272			if b < 32 || b > 127 { // unprintable
273				var buf [3]byte
274				bufs := strconv.AppendInt(buf[:0], int64(b), 10)
275				s = append(s, '\\')
276				for i := 0; i < 3-len(bufs); i++ {
277					s = append(s, '0')
278				}
279				for _, r := range bufs {
280					s = append(s, r)
281				}
282			} else {
283				s = append(s, b)
284			}
285		}
286	}
287	off += 1 + l
288	return string(s), off, nil
289}
290
291func packString(s string, msg []byte, off int) (int, error) {
292	txtTmp := make([]byte, 256*4+1)
293	off, err := packTxtString(s, msg, off, txtTmp)
294	if err != nil {
295		return len(msg), err
296	}
297	return off, nil
298}
299
300func unpackStringBase32(msg []byte, off, end int) (string, int, error) {
301	if end > len(msg) {
302		return "", len(msg), &Error{err: "overflow unpacking base32"}
303	}
304	s := toBase32(msg[off:end])
305	return s, end, nil
306}
307
308func packStringBase32(s string, msg []byte, off int) (int, error) {
309	b32, err := fromBase32([]byte(s))
310	if err != nil {
311		return len(msg), err
312	}
313	if off+len(b32) > len(msg) {
314		return len(msg), &Error{err: "overflow packing base32"}
315	}
316	copy(msg[off:off+len(b32)], b32)
317	off += len(b32)
318	return off, nil
319}
320
321func unpackStringBase64(msg []byte, off, end int) (string, int, error) {
322	// Rest of the RR is base64 encoded value, so we don't need an explicit length
323	// to be set. Thus far all RR's that have base64 encoded fields have those as their
324	// last one. What we do need is the end of the RR!
325	if end > len(msg) {
326		return "", len(msg), &Error{err: "overflow unpacking base64"}
327	}
328	s := toBase64(msg[off:end])
329	return s, end, nil
330}
331
332func packStringBase64(s string, msg []byte, off int) (int, error) {
333	b64, err := fromBase64([]byte(s))
334	if err != nil {
335		return len(msg), err
336	}
337	if off+len(b64) > len(msg) {
338		return len(msg), &Error{err: "overflow packing base64"}
339	}
340	copy(msg[off:off+len(b64)], b64)
341	off += len(b64)
342	return off, nil
343}
344
345func unpackStringHex(msg []byte, off, end int) (string, int, error) {
346	// Rest of the RR is hex encoded value, so we don't need an explicit length
347	// to be set. NSEC and TSIG have hex fields with a length field.
348	// What we do need is the end of the RR!
349	if end > len(msg) {
350		return "", len(msg), &Error{err: "overflow unpacking hex"}
351	}
352
353	s := hex.EncodeToString(msg[off:end])
354	return s, end, nil
355}
356
357func packStringHex(s string, msg []byte, off int) (int, error) {
358	h, err := hex.DecodeString(s)
359	if err != nil {
360		return len(msg), err
361	}
362	if off+(len(h)) > len(msg) {
363		return len(msg), &Error{err: "overflow packing hex"}
364	}
365	copy(msg[off:off+len(h)], h)
366	off += len(h)
367	return off, nil
368}
369
370func unpackStringTxt(msg []byte, off int) ([]string, int, error) {
371	txt, off, err := unpackTxt(msg, off)
372	if err != nil {
373		return nil, len(msg), err
374	}
375	return txt, off, nil
376}
377
378func packStringTxt(s []string, msg []byte, off int) (int, error) {
379	txtTmp := make([]byte, 256*4+1) // If the whole string consists out of \DDD we need this many.
380	off, err := packTxt(s, msg, off, txtTmp)
381	if err != nil {
382		return len(msg), err
383	}
384	return off, nil
385}
386
387func unpackDataOpt(msg []byte, off int) ([]EDNS0, int, error) {
388	var edns []EDNS0
389Option:
390	code := uint16(0)
391	if off+4 > len(msg) {
392		return nil, len(msg), &Error{err: "overflow unpacking opt"}
393	}
394	code = binary.BigEndian.Uint16(msg[off:])
395	off += 2
396	optlen := binary.BigEndian.Uint16(msg[off:])
397	off += 2
398	if off+int(optlen) > len(msg) {
399		return nil, len(msg), &Error{err: "overflow unpacking opt"}
400	}
401	switch code {
402	case EDNS0NSID:
403		e := new(EDNS0_NSID)
404		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
405			return nil, len(msg), err
406		}
407		edns = append(edns, e)
408		off += int(optlen)
409	case EDNS0SUBNET, EDNS0SUBNETDRAFT:
410		e := new(EDNS0_SUBNET)
411		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
412			return nil, len(msg), err
413		}
414		edns = append(edns, e)
415		off += int(optlen)
416		if code == EDNS0SUBNETDRAFT {
417			e.DraftOption = true
418		}
419	case EDNS0COOKIE:
420		e := new(EDNS0_COOKIE)
421		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
422			return nil, len(msg), err
423		}
424		edns = append(edns, e)
425		off += int(optlen)
426	case EDNS0UL:
427		e := new(EDNS0_UL)
428		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
429			return nil, len(msg), err
430		}
431		edns = append(edns, e)
432		off += int(optlen)
433	case EDNS0LLQ:
434		e := new(EDNS0_LLQ)
435		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
436			return nil, len(msg), err
437		}
438		edns = append(edns, e)
439		off += int(optlen)
440	case EDNS0DAU:
441		e := new(EDNS0_DAU)
442		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
443			return nil, len(msg), err
444		}
445		edns = append(edns, e)
446		off += int(optlen)
447	case EDNS0DHU:
448		e := new(EDNS0_DHU)
449		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
450			return nil, len(msg), err
451		}
452		edns = append(edns, e)
453		off += int(optlen)
454	case EDNS0N3U:
455		e := new(EDNS0_N3U)
456		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
457			return nil, len(msg), err
458		}
459		edns = append(edns, e)
460		off += int(optlen)
461	default:
462		e := new(EDNS0_LOCAL)
463		e.Code = code
464		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
465			return nil, len(msg), err
466		}
467		edns = append(edns, e)
468		off += int(optlen)
469	}
470
471	if off < len(msg) {
472		goto Option
473	}
474
475	return edns, off, nil
476}
477
478func packDataOpt(options []EDNS0, msg []byte, off int) (int, error) {
479	for _, el := range options {
480		b, err := el.pack()
481		if err != nil || off+3 > len(msg) {
482			return len(msg), &Error{err: "overflow packing opt"}
483		}
484		binary.BigEndian.PutUint16(msg[off:], el.Option())      // Option code
485		binary.BigEndian.PutUint16(msg[off+2:], uint16(len(b))) // Length
486		off += 4
487		if off+len(b) > len(msg) {
488			copy(msg[off:], b)
489			off = len(msg)
490			continue
491		}
492		// Actual data
493		copy(msg[off:off+len(b)], b)
494		off += len(b)
495	}
496	return off, nil
497}
498
499func unpackStringOctet(msg []byte, off int) (string, int, error) {
500	s := string(msg[off:])
501	return s, len(msg), nil
502}
503
504func packStringOctet(s string, msg []byte, off int) (int, error) {
505	txtTmp := make([]byte, 256*4+1)
506	off, err := packOctetString(s, msg, off, txtTmp)
507	if err != nil {
508		return len(msg), err
509	}
510	return off, nil
511}
512
513func unpackDataNsec(msg []byte, off int) ([]uint16, int, error) {
514	var nsec []uint16
515	length, window, lastwindow := 0, 0, -1
516	for off < len(msg) {
517		if off+2 > len(msg) {
518			return nsec, len(msg), &Error{err: "overflow unpacking nsecx"}
519		}
520		window = int(msg[off])
521		length = int(msg[off+1])
522		off += 2
523		if window <= lastwindow {
524			// RFC 4034: Blocks are present in the NSEC RR RDATA in
525			// increasing numerical order.
526			return nsec, len(msg), &Error{err: "out of order NSEC block"}
527		}
528		if length == 0 {
529			// RFC 4034: Blocks with no types present MUST NOT be included.
530			return nsec, len(msg), &Error{err: "empty NSEC block"}
531		}
532		if length > 32 {
533			return nsec, len(msg), &Error{err: "NSEC block too long"}
534		}
535		if off+length > len(msg) {
536			return nsec, len(msg), &Error{err: "overflowing NSEC block"}
537		}
538
539		// Walk the bytes in the window and extract the type bits
540		for j := 0; j < length; j++ {
541			b := msg[off+j]
542			// Check the bits one by one, and set the type
543			if b&0x80 == 0x80 {
544				nsec = append(nsec, uint16(window*256+j*8+0))
545			}
546			if b&0x40 == 0x40 {
547				nsec = append(nsec, uint16(window*256+j*8+1))
548			}
549			if b&0x20 == 0x20 {
550				nsec = append(nsec, uint16(window*256+j*8+2))
551			}
552			if b&0x10 == 0x10 {
553				nsec = append(nsec, uint16(window*256+j*8+3))
554			}
555			if b&0x8 == 0x8 {
556				nsec = append(nsec, uint16(window*256+j*8+4))
557			}
558			if b&0x4 == 0x4 {
559				nsec = append(nsec, uint16(window*256+j*8+5))
560			}
561			if b&0x2 == 0x2 {
562				nsec = append(nsec, uint16(window*256+j*8+6))
563			}
564			if b&0x1 == 0x1 {
565				nsec = append(nsec, uint16(window*256+j*8+7))
566			}
567		}
568		off += length
569		lastwindow = window
570	}
571	return nsec, off, nil
572}
573
574func packDataNsec(bitmap []uint16, msg []byte, off int) (int, error) {
575	if len(bitmap) == 0 {
576		return off, nil
577	}
578	var lastwindow, lastlength uint16
579	for j := 0; j < len(bitmap); j++ {
580		t := bitmap[j]
581		window := t / 256
582		length := (t-window*256)/8 + 1
583		if window > lastwindow && lastlength != 0 { // New window, jump to the new offset
584			off += int(lastlength) + 2
585			lastlength = 0
586		}
587		if window < lastwindow || length < lastlength {
588			return len(msg), &Error{err: "nsec bits out of order"}
589		}
590		if off+2+int(length) > len(msg) {
591			return len(msg), &Error{err: "overflow packing nsec"}
592		}
593		// Setting the window #
594		msg[off] = byte(window)
595		// Setting the octets length
596		msg[off+1] = byte(length)
597		// Setting the bit value for the type in the right octet
598		msg[off+1+int(length)] |= byte(1 << (7 - (t % 8)))
599		lastwindow, lastlength = window, length
600	}
601	off += int(lastlength) + 2
602	return off, nil
603}
604
605func unpackDataDomainNames(msg []byte, off, end int) ([]string, int, error) {
606	var (
607		servers []string
608		s       string
609		err     error
610	)
611	if end > len(msg) {
612		return nil, len(msg), &Error{err: "overflow unpacking domain names"}
613	}
614	for off < end {
615		s, off, err = UnpackDomainName(msg, off)
616		if err != nil {
617			return servers, len(msg), err
618		}
619		servers = append(servers, s)
620	}
621	return servers, off, nil
622}
623
624func packDataDomainNames(names []string, msg []byte, off int, compression map[string]int, compress bool) (int, error) {
625	var err error
626	for j := 0; j < len(names); j++ {
627		off, err = PackDomainName(names[j], msg, off, compression, false && compress)
628		if err != nil {
629			return len(msg), err
630		}
631	}
632	return off, nil
633}
634