1package dns
2
3import (
4	"encoding/base32"
5	"encoding/base64"
6	"encoding/binary"
7	"encoding/hex"
8	"net"
9	"strings"
10)
11
12// helper functions called from the generated zmsg.go
13
14// These function are named after the tag to help pack/unpack, if there is no tag it is the name
15// of the type they pack/unpack (string, int, etc). We prefix all with unpackData or packData, so packDataA or
16// packDataDomainName.
17
18func unpackDataA(msg []byte, off int) (net.IP, int, error) {
19	if off+net.IPv4len > len(msg) {
20		return nil, len(msg), &Error{err: "overflow unpacking a"}
21	}
22	a := append(make(net.IP, 0, net.IPv4len), msg[off:off+net.IPv4len]...)
23	off += net.IPv4len
24	return a, off, nil
25}
26
27func packDataA(a net.IP, msg []byte, off int) (int, error) {
28	// It must be a slice of 4, even if it is 16, we encode only the first 4
29	if off+net.IPv4len > len(msg) {
30		return len(msg), &Error{err: "overflow packing a"}
31	}
32	switch len(a) {
33	case net.IPv4len, net.IPv6len:
34		copy(msg[off:], a.To4())
35		off += net.IPv4len
36	case 0:
37		// Allowed, for dynamic updates.
38	default:
39		return len(msg), &Error{err: "overflow packing a"}
40	}
41	return off, nil
42}
43
44func unpackDataAAAA(msg []byte, off int) (net.IP, int, error) {
45	if off+net.IPv6len > len(msg) {
46		return nil, len(msg), &Error{err: "overflow unpacking aaaa"}
47	}
48	aaaa := append(make(net.IP, 0, net.IPv6len), msg[off:off+net.IPv6len]...)
49	off += net.IPv6len
50	return aaaa, off, nil
51}
52
53func packDataAAAA(aaaa net.IP, msg []byte, off int) (int, error) {
54	if off+net.IPv6len > len(msg) {
55		return len(msg), &Error{err: "overflow packing aaaa"}
56	}
57
58	switch len(aaaa) {
59	case net.IPv6len:
60		copy(msg[off:], aaaa)
61		off += net.IPv6len
62	case 0:
63		// Allowed, dynamic updates.
64	default:
65		return len(msg), &Error{err: "overflow packing aaaa"}
66	}
67	return off, nil
68}
69
70// unpackHeader unpacks an RR header, returning the offset to the end of the header and a
71// re-sliced msg according to the expected length of the RR.
72func unpackHeader(msg []byte, off int) (rr RR_Header, off1 int, truncmsg []byte, err error) {
73	hdr := RR_Header{}
74	if off == len(msg) {
75		return hdr, off, msg, nil
76	}
77
78	hdr.Name, off, err = UnpackDomainName(msg, off)
79	if err != nil {
80		return hdr, len(msg), msg, err
81	}
82	hdr.Rrtype, off, err = unpackUint16(msg, off)
83	if err != nil {
84		return hdr, len(msg), msg, err
85	}
86	hdr.Class, off, err = unpackUint16(msg, off)
87	if err != nil {
88		return hdr, len(msg), msg, err
89	}
90	hdr.Ttl, off, err = unpackUint32(msg, off)
91	if err != nil {
92		return hdr, len(msg), msg, err
93	}
94	hdr.Rdlength, off, err = unpackUint16(msg, off)
95	if err != nil {
96		return hdr, len(msg), msg, err
97	}
98	msg, err = truncateMsgFromRdlength(msg, off, hdr.Rdlength)
99	return hdr, off, msg, err
100}
101
102// packHeader packs an RR header, returning the offset to the end of the header.
103// See PackDomainName for documentation about the compression.
104func (hdr RR_Header) packHeader(msg []byte, off int, compression compressionMap, compress bool) (int, error) {
105	if off == len(msg) {
106		return off, nil
107	}
108
109	off, err := packDomainName(hdr.Name, msg, off, compression, compress)
110	if err != nil {
111		return len(msg), err
112	}
113	off, err = packUint16(hdr.Rrtype, msg, off)
114	if err != nil {
115		return len(msg), err
116	}
117	off, err = packUint16(hdr.Class, msg, off)
118	if err != nil {
119		return len(msg), err
120	}
121	off, err = packUint32(hdr.Ttl, msg, off)
122	if err != nil {
123		return len(msg), err
124	}
125	off, err = packUint16(0, msg, off) // The RDLENGTH field will be set later in packRR.
126	if err != nil {
127		return len(msg), err
128	}
129	return off, nil
130}
131
132// helper helper functions.
133
134// truncateMsgFromRdLength truncates msg to match the expected length of the RR.
135// Returns an error if msg is smaller than the expected size.
136func truncateMsgFromRdlength(msg []byte, off int, rdlength uint16) (truncmsg []byte, err error) {
137	lenrd := off + int(rdlength)
138	if lenrd > len(msg) {
139		return msg, &Error{err: "overflowing header size"}
140	}
141	return msg[:lenrd], nil
142}
143
144var base32HexNoPadEncoding = base32.HexEncoding.WithPadding(base32.NoPadding)
145
146func fromBase32(s []byte) (buf []byte, err error) {
147	for i, b := range s {
148		if b >= 'a' && b <= 'z' {
149			s[i] = b - 32
150		}
151	}
152	buflen := base32HexNoPadEncoding.DecodedLen(len(s))
153	buf = make([]byte, buflen)
154	n, err := base32HexNoPadEncoding.Decode(buf, s)
155	buf = buf[:n]
156	return
157}
158
159func toBase32(b []byte) string {
160	return base32HexNoPadEncoding.EncodeToString(b)
161}
162
163func fromBase64(s []byte) (buf []byte, err error) {
164	buflen := base64.StdEncoding.DecodedLen(len(s))
165	buf = make([]byte, buflen)
166	n, err := base64.StdEncoding.Decode(buf, s)
167	buf = buf[:n]
168	return
169}
170
171func toBase64(b []byte) string { return base64.StdEncoding.EncodeToString(b) }
172
173// dynamicUpdate returns true if the Rdlength is zero.
174func noRdata(h RR_Header) bool { return h.Rdlength == 0 }
175
176func unpackUint8(msg []byte, off int) (i uint8, off1 int, err error) {
177	if off+1 > len(msg) {
178		return 0, len(msg), &Error{err: "overflow unpacking uint8"}
179	}
180	return msg[off], off + 1, nil
181}
182
183func packUint8(i uint8, msg []byte, off int) (off1 int, err error) {
184	if off+1 > len(msg) {
185		return len(msg), &Error{err: "overflow packing uint8"}
186	}
187	msg[off] = i
188	return off + 1, nil
189}
190
191func unpackUint16(msg []byte, off int) (i uint16, off1 int, err error) {
192	if off+2 > len(msg) {
193		return 0, len(msg), &Error{err: "overflow unpacking uint16"}
194	}
195	return binary.BigEndian.Uint16(msg[off:]), off + 2, nil
196}
197
198func packUint16(i uint16, msg []byte, off int) (off1 int, err error) {
199	if off+2 > len(msg) {
200		return len(msg), &Error{err: "overflow packing uint16"}
201	}
202	binary.BigEndian.PutUint16(msg[off:], i)
203	return off + 2, nil
204}
205
206func unpackUint32(msg []byte, off int) (i uint32, off1 int, err error) {
207	if off+4 > len(msg) {
208		return 0, len(msg), &Error{err: "overflow unpacking uint32"}
209	}
210	return binary.BigEndian.Uint32(msg[off:]), off + 4, nil
211}
212
213func packUint32(i uint32, msg []byte, off int) (off1 int, err error) {
214	if off+4 > len(msg) {
215		return len(msg), &Error{err: "overflow packing uint32"}
216	}
217	binary.BigEndian.PutUint32(msg[off:], i)
218	return off + 4, nil
219}
220
221func unpackUint48(msg []byte, off int) (i uint64, off1 int, err error) {
222	if off+6 > len(msg) {
223		return 0, len(msg), &Error{err: "overflow unpacking uint64 as uint48"}
224	}
225	// Used in TSIG where the last 48 bits are occupied, so for now, assume a uint48 (6 bytes)
226	i = uint64(msg[off])<<40 | uint64(msg[off+1])<<32 | uint64(msg[off+2])<<24 | uint64(msg[off+3])<<16 |
227		uint64(msg[off+4])<<8 | uint64(msg[off+5])
228	off += 6
229	return i, off, nil
230}
231
232func packUint48(i uint64, msg []byte, off int) (off1 int, err error) {
233	if off+6 > len(msg) {
234		return len(msg), &Error{err: "overflow packing uint64 as uint48"}
235	}
236	msg[off] = byte(i >> 40)
237	msg[off+1] = byte(i >> 32)
238	msg[off+2] = byte(i >> 24)
239	msg[off+3] = byte(i >> 16)
240	msg[off+4] = byte(i >> 8)
241	msg[off+5] = byte(i)
242	off += 6
243	return off, nil
244}
245
246func unpackUint64(msg []byte, off int) (i uint64, off1 int, err error) {
247	if off+8 > len(msg) {
248		return 0, len(msg), &Error{err: "overflow unpacking uint64"}
249	}
250	return binary.BigEndian.Uint64(msg[off:]), off + 8, nil
251}
252
253func packUint64(i uint64, msg []byte, off int) (off1 int, err error) {
254	if off+8 > len(msg) {
255		return len(msg), &Error{err: "overflow packing uint64"}
256	}
257	binary.BigEndian.PutUint64(msg[off:], i)
258	off += 8
259	return off, nil
260}
261
262func unpackString(msg []byte, off int) (string, int, error) {
263	if off+1 > len(msg) {
264		return "", off, &Error{err: "overflow unpacking txt"}
265	}
266	l := int(msg[off])
267	if off+l+1 > len(msg) {
268		return "", off, &Error{err: "overflow unpacking txt"}
269	}
270	var s strings.Builder
271	s.Grow(l)
272	for _, b := range msg[off+1 : off+1+l] {
273		switch {
274		case b == '"' || b == '\\':
275			s.WriteByte('\\')
276			s.WriteByte(b)
277		case b < ' ' || b > '~': // unprintable
278			s.WriteString(escapeByte(b))
279		default:
280			s.WriteByte(b)
281		}
282	}
283	off += 1 + l
284	return s.String(), off, nil
285}
286
287func packString(s string, msg []byte, off int) (int, error) {
288	txtTmp := make([]byte, 256*4+1)
289	off, err := packTxtString(s, msg, off, txtTmp)
290	if err != nil {
291		return len(msg), err
292	}
293	return off, nil
294}
295
296func unpackStringBase32(msg []byte, off, end int) (string, int, error) {
297	if end > len(msg) {
298		return "", len(msg), &Error{err: "overflow unpacking base32"}
299	}
300	s := toBase32(msg[off:end])
301	return s, end, nil
302}
303
304func packStringBase32(s string, msg []byte, off int) (int, error) {
305	b32, err := fromBase32([]byte(s))
306	if err != nil {
307		return len(msg), err
308	}
309	if off+len(b32) > len(msg) {
310		return len(msg), &Error{err: "overflow packing base32"}
311	}
312	copy(msg[off:off+len(b32)], b32)
313	off += len(b32)
314	return off, nil
315}
316
317func unpackStringBase64(msg []byte, off, end int) (string, int, error) {
318	// Rest of the RR is base64 encoded value, so we don't need an explicit length
319	// to be set. Thus far all RR's that have base64 encoded fields have those as their
320	// last one. What we do need is the end of the RR!
321	if end > len(msg) {
322		return "", len(msg), &Error{err: "overflow unpacking base64"}
323	}
324	s := toBase64(msg[off:end])
325	return s, end, nil
326}
327
328func packStringBase64(s string, msg []byte, off int) (int, error) {
329	b64, err := fromBase64([]byte(s))
330	if err != nil {
331		return len(msg), err
332	}
333	if off+len(b64) > len(msg) {
334		return len(msg), &Error{err: "overflow packing base64"}
335	}
336	copy(msg[off:off+len(b64)], b64)
337	off += len(b64)
338	return off, nil
339}
340
341func unpackStringHex(msg []byte, off, end int) (string, int, error) {
342	// Rest of the RR is hex encoded value, so we don't need an explicit length
343	// to be set. NSEC and TSIG have hex fields with a length field.
344	// What we do need is the end of the RR!
345	if end > len(msg) {
346		return "", len(msg), &Error{err: "overflow unpacking hex"}
347	}
348
349	s := hex.EncodeToString(msg[off:end])
350	return s, end, nil
351}
352
353func packStringHex(s string, msg []byte, off int) (int, error) {
354	h, err := hex.DecodeString(s)
355	if err != nil {
356		return len(msg), err
357	}
358	if off+len(h) > len(msg) {
359		return len(msg), &Error{err: "overflow packing hex"}
360	}
361	copy(msg[off:off+len(h)], h)
362	off += len(h)
363	return off, nil
364}
365
366func unpackStringAny(msg []byte, off, end int) (string, int, error) {
367	if end > len(msg) {
368		return "", len(msg), &Error{err: "overflow unpacking anything"}
369	}
370	return string(msg[off:end]), end, nil
371}
372
373func packStringAny(s string, msg []byte, off int) (int, error) {
374	if off+len(s) > len(msg) {
375		return len(msg), &Error{err: "overflow packing anything"}
376	}
377	copy(msg[off:off+len(s)], s)
378	off += len(s)
379	return off, nil
380}
381
382func unpackStringTxt(msg []byte, off int) ([]string, int, error) {
383	txt, off, err := unpackTxt(msg, off)
384	if err != nil {
385		return nil, len(msg), err
386	}
387	return txt, off, nil
388}
389
390func packStringTxt(s []string, msg []byte, off int) (int, error) {
391	txtTmp := make([]byte, 256*4+1) // If the whole string consists out of \DDD we need this many.
392	off, err := packTxt(s, msg, off, txtTmp)
393	if err != nil {
394		return len(msg), err
395	}
396	return off, nil
397}
398
399func unpackDataOpt(msg []byte, off int) ([]EDNS0, int, error) {
400	var edns []EDNS0
401Option:
402	var code uint16
403	if off+4 > len(msg) {
404		return nil, len(msg), &Error{err: "overflow unpacking opt"}
405	}
406	code = binary.BigEndian.Uint16(msg[off:])
407	off += 2
408	optlen := binary.BigEndian.Uint16(msg[off:])
409	off += 2
410	if off+int(optlen) > len(msg) {
411		return nil, len(msg), &Error{err: "overflow unpacking opt"}
412	}
413	switch code {
414	case EDNS0NSID:
415		e := new(EDNS0_NSID)
416		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
417			return nil, len(msg), err
418		}
419		edns = append(edns, e)
420		off += int(optlen)
421	case EDNS0SUBNET:
422		e := new(EDNS0_SUBNET)
423		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
424			return nil, len(msg), err
425		}
426		edns = append(edns, e)
427		off += int(optlen)
428	case EDNS0COOKIE:
429		e := new(EDNS0_COOKIE)
430		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
431			return nil, len(msg), err
432		}
433		edns = append(edns, e)
434		off += int(optlen)
435	case EDNS0UL:
436		e := new(EDNS0_UL)
437		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
438			return nil, len(msg), err
439		}
440		edns = append(edns, e)
441		off += int(optlen)
442	case EDNS0LLQ:
443		e := new(EDNS0_LLQ)
444		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
445			return nil, len(msg), err
446		}
447		edns = append(edns, e)
448		off += int(optlen)
449	case EDNS0DAU:
450		e := new(EDNS0_DAU)
451		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
452			return nil, len(msg), err
453		}
454		edns = append(edns, e)
455		off += int(optlen)
456	case EDNS0DHU:
457		e := new(EDNS0_DHU)
458		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
459			return nil, len(msg), err
460		}
461		edns = append(edns, e)
462		off += int(optlen)
463	case EDNS0N3U:
464		e := new(EDNS0_N3U)
465		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
466			return nil, len(msg), err
467		}
468		edns = append(edns, e)
469		off += int(optlen)
470	case EDNS0PADDING:
471		e := new(EDNS0_PADDING)
472		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
473			return nil, len(msg), err
474		}
475		edns = append(edns, e)
476		off += int(optlen)
477	default:
478		e := new(EDNS0_LOCAL)
479		e.Code = code
480		if err := e.unpack(msg[off : off+int(optlen)]); err != nil {
481			return nil, len(msg), err
482		}
483		edns = append(edns, e)
484		off += int(optlen)
485	}
486
487	if off < len(msg) {
488		goto Option
489	}
490
491	return edns, off, nil
492}
493
494func packDataOpt(options []EDNS0, msg []byte, off int) (int, error) {
495	for _, el := range options {
496		b, err := el.pack()
497		if err != nil || off+3 > len(msg) {
498			return len(msg), &Error{err: "overflow packing opt"}
499		}
500		binary.BigEndian.PutUint16(msg[off:], el.Option())      // Option code
501		binary.BigEndian.PutUint16(msg[off+2:], uint16(len(b))) // Length
502		off += 4
503		if off+len(b) > len(msg) {
504			copy(msg[off:], b)
505			off = len(msg)
506			continue
507		}
508		// Actual data
509		copy(msg[off:off+len(b)], b)
510		off += len(b)
511	}
512	return off, nil
513}
514
515func unpackStringOctet(msg []byte, off int) (string, int, error) {
516	s := string(msg[off:])
517	return s, len(msg), nil
518}
519
520func packStringOctet(s string, msg []byte, off int) (int, error) {
521	txtTmp := make([]byte, 256*4+1)
522	off, err := packOctetString(s, msg, off, txtTmp)
523	if err != nil {
524		return len(msg), err
525	}
526	return off, nil
527}
528
529func unpackDataNsec(msg []byte, off int) ([]uint16, int, error) {
530	var nsec []uint16
531	length, window, lastwindow := 0, 0, -1
532	for off < len(msg) {
533		if off+2 > len(msg) {
534			return nsec, len(msg), &Error{err: "overflow unpacking nsecx"}
535		}
536		window = int(msg[off])
537		length = int(msg[off+1])
538		off += 2
539		if window <= lastwindow {
540			// RFC 4034: Blocks are present in the NSEC RR RDATA in
541			// increasing numerical order.
542			return nsec, len(msg), &Error{err: "out of order NSEC block"}
543		}
544		if length == 0 {
545			// RFC 4034: Blocks with no types present MUST NOT be included.
546			return nsec, len(msg), &Error{err: "empty NSEC block"}
547		}
548		if length > 32 {
549			return nsec, len(msg), &Error{err: "NSEC block too long"}
550		}
551		if off+length > len(msg) {
552			return nsec, len(msg), &Error{err: "overflowing NSEC block"}
553		}
554
555		// Walk the bytes in the window and extract the type bits
556		for j := 0; j < length; j++ {
557			b := msg[off+j]
558			// Check the bits one by one, and set the type
559			if b&0x80 == 0x80 {
560				nsec = append(nsec, uint16(window*256+j*8+0))
561			}
562			if b&0x40 == 0x40 {
563				nsec = append(nsec, uint16(window*256+j*8+1))
564			}
565			if b&0x20 == 0x20 {
566				nsec = append(nsec, uint16(window*256+j*8+2))
567			}
568			if b&0x10 == 0x10 {
569				nsec = append(nsec, uint16(window*256+j*8+3))
570			}
571			if b&0x8 == 0x8 {
572				nsec = append(nsec, uint16(window*256+j*8+4))
573			}
574			if b&0x4 == 0x4 {
575				nsec = append(nsec, uint16(window*256+j*8+5))
576			}
577			if b&0x2 == 0x2 {
578				nsec = append(nsec, uint16(window*256+j*8+6))
579			}
580			if b&0x1 == 0x1 {
581				nsec = append(nsec, uint16(window*256+j*8+7))
582			}
583		}
584		off += length
585		lastwindow = window
586	}
587	return nsec, off, nil
588}
589
590func packDataNsec(bitmap []uint16, msg []byte, off int) (int, error) {
591	if len(bitmap) == 0 {
592		return off, nil
593	}
594	var lastwindow, lastlength uint16
595	for j := 0; j < len(bitmap); j++ {
596		t := bitmap[j]
597		window := t / 256
598		length := (t-window*256)/8 + 1
599		if window > lastwindow && lastlength != 0 { // New window, jump to the new offset
600			off += int(lastlength) + 2
601			lastlength = 0
602		}
603		if window < lastwindow || length < lastlength {
604			return len(msg), &Error{err: "nsec bits out of order"}
605		}
606		if off+2+int(length) > len(msg) {
607			return len(msg), &Error{err: "overflow packing nsec"}
608		}
609		// Setting the window #
610		msg[off] = byte(window)
611		// Setting the octets length
612		msg[off+1] = byte(length)
613		// Setting the bit value for the type in the right octet
614		msg[off+1+int(length)] |= byte(1 << (7 - t%8))
615		lastwindow, lastlength = window, length
616	}
617	off += int(lastlength) + 2
618	return off, nil
619}
620
621func unpackDataDomainNames(msg []byte, off, end int) ([]string, int, error) {
622	var (
623		servers []string
624		s       string
625		err     error
626	)
627	if end > len(msg) {
628		return nil, len(msg), &Error{err: "overflow unpacking domain names"}
629	}
630	for off < end {
631		s, off, err = UnpackDomainName(msg, off)
632		if err != nil {
633			return servers, len(msg), err
634		}
635		servers = append(servers, s)
636	}
637	return servers, off, nil
638}
639
640func packDataDomainNames(names []string, msg []byte, off int, compression compressionMap, compress bool) (int, error) {
641	var err error
642	for j := 0; j < len(names); j++ {
643		off, err = packDomainName(names[j], msg, off, compression, compress)
644		if err != nil {
645			return len(msg), err
646		}
647	}
648	return off, nil
649}
650