1package dns
2
3import (
4	"encoding/base32"
5	"encoding/base64"
6	"encoding/binary"
7	"encoding/hex"
8	"net"
9	"strings"
10)
11
12// helper functions called from the generated zmsg.go
13
14// These function are named after the tag to help pack/unpack, if there is no tag it is the name
15// of the type they pack/unpack (string, int, etc). We prefix all with unpackData or packData, so packDataA or
16// packDataDomainName.
17
18func unpackDataA(msg []byte, off int) (net.IP, int, error) {
19	if off+net.IPv4len > len(msg) {
20		return nil, len(msg), &Error{err: "overflow unpacking a"}
21	}
22	a := append(make(net.IP, 0, net.IPv4len), msg[off:off+net.IPv4len]...)
23	off += net.IPv4len
24	return a, off, nil
25}
26
27func packDataA(a net.IP, msg []byte, off int) (int, error) {
28	switch len(a) {
29	case net.IPv4len, net.IPv6len:
30		// It must be a slice of 4, even if it is 16, we encode only the first 4
31		if off+net.IPv4len > len(msg) {
32			return len(msg), &Error{err: "overflow packing a"}
33		}
34
35		copy(msg[off:], a.To4())
36		off += net.IPv4len
37	case 0:
38		// Allowed, for dynamic updates.
39	default:
40		return len(msg), &Error{err: "overflow packing a"}
41	}
42	return off, nil
43}
44
45func unpackDataAAAA(msg []byte, off int) (net.IP, int, error) {
46	if off+net.IPv6len > len(msg) {
47		return nil, len(msg), &Error{err: "overflow unpacking aaaa"}
48	}
49	aaaa := append(make(net.IP, 0, net.IPv6len), msg[off:off+net.IPv6len]...)
50	off += net.IPv6len
51	return aaaa, off, nil
52}
53
54func packDataAAAA(aaaa net.IP, msg []byte, off int) (int, error) {
55	switch len(aaaa) {
56	case net.IPv6len:
57		if off+net.IPv6len > len(msg) {
58			return len(msg), &Error{err: "overflow packing aaaa"}
59		}
60
61		copy(msg[off:], aaaa)
62		off += net.IPv6len
63	case 0:
64		// Allowed, dynamic updates.
65	default:
66		return len(msg), &Error{err: "overflow packing aaaa"}
67	}
68	return off, nil
69}
70
71// unpackHeader unpacks an RR header, returning the offset to the end of the header and a
72// re-sliced msg according to the expected length of the RR.
73func unpackHeader(msg []byte, off int) (rr RR_Header, off1 int, truncmsg []byte, err error) {
74	hdr := RR_Header{}
75	if off == len(msg) {
76		return hdr, off, msg, nil
77	}
78
79	hdr.Name, off, err = UnpackDomainName(msg, off)
80	if err != nil {
81		return hdr, len(msg), msg, err
82	}
83	hdr.Rrtype, off, err = unpackUint16(msg, off)
84	if err != nil {
85		return hdr, len(msg), msg, err
86	}
87	hdr.Class, off, err = unpackUint16(msg, off)
88	if err != nil {
89		return hdr, len(msg), msg, err
90	}
91	hdr.Ttl, off, err = unpackUint32(msg, off)
92	if err != nil {
93		return hdr, len(msg), msg, err
94	}
95	hdr.Rdlength, off, err = unpackUint16(msg, off)
96	if err != nil {
97		return hdr, len(msg), msg, err
98	}
99	msg, err = truncateMsgFromRdlength(msg, off, hdr.Rdlength)
100	return hdr, off, msg, err
101}
102
103// packHeader packs an RR header, returning the offset to the end of the header.
104// See PackDomainName for documentation about the compression.
105func (hdr RR_Header) packHeader(msg []byte, off int, compression compressionMap, compress bool) (int, error) {
106	if off == len(msg) {
107		return off, nil
108	}
109
110	off, err := packDomainName(hdr.Name, msg, off, compression, compress)
111	if err != nil {
112		return len(msg), err
113	}
114	off, err = packUint16(hdr.Rrtype, msg, off)
115	if err != nil {
116		return len(msg), err
117	}
118	off, err = packUint16(hdr.Class, msg, off)
119	if err != nil {
120		return len(msg), err
121	}
122	off, err = packUint32(hdr.Ttl, msg, off)
123	if err != nil {
124		return len(msg), err
125	}
126	off, err = packUint16(0, msg, off) // The RDLENGTH field will be set later in packRR.
127	if err != nil {
128		return len(msg), err
129	}
130	return off, nil
131}
132
133// helper helper functions.
134
135// truncateMsgFromRdLength truncates msg to match the expected length of the RR.
136// Returns an error if msg is smaller than the expected size.
137func truncateMsgFromRdlength(msg []byte, off int, rdlength uint16) (truncmsg []byte, err error) {
138	lenrd := off + int(rdlength)
139	if lenrd > len(msg) {
140		return msg, &Error{err: "overflowing header size"}
141	}
142	return msg[:lenrd], nil
143}
144
145var base32HexNoPadEncoding = base32.HexEncoding.WithPadding(base32.NoPadding)
146
147func fromBase32(s []byte) (buf []byte, err error) {
148	for i, b := range s {
149		if b >= 'a' && b <= 'z' {
150			s[i] = b - 32
151		}
152	}
153	buflen := base32HexNoPadEncoding.DecodedLen(len(s))
154	buf = make([]byte, buflen)
155	n, err := base32HexNoPadEncoding.Decode(buf, s)
156	buf = buf[:n]
157	return
158}
159
160func toBase32(b []byte) string {
161	return base32HexNoPadEncoding.EncodeToString(b)
162}
163
164func fromBase64(s []byte) (buf []byte, err error) {
165	buflen := base64.StdEncoding.DecodedLen(len(s))
166	buf = make([]byte, buflen)
167	n, err := base64.StdEncoding.Decode(buf, s)
168	buf = buf[:n]
169	return
170}
171
172func toBase64(b []byte) string { return base64.StdEncoding.EncodeToString(b) }
173
174// dynamicUpdate returns true if the Rdlength is zero.
175func noRdata(h RR_Header) bool { return h.Rdlength == 0 }
176
177func unpackUint8(msg []byte, off int) (i uint8, off1 int, err error) {
178	if off+1 > len(msg) {
179		return 0, len(msg), &Error{err: "overflow unpacking uint8"}
180	}
181	return msg[off], off + 1, nil
182}
183
184func packUint8(i uint8, msg []byte, off int) (off1 int, err error) {
185	if off+1 > len(msg) {
186		return len(msg), &Error{err: "overflow packing uint8"}
187	}
188	msg[off] = i
189	return off + 1, nil
190}
191
192func unpackUint16(msg []byte, off int) (i uint16, off1 int, err error) {
193	if off+2 > len(msg) {
194		return 0, len(msg), &Error{err: "overflow unpacking uint16"}
195	}
196	return binary.BigEndian.Uint16(msg[off:]), off + 2, nil
197}
198
199func packUint16(i uint16, msg []byte, off int) (off1 int, err error) {
200	if off+2 > len(msg) {
201		return len(msg), &Error{err: "overflow packing uint16"}
202	}
203	binary.BigEndian.PutUint16(msg[off:], i)
204	return off + 2, nil
205}
206
207func unpackUint32(msg []byte, off int) (i uint32, off1 int, err error) {
208	if off+4 > len(msg) {
209		return 0, len(msg), &Error{err: "overflow unpacking uint32"}
210	}
211	return binary.BigEndian.Uint32(msg[off:]), off + 4, nil
212}
213
214func packUint32(i uint32, msg []byte, off int) (off1 int, err error) {
215	if off+4 > len(msg) {
216		return len(msg), &Error{err: "overflow packing uint32"}
217	}
218	binary.BigEndian.PutUint32(msg[off:], i)
219	return off + 4, nil
220}
221
222func unpackUint48(msg []byte, off int) (i uint64, off1 int, err error) {
223	if off+6 > len(msg) {
224		return 0, len(msg), &Error{err: "overflow unpacking uint64 as uint48"}
225	}
226	// Used in TSIG where the last 48 bits are occupied, so for now, assume a uint48 (6 bytes)
227	i = uint64(msg[off])<<40 | uint64(msg[off+1])<<32 | uint64(msg[off+2])<<24 | uint64(msg[off+3])<<16 |
228		uint64(msg[off+4])<<8 | uint64(msg[off+5])
229	off += 6
230	return i, off, nil
231}
232
233func packUint48(i uint64, msg []byte, off int) (off1 int, err error) {
234	if off+6 > len(msg) {
235		return len(msg), &Error{err: "overflow packing uint64 as uint48"}
236	}
237	msg[off] = byte(i >> 40)
238	msg[off+1] = byte(i >> 32)
239	msg[off+2] = byte(i >> 24)
240	msg[off+3] = byte(i >> 16)
241	msg[off+4] = byte(i >> 8)
242	msg[off+5] = byte(i)
243	off += 6
244	return off, nil
245}
246
247func unpackUint64(msg []byte, off int) (i uint64, off1 int, err error) {
248	if off+8 > len(msg) {
249		return 0, len(msg), &Error{err: "overflow unpacking uint64"}
250	}
251	return binary.BigEndian.Uint64(msg[off:]), off + 8, nil
252}
253
254func packUint64(i uint64, msg []byte, off int) (off1 int, err error) {
255	if off+8 > len(msg) {
256		return len(msg), &Error{err: "overflow packing uint64"}
257	}
258	binary.BigEndian.PutUint64(msg[off:], i)
259	off += 8
260	return off, nil
261}
262
263func unpackString(msg []byte, off int) (string, int, error) {
264	if off+1 > len(msg) {
265		return "", off, &Error{err: "overflow unpacking txt"}
266	}
267	l := int(msg[off])
268	off++
269	if off+l > len(msg) {
270		return "", off, &Error{err: "overflow unpacking txt"}
271	}
272	var s strings.Builder
273	consumed := 0
274	for i, b := range msg[off : off+l] {
275		switch {
276		case b == '"' || b == '\\':
277			if consumed == 0 {
278				s.Grow(l * 2)
279			}
280			s.Write(msg[off+consumed : off+i])
281			s.WriteByte('\\')
282			s.WriteByte(b)
283			consumed = i + 1
284		case b < ' ' || b > '~': // unprintable
285			if consumed == 0 {
286				s.Grow(l * 2)
287			}
288			s.Write(msg[off+consumed : off+i])
289			s.WriteString(escapeByte(b))
290			consumed = i + 1
291		}
292	}
293	if consumed == 0 { // no escaping needed
294		return string(msg[off : off+l]), off + l, nil
295	}
296	s.Write(msg[off+consumed : off+l])
297	return s.String(), off + l, nil
298}
299
300func packString(s string, msg []byte, off int) (int, error) {
301	txtTmp := make([]byte, 256*4+1)
302	off, err := packTxtString(s, msg, off, txtTmp)
303	if err != nil {
304		return len(msg), err
305	}
306	return off, nil
307}
308
309func unpackStringBase32(msg []byte, off, end int) (string, int, error) {
310	if end > len(msg) {
311		return "", len(msg), &Error{err: "overflow unpacking base32"}
312	}
313	s := toBase32(msg[off:end])
314	return s, end, nil
315}
316
317func packStringBase32(s string, msg []byte, off int) (int, error) {
318	b32, err := fromBase32([]byte(s))
319	if err != nil {
320		return len(msg), err
321	}
322	if off+len(b32) > len(msg) {
323		return len(msg), &Error{err: "overflow packing base32"}
324	}
325	copy(msg[off:off+len(b32)], b32)
326	off += len(b32)
327	return off, nil
328}
329
330func unpackStringBase64(msg []byte, off, end int) (string, int, error) {
331	// Rest of the RR is base64 encoded value, so we don't need an explicit length
332	// to be set. Thus far all RR's that have base64 encoded fields have those as their
333	// last one. What we do need is the end of the RR!
334	if end > len(msg) {
335		return "", len(msg), &Error{err: "overflow unpacking base64"}
336	}
337	s := toBase64(msg[off:end])
338	return s, end, nil
339}
340
341func packStringBase64(s string, msg []byte, off int) (int, error) {
342	b64, err := fromBase64([]byte(s))
343	if err != nil {
344		return len(msg), err
345	}
346	if off+len(b64) > len(msg) {
347		return len(msg), &Error{err: "overflow packing base64"}
348	}
349	copy(msg[off:off+len(b64)], b64)
350	off += len(b64)
351	return off, nil
352}
353
354func unpackStringHex(msg []byte, off, end int) (string, int, error) {
355	// Rest of the RR is hex encoded value, so we don't need an explicit length
356	// to be set. NSEC and TSIG have hex fields with a length field.
357	// What we do need is the end of the RR!
358	if end > len(msg) {
359		return "", len(msg), &Error{err: "overflow unpacking hex"}
360	}
361
362	s := hex.EncodeToString(msg[off:end])
363	return s, end, nil
364}
365
366func packStringHex(s string, msg []byte, off int) (int, error) {
367	h, err := hex.DecodeString(s)
368	if err != nil {
369		return len(msg), err
370	}
371	if off+len(h) > len(msg) {
372		return len(msg), &Error{err: "overflow packing hex"}
373	}
374	copy(msg[off:off+len(h)], h)
375	off += len(h)
376	return off, nil
377}
378
379func unpackStringAny(msg []byte, off, end int) (string, int, error) {
380	if end > len(msg) {
381		return "", len(msg), &Error{err: "overflow unpacking anything"}
382	}
383	return string(msg[off:end]), end, nil
384}
385
386func packStringAny(s string, msg []byte, off int) (int, error) {
387	if off+len(s) > len(msg) {
388		return len(msg), &Error{err: "overflow packing anything"}
389	}
390	copy(msg[off:off+len(s)], s)
391	off += len(s)
392	return off, nil
393}
394
395func unpackStringTxt(msg []byte, off int) ([]string, int, error) {
396	txt, off, err := unpackTxt(msg, off)
397	if err != nil {
398		return nil, len(msg), err
399	}
400	return txt, off, nil
401}
402
403func packStringTxt(s []string, msg []byte, off int) (int, error) {
404	txtTmp := make([]byte, 256*4+1) // If the whole string consists out of \DDD we need this many.
405	off, err := packTxt(s, msg, off, txtTmp)
406	if err != nil {
407		return len(msg), err
408	}
409	return off, nil
410}
411
412func unpackDataOpt(msg []byte, off int) ([]EDNS0, int, error) {
413	var edns []EDNS0
414Option:
415	var code uint16
416	if off+4 > len(msg) {
417		return nil, len(msg), &Error{err: "overflow unpacking opt"}
418	}
419	code = binary.BigEndian.Uint16(msg[off:])
420	off += 2
421	optlen := binary.BigEndian.Uint16(msg[off:])
422	off += 2
423	if off+int(optlen) > len(msg) {
424		return nil, len(msg), &Error{err: "overflow unpacking opt"}
425	}
426	switch code {
427	case EDNS0NSID:
428		e := new(EDNS0_NSID)
429		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
430			return nil, len(msg), err
431		}
432		edns = append(edns, e)
433		off += int(optlen)
434	case EDNS0SUBNET:
435		e := new(EDNS0_SUBNET)
436		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
437			return nil, len(msg), err
438		}
439		edns = append(edns, e)
440		off += int(optlen)
441	case EDNS0COOKIE:
442		e := new(EDNS0_COOKIE)
443		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
444			return nil, len(msg), err
445		}
446		edns = append(edns, e)
447		off += int(optlen)
448	case EDNS0UL:
449		e := new(EDNS0_UL)
450		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
451			return nil, len(msg), err
452		}
453		edns = append(edns, e)
454		off += int(optlen)
455	case EDNS0LLQ:
456		e := new(EDNS0_LLQ)
457		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
458			return nil, len(msg), err
459		}
460		edns = append(edns, e)
461		off += int(optlen)
462	case EDNS0DAU:
463		e := new(EDNS0_DAU)
464		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
465			return nil, len(msg), err
466		}
467		edns = append(edns, e)
468		off += int(optlen)
469	case EDNS0DHU:
470		e := new(EDNS0_DHU)
471		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
472			return nil, len(msg), err
473		}
474		edns = append(edns, e)
475		off += int(optlen)
476	case EDNS0N3U:
477		e := new(EDNS0_N3U)
478		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
479			return nil, len(msg), err
480		}
481		edns = append(edns, e)
482		off += int(optlen)
483	case EDNS0PADDING:
484		e := new(EDNS0_PADDING)
485		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
486			return nil, len(msg), err
487		}
488		edns = append(edns, e)
489		off += int(optlen)
490	default:
491		e := new(EDNS0_LOCAL)
492		e.Code = code
493		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
494			return nil, len(msg), err
495		}
496		edns = append(edns, e)
497		off += int(optlen)
498	}
499
500	if off < len(msg) {
501		goto Option
502	}
503
504	return edns, off, nil
505}
506
507func packDataOpt(options []EDNS0, msg []byte, off int) (int, error) {
508	for _, el := range options {
509		b, err := el.pack()
510		if err != nil || off+4 > len(msg) {
511			return len(msg), &Error{err: "overflow packing opt"}
512		}
513		binary.BigEndian.PutUint16(msg[off:], el.Option())      // Option code
514		binary.BigEndian.PutUint16(msg[off+2:], uint16(len(b))) // Length
515		off += 4
516		if off+len(b) > len(msg) {
517			copy(msg[off:], b)
518			off = len(msg)
519			continue
520		}
521		// Actual data
522		copy(msg[off:off+len(b)], b)
523		off += len(b)
524	}
525	return off, nil
526}
527
528func unpackStringOctet(msg []byte, off int) (string, int, error) {
529	s := string(msg[off:])
530	return s, len(msg), nil
531}
532
533func packStringOctet(s string, msg []byte, off int) (int, error) {
534	txtTmp := make([]byte, 256*4+1)
535	off, err := packOctetString(s, msg, off, txtTmp)
536	if err != nil {
537		return len(msg), err
538	}
539	return off, nil
540}
541
542func unpackDataNsec(msg []byte, off int) ([]uint16, int, error) {
543	var nsec []uint16
544	length, window, lastwindow := 0, 0, -1
545	for off < len(msg) {
546		if off+2 > len(msg) {
547			return nsec, len(msg), &Error{err: "overflow unpacking nsecx"}
548		}
549		window = int(msg[off])
550		length = int(msg[off+1])
551		off += 2
552		if window <= lastwindow {
553			// RFC 4034: Blocks are present in the NSEC RR RDATA in
554			// increasing numerical order.
555			return nsec, len(msg), &Error{err: "out of order NSEC block"}
556		}
557		if length == 0 {
558			// RFC 4034: Blocks with no types present MUST NOT be included.
559			return nsec, len(msg), &Error{err: "empty NSEC block"}
560		}
561		if length > 32 {
562			return nsec, len(msg), &Error{err: "NSEC block too long"}
563		}
564		if off+length > len(msg) {
565			return nsec, len(msg), &Error{err: "overflowing NSEC block"}
566		}
567
568		// Walk the bytes in the window and extract the type bits
569		for j, b := range msg[off : off+length] {
570			// Check the bits one by one, and set the type
571			if b&0x80 == 0x80 {
572				nsec = append(nsec, uint16(window*256+j*8+0))
573			}
574			if b&0x40 == 0x40 {
575				nsec = append(nsec, uint16(window*256+j*8+1))
576			}
577			if b&0x20 == 0x20 {
578				nsec = append(nsec, uint16(window*256+j*8+2))
579			}
580			if b&0x10 == 0x10 {
581				nsec = append(nsec, uint16(window*256+j*8+3))
582			}
583			if b&0x8 == 0x8 {
584				nsec = append(nsec, uint16(window*256+j*8+4))
585			}
586			if b&0x4 == 0x4 {
587				nsec = append(nsec, uint16(window*256+j*8+5))
588			}
589			if b&0x2 == 0x2 {
590				nsec = append(nsec, uint16(window*256+j*8+6))
591			}
592			if b&0x1 == 0x1 {
593				nsec = append(nsec, uint16(window*256+j*8+7))
594			}
595		}
596		off += length
597		lastwindow = window
598	}
599	return nsec, off, nil
600}
601
602// typeBitMapLen is a helper function which computes the "maximum" length of
603// a the NSEC Type BitMap field.
604func typeBitMapLen(bitmap []uint16) int {
605	var l int
606	var lastwindow, lastlength uint16
607	for _, t := range bitmap {
608		window := t / 256
609		length := (t-window*256)/8 + 1
610		if window > lastwindow && lastlength != 0 { // New window, jump to the new offset
611			l += int(lastlength) + 2
612			lastlength = 0
613		}
614		if window < lastwindow || length < lastlength {
615			// packDataNsec would return Error{err: "nsec bits out of order"} here, but
616			// when computing the length, we want do be liberal.
617			continue
618		}
619		lastwindow, lastlength = window, length
620	}
621	l += int(lastlength) + 2
622	return l
623}
624
625func packDataNsec(bitmap []uint16, msg []byte, off int) (int, error) {
626	if len(bitmap) == 0 {
627		return off, nil
628	}
629	var lastwindow, lastlength uint16
630	for _, t := range bitmap {
631		window := t / 256
632		length := (t-window*256)/8 + 1
633		if window > lastwindow && lastlength != 0 { // New window, jump to the new offset
634			off += int(lastlength) + 2
635			lastlength = 0
636		}
637		if window < lastwindow || length < lastlength {
638			return len(msg), &Error{err: "nsec bits out of order"}
639		}
640		if off+2+int(length) > len(msg) {
641			return len(msg), &Error{err: "overflow packing nsec"}
642		}
643		// Setting the window #
644		msg[off] = byte(window)
645		// Setting the octets length
646		msg[off+1] = byte(length)
647		// Setting the bit value for the type in the right octet
648		msg[off+1+int(length)] |= byte(1 << (7 - t%8))
649		lastwindow, lastlength = window, length
650	}
651	off += int(lastlength) + 2
652	return off, nil
653}
654
655func unpackDataDomainNames(msg []byte, off, end int) ([]string, int, error) {
656	var (
657		servers []string
658		s       string
659		err     error
660	)
661	if end > len(msg) {
662		return nil, len(msg), &Error{err: "overflow unpacking domain names"}
663	}
664	for off < end {
665		s, off, err = UnpackDomainName(msg, off)
666		if err != nil {
667			return servers, len(msg), err
668		}
669		servers = append(servers, s)
670	}
671	return servers, off, nil
672}
673
674func packDataDomainNames(names []string, msg []byte, off int, compression compressionMap, compress bool) (int, error) {
675	var err error
676	for _, name := range names {
677		off, err = packDomainName(name, msg, off, compression, compress)
678		if err != nil {
679			return len(msg), err
680		}
681	}
682	return off, nil
683}
684