1package dns
2
3import (
4	"encoding/base32"
5	"encoding/base64"
6	"encoding/binary"
7	"encoding/hex"
8	"net"
9	"strconv"
10)
11
12// helper functions called from the generated zmsg.go
13
14// These function are named after the tag to help pack/unpack, if there is no tag it is the name
15// of the type they pack/unpack (string, int, etc). We prefix all with unpackData or packData, so packDataA or
16// packDataDomainName.
17
18func unpackDataA(msg []byte, off int) (net.IP, int, error) {
19	if off+net.IPv4len > len(msg) {
20		return nil, len(msg), &Error{err: "overflow unpacking a"}
21	}
22	a := append(make(net.IP, 0, net.IPv4len), msg[off:off+net.IPv4len]...)
23	off += net.IPv4len
24	return a, off, nil
25}
26
27func packDataA(a net.IP, msg []byte, off int) (int, error) {
28	// It must be a slice of 4, even if it is 16, we encode only the first 4
29	if off+net.IPv4len > len(msg) {
30		return len(msg), &Error{err: "overflow packing a"}
31	}
32	switch len(a) {
33	case net.IPv4len, net.IPv6len:
34		copy(msg[off:], a.To4())
35		off += net.IPv4len
36	case 0:
37		// Allowed, for dynamic updates.
38	default:
39		return len(msg), &Error{err: "overflow packing a"}
40	}
41	return off, nil
42}
43
44func unpackDataAAAA(msg []byte, off int) (net.IP, int, error) {
45	if off+net.IPv6len > len(msg) {
46		return nil, len(msg), &Error{err: "overflow unpacking aaaa"}
47	}
48	aaaa := append(make(net.IP, 0, net.IPv6len), msg[off:off+net.IPv6len]...)
49	off += net.IPv6len
50	return aaaa, off, nil
51}
52
53func packDataAAAA(aaaa net.IP, msg []byte, off int) (int, error) {
54	if off+net.IPv6len > len(msg) {
55		return len(msg), &Error{err: "overflow packing aaaa"}
56	}
57
58	switch len(aaaa) {
59	case net.IPv6len:
60		copy(msg[off:], aaaa)
61		off += net.IPv6len
62	case 0:
63		// Allowed, dynamic updates.
64	default:
65		return len(msg), &Error{err: "overflow packing aaaa"}
66	}
67	return off, nil
68}
69
70// unpackHeader unpacks an RR header, returning the offset to the end of the header and a
71// re-sliced msg according to the expected length of the RR.
72func unpackHeader(msg []byte, off int) (rr RR_Header, off1 int, truncmsg []byte, err error) {
73	hdr := RR_Header{}
74	if off == len(msg) {
75		return hdr, off, msg, nil
76	}
77
78	hdr.Name, off, err = UnpackDomainName(msg, off)
79	if err != nil {
80		return hdr, len(msg), msg, err
81	}
82	hdr.Rrtype, off, err = unpackUint16(msg, off)
83	if err != nil {
84		return hdr, len(msg), msg, err
85	}
86	hdr.Class, off, err = unpackUint16(msg, off)
87	if err != nil {
88		return hdr, len(msg), msg, err
89	}
90	hdr.Ttl, off, err = unpackUint32(msg, off)
91	if err != nil {
92		return hdr, len(msg), msg, err
93	}
94	hdr.Rdlength, off, err = unpackUint16(msg, off)
95	if err != nil {
96		return hdr, len(msg), msg, err
97	}
98	msg, err = truncateMsgFromRdlength(msg, off, hdr.Rdlength)
99	return hdr, off, msg, nil
100}
101
102// pack packs an RR header, returning the offset to the end of the header.
103// See PackDomainName for documentation about the compression.
104func (hdr RR_Header) pack(msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
105	if off == len(msg) {
106		return off, nil
107	}
108
109	off, err = PackDomainName(hdr.Name, msg, off, compression, compress)
110	if err != nil {
111		return len(msg), err
112	}
113	off, err = packUint16(hdr.Rrtype, msg, off)
114	if err != nil {
115		return len(msg), err
116	}
117	off, err = packUint16(hdr.Class, msg, off)
118	if err != nil {
119		return len(msg), err
120	}
121	off, err = packUint32(hdr.Ttl, msg, off)
122	if err != nil {
123		return len(msg), err
124	}
125	off, err = packUint16(hdr.Rdlength, msg, off)
126	if err != nil {
127		return len(msg), err
128	}
129	return off, nil
130}
131
132// helper helper functions.
133
134// truncateMsgFromRdLength truncates msg to match the expected length of the RR.
135// Returns an error if msg is smaller than the expected size.
136func truncateMsgFromRdlength(msg []byte, off int, rdlength uint16) (truncmsg []byte, err error) {
137	lenrd := off + int(rdlength)
138	if lenrd > len(msg) {
139		return msg, &Error{err: "overflowing header size"}
140	}
141	return msg[:lenrd], nil
142}
143
144func fromBase32(s []byte) (buf []byte, err error) {
145	buflen := base32.HexEncoding.DecodedLen(len(s))
146	buf = make([]byte, buflen)
147	n, err := base32.HexEncoding.Decode(buf, s)
148	buf = buf[:n]
149	return
150}
151
152func toBase32(b []byte) string { return base32.HexEncoding.EncodeToString(b) }
153
154func fromBase64(s []byte) (buf []byte, err error) {
155	buflen := base64.StdEncoding.DecodedLen(len(s))
156	buf = make([]byte, buflen)
157	n, err := base64.StdEncoding.Decode(buf, s)
158	buf = buf[:n]
159	return
160}
161
162func toBase64(b []byte) string { return base64.StdEncoding.EncodeToString(b) }
163
164// dynamicUpdate returns true if the Rdlength is zero.
165func noRdata(h RR_Header) bool { return h.Rdlength == 0 }
166
167func unpackUint8(msg []byte, off int) (i uint8, off1 int, err error) {
168	if off+1 > len(msg) {
169		return 0, len(msg), &Error{err: "overflow unpacking uint8"}
170	}
171	return uint8(msg[off]), off + 1, nil
172}
173
174func packUint8(i uint8, msg []byte, off int) (off1 int, err error) {
175	if off+1 > len(msg) {
176		return len(msg), &Error{err: "overflow packing uint8"}
177	}
178	msg[off] = byte(i)
179	return off + 1, nil
180}
181
182func unpackUint16(msg []byte, off int) (i uint16, off1 int, err error) {
183	if off+2 > len(msg) {
184		return 0, len(msg), &Error{err: "overflow unpacking uint16"}
185	}
186	return binary.BigEndian.Uint16(msg[off:]), off + 2, nil
187}
188
189func packUint16(i uint16, msg []byte, off int) (off1 int, err error) {
190	if off+2 > len(msg) {
191		return len(msg), &Error{err: "overflow packing uint16"}
192	}
193	binary.BigEndian.PutUint16(msg[off:], i)
194	return off + 2, nil
195}
196
197func unpackUint32(msg []byte, off int) (i uint32, off1 int, err error) {
198	if off+4 > len(msg) {
199		return 0, len(msg), &Error{err: "overflow unpacking uint32"}
200	}
201	return binary.BigEndian.Uint32(msg[off:]), off + 4, nil
202}
203
204func packUint32(i uint32, msg []byte, off int) (off1 int, err error) {
205	if off+4 > len(msg) {
206		return len(msg), &Error{err: "overflow packing uint32"}
207	}
208	binary.BigEndian.PutUint32(msg[off:], i)
209	return off + 4, nil
210}
211
212func unpackUint48(msg []byte, off int) (i uint64, off1 int, err error) {
213	if off+6 > len(msg) {
214		return 0, len(msg), &Error{err: "overflow unpacking uint64 as uint48"}
215	}
216	// Used in TSIG where the last 48 bits are occupied, so for now, assume a uint48 (6 bytes)
217	i = (uint64(uint64(msg[off])<<40 | uint64(msg[off+1])<<32 | uint64(msg[off+2])<<24 | uint64(msg[off+3])<<16 |
218		uint64(msg[off+4])<<8 | uint64(msg[off+5])))
219	off += 6
220	return i, off, nil
221}
222
223func packUint48(i uint64, msg []byte, off int) (off1 int, err error) {
224	if off+6 > len(msg) {
225		return len(msg), &Error{err: "overflow packing uint64 as uint48"}
226	}
227	msg[off] = byte(i >> 40)
228	msg[off+1] = byte(i >> 32)
229	msg[off+2] = byte(i >> 24)
230	msg[off+3] = byte(i >> 16)
231	msg[off+4] = byte(i >> 8)
232	msg[off+5] = byte(i)
233	off += 6
234	return off, nil
235}
236
237func unpackUint64(msg []byte, off int) (i uint64, off1 int, err error) {
238	if off+8 > len(msg) {
239		return 0, len(msg), &Error{err: "overflow unpacking uint64"}
240	}
241	return binary.BigEndian.Uint64(msg[off:]), off + 8, nil
242}
243
244func packUint64(i uint64, msg []byte, off int) (off1 int, err error) {
245	if off+8 > len(msg) {
246		return len(msg), &Error{err: "overflow packing uint64"}
247	}
248	binary.BigEndian.PutUint64(msg[off:], i)
249	off += 8
250	return off, nil
251}
252
253func unpackString(msg []byte, off int) (string, int, error) {
254	if off+1 > len(msg) {
255		return "", off, &Error{err: "overflow unpacking txt"}
256	}
257	l := int(msg[off])
258	if off+l+1 > len(msg) {
259		return "", off, &Error{err: "overflow unpacking txt"}
260	}
261	s := make([]byte, 0, l)
262	for _, b := range msg[off+1 : off+1+l] {
263		switch b {
264		case '"', '\\':
265			s = append(s, '\\', b)
266		case '\t', '\r', '\n':
267			s = append(s, b)
268		default:
269			if b < 32 || b > 127 { // unprintable
270				var buf [3]byte
271				bufs := strconv.AppendInt(buf[:0], int64(b), 10)
272				s = append(s, '\\')
273				for i := 0; i < 3-len(bufs); i++ {
274					s = append(s, '0')
275				}
276				for _, r := range bufs {
277					s = append(s, r)
278				}
279			} else {
280				s = append(s, b)
281			}
282		}
283	}
284	off += 1 + l
285	return string(s), off, nil
286}
287
288func packString(s string, msg []byte, off int) (int, error) {
289	txtTmp := make([]byte, 256*4+1)
290	off, err := packTxtString(s, msg, off, txtTmp)
291	if err != nil {
292		return len(msg), err
293	}
294	return off, nil
295}
296
297func unpackStringBase32(msg []byte, off, end int) (string, int, error) {
298	if end > len(msg) {
299		return "", len(msg), &Error{err: "overflow unpacking base32"}
300	}
301	s := toBase32(msg[off:end])
302	return s, end, nil
303}
304
305func packStringBase32(s string, msg []byte, off int) (int, error) {
306	b32, err := fromBase32([]byte(s))
307	if err != nil {
308		return len(msg), err
309	}
310	if off+len(b32) > len(msg) {
311		return len(msg), &Error{err: "overflow packing base32"}
312	}
313	copy(msg[off:off+len(b32)], b32)
314	off += len(b32)
315	return off, nil
316}
317
318func unpackStringBase64(msg []byte, off, end int) (string, int, error) {
319	// Rest of the RR is base64 encoded value, so we don't need an explicit length
320	// to be set. Thus far all RR's that have base64 encoded fields have those as their
321	// last one. What we do need is the end of the RR!
322	if end > len(msg) {
323		return "", len(msg), &Error{err: "overflow unpacking base64"}
324	}
325	s := toBase64(msg[off:end])
326	return s, end, nil
327}
328
329func packStringBase64(s string, msg []byte, off int) (int, error) {
330	b64, err := fromBase64([]byte(s))
331	if err != nil {
332		return len(msg), err
333	}
334	if off+len(b64) > len(msg) {
335		return len(msg), &Error{err: "overflow packing base64"}
336	}
337	copy(msg[off:off+len(b64)], b64)
338	off += len(b64)
339	return off, nil
340}
341
342func unpackStringHex(msg []byte, off, end int) (string, int, error) {
343	// Rest of the RR is hex encoded value, so we don't need an explicit length
344	// to be set. NSEC and TSIG have hex fields with a length field.
345	// What we do need is the end of the RR!
346	if end > len(msg) {
347		return "", len(msg), &Error{err: "overflow unpacking hex"}
348	}
349
350	s := hex.EncodeToString(msg[off:end])
351	return s, end, nil
352}
353
354func packStringHex(s string, msg []byte, off int) (int, error) {
355	h, err := hex.DecodeString(s)
356	if err != nil {
357		return len(msg), err
358	}
359	if off+(len(h)) > len(msg) {
360		return len(msg), &Error{err: "overflow packing hex"}
361	}
362	copy(msg[off:off+len(h)], h)
363	off += len(h)
364	return off, nil
365}
366
367func unpackStringTxt(msg []byte, off int) ([]string, int, error) {
368	txt, off, err := unpackTxt(msg, off)
369	if err != nil {
370		return nil, len(msg), err
371	}
372	return txt, off, nil
373}
374
375func packStringTxt(s []string, msg []byte, off int) (int, error) {
376	txtTmp := make([]byte, 256*4+1) // If the whole string consists out of \DDD we need this many.
377	off, err := packTxt(s, msg, off, txtTmp)
378	if err != nil {
379		return len(msg), err
380	}
381	return off, nil
382}
383
384func unpackDataOpt(msg []byte, off int) ([]EDNS0, int, error) {
385	var edns []EDNS0
386Option:
387	code := uint16(0)
388	if off+4 > len(msg) {
389		return nil, len(msg), &Error{err: "overflow unpacking opt"}
390	}
391	code = binary.BigEndian.Uint16(msg[off:])
392	off += 2
393	optlen := binary.BigEndian.Uint16(msg[off:])
394	off += 2
395	if off+int(optlen) > len(msg) {
396		return nil, len(msg), &Error{err: "overflow unpacking opt"}
397	}
398	switch code {
399	case EDNS0NSID:
400		e := new(EDNS0_NSID)
401		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
402			return nil, len(msg), err
403		}
404		edns = append(edns, e)
405		off += int(optlen)
406	case EDNS0SUBNET, EDNS0SUBNETDRAFT:
407		e := new(EDNS0_SUBNET)
408		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
409			return nil, len(msg), err
410		}
411		edns = append(edns, e)
412		off += int(optlen)
413		if code == EDNS0SUBNETDRAFT {
414			e.DraftOption = true
415		}
416	case EDNS0COOKIE:
417		e := new(EDNS0_COOKIE)
418		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
419			return nil, len(msg), err
420		}
421		edns = append(edns, e)
422		off += int(optlen)
423	case EDNS0UL:
424		e := new(EDNS0_UL)
425		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
426			return nil, len(msg), err
427		}
428		edns = append(edns, e)
429		off += int(optlen)
430	case EDNS0LLQ:
431		e := new(EDNS0_LLQ)
432		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
433			return nil, len(msg), err
434		}
435		edns = append(edns, e)
436		off += int(optlen)
437	case EDNS0DAU:
438		e := new(EDNS0_DAU)
439		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
440			return nil, len(msg), err
441		}
442		edns = append(edns, e)
443		off += int(optlen)
444	case EDNS0DHU:
445		e := new(EDNS0_DHU)
446		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
447			return nil, len(msg), err
448		}
449		edns = append(edns, e)
450		off += int(optlen)
451	case EDNS0N3U:
452		e := new(EDNS0_N3U)
453		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
454			return nil, len(msg), err
455		}
456		edns = append(edns, e)
457		off += int(optlen)
458	default:
459		e := new(EDNS0_LOCAL)
460		e.Code = code
461		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
462			return nil, len(msg), err
463		}
464		edns = append(edns, e)
465		off += int(optlen)
466	}
467
468	if off < len(msg) {
469		goto Option
470	}
471
472	return edns, off, nil
473}
474
475func packDataOpt(options []EDNS0, msg []byte, off int) (int, error) {
476	for _, el := range options {
477		b, err := el.pack()
478		if err != nil || off+3 > len(msg) {
479			return len(msg), &Error{err: "overflow packing opt"}
480		}
481		binary.BigEndian.PutUint16(msg[off:], el.Option())      // Option code
482		binary.BigEndian.PutUint16(msg[off+2:], uint16(len(b))) // Length
483		off += 4
484		if off+len(b) > len(msg) {
485			copy(msg[off:], b)
486			off = len(msg)
487			continue
488		}
489		// Actual data
490		copy(msg[off:off+len(b)], b)
491		off += len(b)
492	}
493	return off, nil
494}
495
496func unpackStringOctet(msg []byte, off int) (string, int, error) {
497	s := string(msg[off:])
498	return s, len(msg), nil
499}
500
501func packStringOctet(s string, msg []byte, off int) (int, error) {
502	txtTmp := make([]byte, 256*4+1)
503	off, err := packOctetString(s, msg, off, txtTmp)
504	if err != nil {
505		return len(msg), err
506	}
507	return off, nil
508}
509
510func unpackDataNsec(msg []byte, off int) ([]uint16, int, error) {
511	var nsec []uint16
512	length, window, lastwindow := 0, 0, -1
513	for off < len(msg) {
514		if off+2 > len(msg) {
515			return nsec, len(msg), &Error{err: "overflow unpacking nsecx"}
516		}
517		window = int(msg[off])
518		length = int(msg[off+1])
519		off += 2
520		if window <= lastwindow {
521			// RFC 4034: Blocks are present in the NSEC RR RDATA in
522			// increasing numerical order.
523			return nsec, len(msg), &Error{err: "out of order NSEC block"}
524		}
525		if length == 0 {
526			// RFC 4034: Blocks with no types present MUST NOT be included.
527			return nsec, len(msg), &Error{err: "empty NSEC block"}
528		}
529		if length > 32 {
530			return nsec, len(msg), &Error{err: "NSEC block too long"}
531		}
532		if off+length > len(msg) {
533			return nsec, len(msg), &Error{err: "overflowing NSEC block"}
534		}
535
536		// Walk the bytes in the window and extract the type bits
537		for j := 0; j < length; j++ {
538			b := msg[off+j]
539			// Check the bits one by one, and set the type
540			if b&0x80 == 0x80 {
541				nsec = append(nsec, uint16(window*256+j*8+0))
542			}
543			if b&0x40 == 0x40 {
544				nsec = append(nsec, uint16(window*256+j*8+1))
545			}
546			if b&0x20 == 0x20 {
547				nsec = append(nsec, uint16(window*256+j*8+2))
548			}
549			if b&0x10 == 0x10 {
550				nsec = append(nsec, uint16(window*256+j*8+3))
551			}
552			if b&0x8 == 0x8 {
553				nsec = append(nsec, uint16(window*256+j*8+4))
554			}
555			if b&0x4 == 0x4 {
556				nsec = append(nsec, uint16(window*256+j*8+5))
557			}
558			if b&0x2 == 0x2 {
559				nsec = append(nsec, uint16(window*256+j*8+6))
560			}
561			if b&0x1 == 0x1 {
562				nsec = append(nsec, uint16(window*256+j*8+7))
563			}
564		}
565		off += length
566		lastwindow = window
567	}
568	return nsec, off, nil
569}
570
571func packDataNsec(bitmap []uint16, msg []byte, off int) (int, error) {
572	if len(bitmap) == 0 {
573		return off, nil
574	}
575	var lastwindow, lastlength uint16
576	for j := 0; j < len(bitmap); j++ {
577		t := bitmap[j]
578		window := t / 256
579		length := (t-window*256)/8 + 1
580		if window > lastwindow && lastlength != 0 { // New window, jump to the new offset
581			off += int(lastlength) + 2
582			lastlength = 0
583		}
584		if window < lastwindow || length < lastlength {
585			return len(msg), &Error{err: "nsec bits out of order"}
586		}
587		if off+2+int(length) > len(msg) {
588			return len(msg), &Error{err: "overflow packing nsec"}
589		}
590		// Setting the window #
591		msg[off] = byte(window)
592		// Setting the octets length
593		msg[off+1] = byte(length)
594		// Setting the bit value for the type in the right octet
595		msg[off+1+int(length)] |= byte(1 << (7 - (t % 8)))
596		lastwindow, lastlength = window, length
597	}
598	off += int(lastlength) + 2
599	return off, nil
600}
601
602func unpackDataDomainNames(msg []byte, off, end int) ([]string, int, error) {
603	var (
604		servers []string
605		s       string
606		err     error
607	)
608	if end > len(msg) {
609		return nil, len(msg), &Error{err: "overflow unpacking domain names"}
610	}
611	for off < end {
612		s, off, err = UnpackDomainName(msg, off)
613		if err != nil {
614			return servers, len(msg), err
615		}
616		servers = append(servers, s)
617	}
618	return servers, off, nil
619}
620
621func packDataDomainNames(names []string, msg []byte, off int, compression map[string]int, compress bool) (int, error) {
622	var err error
623	for j := 0; j < len(names); j++ {
624		off, err = PackDomainName(names[j], msg, off, compression, false && compress)
625		if err != nil {
626			return len(msg), err
627		}
628	}
629	return off, nil
630}
631