1package dns
2
3import (
4	"encoding/base32"
5	"encoding/base64"
6	"encoding/binary"
7	"encoding/hex"
8	"net"
9	"strings"
10)
11
12// helper functions called from the generated zmsg.go
13
14// These function are named after the tag to help pack/unpack, if there is no tag it is the name
15// of the type they pack/unpack (string, int, etc). We prefix all with unpackData or packData, so packDataA or
16// packDataDomainName.
17
18func unpackDataA(msg []byte, off int) (net.IP, int, error) {
19	if off+net.IPv4len > len(msg) {
20		return nil, len(msg), &Error{err: "overflow unpacking a"}
21	}
22	a := append(make(net.IP, 0, net.IPv4len), msg[off:off+net.IPv4len]...)
23	off += net.IPv4len
24	return a, off, nil
25}
26
27func packDataA(a net.IP, msg []byte, off int) (int, error) {
28	switch len(a) {
29	case net.IPv4len, net.IPv6len:
30		// It must be a slice of 4, even if it is 16, we encode only the first 4
31		if off+net.IPv4len > len(msg) {
32			return len(msg), &Error{err: "overflow packing a"}
33		}
34
35		copy(msg[off:], a.To4())
36		off += net.IPv4len
37	case 0:
38		// Allowed, for dynamic updates.
39	default:
40		return len(msg), &Error{err: "overflow packing a"}
41	}
42	return off, nil
43}
44
45func unpackDataAAAA(msg []byte, off int) (net.IP, int, error) {
46	if off+net.IPv6len > len(msg) {
47		return nil, len(msg), &Error{err: "overflow unpacking aaaa"}
48	}
49	aaaa := append(make(net.IP, 0, net.IPv6len), msg[off:off+net.IPv6len]...)
50	off += net.IPv6len
51	return aaaa, off, nil
52}
53
54func packDataAAAA(aaaa net.IP, msg []byte, off int) (int, error) {
55	switch len(aaaa) {
56	case net.IPv6len:
57		if off+net.IPv6len > len(msg) {
58			return len(msg), &Error{err: "overflow packing aaaa"}
59		}
60
61		copy(msg[off:], aaaa)
62		off += net.IPv6len
63	case 0:
64		// Allowed, dynamic updates.
65	default:
66		return len(msg), &Error{err: "overflow packing aaaa"}
67	}
68	return off, nil
69}
70
71// unpackHeader unpacks an RR header, returning the offset to the end of the header and a
72// re-sliced msg according to the expected length of the RR.
73func unpackHeader(msg []byte, off int) (rr RR_Header, off1 int, truncmsg []byte, err error) {
74	hdr := RR_Header{}
75	if off == len(msg) {
76		return hdr, off, msg, nil
77	}
78
79	hdr.Name, off, err = UnpackDomainName(msg, off)
80	if err != nil {
81		return hdr, len(msg), msg, err
82	}
83	hdr.Rrtype, off, err = unpackUint16(msg, off)
84	if err != nil {
85		return hdr, len(msg), msg, err
86	}
87	hdr.Class, off, err = unpackUint16(msg, off)
88	if err != nil {
89		return hdr, len(msg), msg, err
90	}
91	hdr.Ttl, off, err = unpackUint32(msg, off)
92	if err != nil {
93		return hdr, len(msg), msg, err
94	}
95	hdr.Rdlength, off, err = unpackUint16(msg, off)
96	if err != nil {
97		return hdr, len(msg), msg, err
98	}
99	msg, err = truncateMsgFromRdlength(msg, off, hdr.Rdlength)
100	return hdr, off, msg, err
101}
102
103// packHeader packs an RR header, returning the offset to the end of the header.
104// See PackDomainName for documentation about the compression.
105func (hdr RR_Header) packHeader(msg []byte, off int, compression compressionMap, compress bool) (int, error) {
106	if off == len(msg) {
107		return off, nil
108	}
109
110	off, err := packDomainName(hdr.Name, msg, off, compression, compress)
111	if err != nil {
112		return len(msg), err
113	}
114	off, err = packUint16(hdr.Rrtype, msg, off)
115	if err != nil {
116		return len(msg), err
117	}
118	off, err = packUint16(hdr.Class, msg, off)
119	if err != nil {
120		return len(msg), err
121	}
122	off, err = packUint32(hdr.Ttl, msg, off)
123	if err != nil {
124		return len(msg), err
125	}
126	off, err = packUint16(0, msg, off) // The RDLENGTH field will be set later in packRR.
127	if err != nil {
128		return len(msg), err
129	}
130	return off, nil
131}
132
133// helper helper functions.
134
135// truncateMsgFromRdLength truncates msg to match the expected length of the RR.
136// Returns an error if msg is smaller than the expected size.
137func truncateMsgFromRdlength(msg []byte, off int, rdlength uint16) (truncmsg []byte, err error) {
138	lenrd := off + int(rdlength)
139	if lenrd > len(msg) {
140		return msg, &Error{err: "overflowing header size"}
141	}
142	return msg[:lenrd], nil
143}
144
145var base32HexNoPadEncoding = base32.HexEncoding.WithPadding(base32.NoPadding)
146
147func fromBase32(s []byte) (buf []byte, err error) {
148	for i, b := range s {
149		if b >= 'a' && b <= 'z' {
150			s[i] = b - 32
151		}
152	}
153	buflen := base32HexNoPadEncoding.DecodedLen(len(s))
154	buf = make([]byte, buflen)
155	n, err := base32HexNoPadEncoding.Decode(buf, s)
156	buf = buf[:n]
157	return
158}
159
160func toBase32(b []byte) string {
161	return base32HexNoPadEncoding.EncodeToString(b)
162}
163
164func fromBase64(s []byte) (buf []byte, err error) {
165	buflen := base64.StdEncoding.DecodedLen(len(s))
166	buf = make([]byte, buflen)
167	n, err := base64.StdEncoding.Decode(buf, s)
168	buf = buf[:n]
169	return
170}
171
172func toBase64(b []byte) string { return base64.StdEncoding.EncodeToString(b) }
173
174// dynamicUpdate returns true if the Rdlength is zero.
175func noRdata(h RR_Header) bool { return h.Rdlength == 0 }
176
177func unpackUint8(msg []byte, off int) (i uint8, off1 int, err error) {
178	if off+1 > len(msg) {
179		return 0, len(msg), &Error{err: "overflow unpacking uint8"}
180	}
181	return msg[off], off + 1, nil
182}
183
184func packUint8(i uint8, msg []byte, off int) (off1 int, err error) {
185	if off+1 > len(msg) {
186		return len(msg), &Error{err: "overflow packing uint8"}
187	}
188	msg[off] = i
189	return off + 1, nil
190}
191
192func unpackUint16(msg []byte, off int) (i uint16, off1 int, err error) {
193	if off+2 > len(msg) {
194		return 0, len(msg), &Error{err: "overflow unpacking uint16"}
195	}
196	return binary.BigEndian.Uint16(msg[off:]), off + 2, nil
197}
198
199func packUint16(i uint16, msg []byte, off int) (off1 int, err error) {
200	if off+2 > len(msg) {
201		return len(msg), &Error{err: "overflow packing uint16"}
202	}
203	binary.BigEndian.PutUint16(msg[off:], i)
204	return off + 2, nil
205}
206
207func unpackUint32(msg []byte, off int) (i uint32, off1 int, err error) {
208	if off+4 > len(msg) {
209		return 0, len(msg), &Error{err: "overflow unpacking uint32"}
210	}
211	return binary.BigEndian.Uint32(msg[off:]), off + 4, nil
212}
213
214func packUint32(i uint32, msg []byte, off int) (off1 int, err error) {
215	if off+4 > len(msg) {
216		return len(msg), &Error{err: "overflow packing uint32"}
217	}
218	binary.BigEndian.PutUint32(msg[off:], i)
219	return off + 4, nil
220}
221
222func unpackUint48(msg []byte, off int) (i uint64, off1 int, err error) {
223	if off+6 > len(msg) {
224		return 0, len(msg), &Error{err: "overflow unpacking uint64 as uint48"}
225	}
226	// Used in TSIG where the last 48 bits are occupied, so for now, assume a uint48 (6 bytes)
227	i = uint64(msg[off])<<40 | uint64(msg[off+1])<<32 | uint64(msg[off+2])<<24 | uint64(msg[off+3])<<16 |
228		uint64(msg[off+4])<<8 | uint64(msg[off+5])
229	off += 6
230	return i, off, nil
231}
232
233func packUint48(i uint64, msg []byte, off int) (off1 int, err error) {
234	if off+6 > len(msg) {
235		return len(msg), &Error{err: "overflow packing uint64 as uint48"}
236	}
237	msg[off] = byte(i >> 40)
238	msg[off+1] = byte(i >> 32)
239	msg[off+2] = byte(i >> 24)
240	msg[off+3] = byte(i >> 16)
241	msg[off+4] = byte(i >> 8)
242	msg[off+5] = byte(i)
243	off += 6
244	return off, nil
245}
246
247func unpackUint64(msg []byte, off int) (i uint64, off1 int, err error) {
248	if off+8 > len(msg) {
249		return 0, len(msg), &Error{err: "overflow unpacking uint64"}
250	}
251	return binary.BigEndian.Uint64(msg[off:]), off + 8, nil
252}
253
254func packUint64(i uint64, msg []byte, off int) (off1 int, err error) {
255	if off+8 > len(msg) {
256		return len(msg), &Error{err: "overflow packing uint64"}
257	}
258	binary.BigEndian.PutUint64(msg[off:], i)
259	off += 8
260	return off, nil
261}
262
263func unpackString(msg []byte, off int) (string, int, error) {
264	if off+1 > len(msg) {
265		return "", off, &Error{err: "overflow unpacking txt"}
266	}
267	l := int(msg[off])
268	off++
269	if off+l > len(msg) {
270		return "", off, &Error{err: "overflow unpacking txt"}
271	}
272	var s strings.Builder
273	consumed := 0
274	for i, b := range msg[off : off+l] {
275		switch {
276		case b == '"' || b == '\\':
277			if consumed == 0 {
278				s.Grow(l * 2)
279			}
280			s.Write(msg[off+consumed : off+i])
281			s.WriteByte('\\')
282			s.WriteByte(b)
283			consumed = i + 1
284		case b < ' ' || b > '~': // unprintable
285			if consumed == 0 {
286				s.Grow(l * 2)
287			}
288			s.Write(msg[off+consumed : off+i])
289			s.WriteString(escapeByte(b))
290			consumed = i + 1
291		}
292	}
293	if consumed == 0 { // no escaping needed
294		return string(msg[off : off+l]), off + l, nil
295	}
296	s.Write(msg[off+consumed : off+l])
297	return s.String(), off + l, nil
298}
299
300func packString(s string, msg []byte, off int) (int, error) {
301	txtTmp := make([]byte, 256*4+1)
302	off, err := packTxtString(s, msg, off, txtTmp)
303	if err != nil {
304		return len(msg), err
305	}
306	return off, nil
307}
308
309func unpackStringBase32(msg []byte, off, end int) (string, int, error) {
310	if end > len(msg) {
311		return "", len(msg), &Error{err: "overflow unpacking base32"}
312	}
313	s := toBase32(msg[off:end])
314	return s, end, nil
315}
316
317func packStringBase32(s string, msg []byte, off int) (int, error) {
318	b32, err := fromBase32([]byte(s))
319	if err != nil {
320		return len(msg), err
321	}
322	if off+len(b32) > len(msg) {
323		return len(msg), &Error{err: "overflow packing base32"}
324	}
325	copy(msg[off:off+len(b32)], b32)
326	off += len(b32)
327	return off, nil
328}
329
330func unpackStringBase64(msg []byte, off, end int) (string, int, error) {
331	// Rest of the RR is base64 encoded value, so we don't need an explicit length
332	// to be set. Thus far all RR's that have base64 encoded fields have those as their
333	// last one. What we do need is the end of the RR!
334	if end > len(msg) {
335		return "", len(msg), &Error{err: "overflow unpacking base64"}
336	}
337	s := toBase64(msg[off:end])
338	return s, end, nil
339}
340
341func packStringBase64(s string, msg []byte, off int) (int, error) {
342	b64, err := fromBase64([]byte(s))
343	if err != nil {
344		return len(msg), err
345	}
346	if off+len(b64) > len(msg) {
347		return len(msg), &Error{err: "overflow packing base64"}
348	}
349	copy(msg[off:off+len(b64)], b64)
350	off += len(b64)
351	return off, nil
352}
353
354func unpackStringHex(msg []byte, off, end int) (string, int, error) {
355	// Rest of the RR is hex encoded value, so we don't need an explicit length
356	// to be set. NSEC and TSIG have hex fields with a length field.
357	// What we do need is the end of the RR!
358	if end > len(msg) {
359		return "", len(msg), &Error{err: "overflow unpacking hex"}
360	}
361
362	s := hex.EncodeToString(msg[off:end])
363	return s, end, nil
364}
365
366func packStringHex(s string, msg []byte, off int) (int, error) {
367	h, err := hex.DecodeString(s)
368	if err != nil {
369		return len(msg), err
370	}
371	if off+len(h) > len(msg) {
372		return len(msg), &Error{err: "overflow packing hex"}
373	}
374	copy(msg[off:off+len(h)], h)
375	off += len(h)
376	return off, nil
377}
378
379func unpackStringAny(msg []byte, off, end int) (string, int, error) {
380	if end > len(msg) {
381		return "", len(msg), &Error{err: "overflow unpacking anything"}
382	}
383	return string(msg[off:end]), end, nil
384}
385
386func packStringAny(s string, msg []byte, off int) (int, error) {
387	if off+len(s) > len(msg) {
388		return len(msg), &Error{err: "overflow packing anything"}
389	}
390	copy(msg[off:off+len(s)], s)
391	off += len(s)
392	return off, nil
393}
394
395func unpackStringTxt(msg []byte, off int) ([]string, int, error) {
396	txt, off, err := unpackTxt(msg, off)
397	if err != nil {
398		return nil, len(msg), err
399	}
400	return txt, off, nil
401}
402
403func packStringTxt(s []string, msg []byte, off int) (int, error) {
404	txtTmp := make([]byte, 256*4+1) // If the whole string consists out of \DDD we need this many.
405	off, err := packTxt(s, msg, off, txtTmp)
406	if err != nil {
407		return len(msg), err
408	}
409	return off, nil
410}
411
412func unpackDataOpt(msg []byte, off int) ([]EDNS0, int, error) {
413	var edns []EDNS0
414Option:
415	var code uint16
416	if off+4 > len(msg) {
417		return nil, len(msg), &Error{err: "overflow unpacking opt"}
418	}
419	code = binary.BigEndian.Uint16(msg[off:])
420	off += 2
421	optlen := binary.BigEndian.Uint16(msg[off:])
422	off += 2
423	if off+int(optlen) > len(msg) {
424		return nil, len(msg), &Error{err: "overflow unpacking opt"}
425	}
426	switch code {
427	case EDNS0NSID:
428		e := new(EDNS0_NSID)
429		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
430			return nil, len(msg), err
431		}
432		edns = append(edns, e)
433		off += int(optlen)
434	case EDNS0SUBNET:
435		e := new(EDNS0_SUBNET)
436		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
437			return nil, len(msg), err
438		}
439		edns = append(edns, e)
440		off += int(optlen)
441	case EDNS0COOKIE:
442		e := new(EDNS0_COOKIE)
443		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
444			return nil, len(msg), err
445		}
446		edns = append(edns, e)
447		off += int(optlen)
448	case EDNS0EXPIRE:
449		e := new(EDNS0_EXPIRE)
450		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
451			return nil, len(msg), err
452		}
453		edns = append(edns, e)
454		off += int(optlen)
455	case EDNS0UL:
456		e := new(EDNS0_UL)
457		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
458			return nil, len(msg), err
459		}
460		edns = append(edns, e)
461		off += int(optlen)
462	case EDNS0LLQ:
463		e := new(EDNS0_LLQ)
464		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
465			return nil, len(msg), err
466		}
467		edns = append(edns, e)
468		off += int(optlen)
469	case EDNS0DAU:
470		e := new(EDNS0_DAU)
471		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
472			return nil, len(msg), err
473		}
474		edns = append(edns, e)
475		off += int(optlen)
476	case EDNS0DHU:
477		e := new(EDNS0_DHU)
478		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
479			return nil, len(msg), err
480		}
481		edns = append(edns, e)
482		off += int(optlen)
483	case EDNS0N3U:
484		e := new(EDNS0_N3U)
485		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
486			return nil, len(msg), err
487		}
488		edns = append(edns, e)
489		off += int(optlen)
490	case EDNS0PADDING:
491		e := new(EDNS0_PADDING)
492		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
493			return nil, len(msg), err
494		}
495		edns = append(edns, e)
496		off += int(optlen)
497	default:
498		e := new(EDNS0_LOCAL)
499		e.Code = code
500		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
501			return nil, len(msg), err
502		}
503		edns = append(edns, e)
504		off += int(optlen)
505	}
506
507	if off < len(msg) {
508		goto Option
509	}
510
511	return edns, off, nil
512}
513
514func packDataOpt(options []EDNS0, msg []byte, off int) (int, error) {
515	for _, el := range options {
516		b, err := el.pack()
517		if err != nil || off+4 > len(msg) {
518			return len(msg), &Error{err: "overflow packing opt"}
519		}
520		binary.BigEndian.PutUint16(msg[off:], el.Option())      // Option code
521		binary.BigEndian.PutUint16(msg[off+2:], uint16(len(b))) // Length
522		off += 4
523		if off+len(b) > len(msg) {
524			copy(msg[off:], b)
525			off = len(msg)
526			continue
527		}
528		// Actual data
529		copy(msg[off:off+len(b)], b)
530		off += len(b)
531	}
532	return off, nil
533}
534
535func unpackStringOctet(msg []byte, off int) (string, int, error) {
536	s := string(msg[off:])
537	return s, len(msg), nil
538}
539
540func packStringOctet(s string, msg []byte, off int) (int, error) {
541	txtTmp := make([]byte, 256*4+1)
542	off, err := packOctetString(s, msg, off, txtTmp)
543	if err != nil {
544		return len(msg), err
545	}
546	return off, nil
547}
548
549func unpackDataNsec(msg []byte, off int) ([]uint16, int, error) {
550	var nsec []uint16
551	length, window, lastwindow := 0, 0, -1
552	for off < len(msg) {
553		if off+2 > len(msg) {
554			return nsec, len(msg), &Error{err: "overflow unpacking nsecx"}
555		}
556		window = int(msg[off])
557		length = int(msg[off+1])
558		off += 2
559		if window <= lastwindow {
560			// RFC 4034: Blocks are present in the NSEC RR RDATA in
561			// increasing numerical order.
562			return nsec, len(msg), &Error{err: "out of order NSEC block"}
563		}
564		if length == 0 {
565			// RFC 4034: Blocks with no types present MUST NOT be included.
566			return nsec, len(msg), &Error{err: "empty NSEC block"}
567		}
568		if length > 32 {
569			return nsec, len(msg), &Error{err: "NSEC block too long"}
570		}
571		if off+length > len(msg) {
572			return nsec, len(msg), &Error{err: "overflowing NSEC block"}
573		}
574
575		// Walk the bytes in the window and extract the type bits
576		for j, b := range msg[off : off+length] {
577			// Check the bits one by one, and set the type
578			if b&0x80 == 0x80 {
579				nsec = append(nsec, uint16(window*256+j*8+0))
580			}
581			if b&0x40 == 0x40 {
582				nsec = append(nsec, uint16(window*256+j*8+1))
583			}
584			if b&0x20 == 0x20 {
585				nsec = append(nsec, uint16(window*256+j*8+2))
586			}
587			if b&0x10 == 0x10 {
588				nsec = append(nsec, uint16(window*256+j*8+3))
589			}
590			if b&0x8 == 0x8 {
591				nsec = append(nsec, uint16(window*256+j*8+4))
592			}
593			if b&0x4 == 0x4 {
594				nsec = append(nsec, uint16(window*256+j*8+5))
595			}
596			if b&0x2 == 0x2 {
597				nsec = append(nsec, uint16(window*256+j*8+6))
598			}
599			if b&0x1 == 0x1 {
600				nsec = append(nsec, uint16(window*256+j*8+7))
601			}
602		}
603		off += length
604		lastwindow = window
605	}
606	return nsec, off, nil
607}
608
609// typeBitMapLen is a helper function which computes the "maximum" length of
610// a the NSEC Type BitMap field.
611func typeBitMapLen(bitmap []uint16) int {
612	var l int
613	var lastwindow, lastlength uint16
614	for _, t := range bitmap {
615		window := t / 256
616		length := (t-window*256)/8 + 1
617		if window > lastwindow && lastlength != 0 { // New window, jump to the new offset
618			l += int(lastlength) + 2
619			lastlength = 0
620		}
621		if window < lastwindow || length < lastlength {
622			// packDataNsec would return Error{err: "nsec bits out of order"} here, but
623			// when computing the length, we want do be liberal.
624			continue
625		}
626		lastwindow, lastlength = window, length
627	}
628	l += int(lastlength) + 2
629	return l
630}
631
632func packDataNsec(bitmap []uint16, msg []byte, off int) (int, error) {
633	if len(bitmap) == 0 {
634		return off, nil
635	}
636	var lastwindow, lastlength uint16
637	for _, t := range bitmap {
638		window := t / 256
639		length := (t-window*256)/8 + 1
640		if window > lastwindow && lastlength != 0 { // New window, jump to the new offset
641			off += int(lastlength) + 2
642			lastlength = 0
643		}
644		if window < lastwindow || length < lastlength {
645			return len(msg), &Error{err: "nsec bits out of order"}
646		}
647		if off+2+int(length) > len(msg) {
648			return len(msg), &Error{err: "overflow packing nsec"}
649		}
650		// Setting the window #
651		msg[off] = byte(window)
652		// Setting the octets length
653		msg[off+1] = byte(length)
654		// Setting the bit value for the type in the right octet
655		msg[off+1+int(length)] |= byte(1 << (7 - t%8))
656		lastwindow, lastlength = window, length
657	}
658	off += int(lastlength) + 2
659	return off, nil
660}
661
662func unpackDataDomainNames(msg []byte, off, end int) ([]string, int, error) {
663	var (
664		servers []string
665		s       string
666		err     error
667	)
668	if end > len(msg) {
669		return nil, len(msg), &Error{err: "overflow unpacking domain names"}
670	}
671	for off < end {
672		s, off, err = UnpackDomainName(msg, off)
673		if err != nil {
674			return servers, len(msg), err
675		}
676		servers = append(servers, s)
677	}
678	return servers, off, nil
679}
680
681func packDataDomainNames(names []string, msg []byte, off int, compression compressionMap, compress bool) (int, error) {
682	var err error
683	for _, name := range names {
684		off, err = packDomainName(name, msg, off, compression, compress)
685		if err != nil {
686			return len(msg), err
687		}
688	}
689	return off, nil
690}
691