1package dns
2
3import (
4	"bytes"
5	"net"
6	"testing"
7)
8
9// TestPacketDataNsec tests generated using fuzz.go and with a message pack
10// containing the following bytes: 0000\x00\x00000000\x00\x002000000\x0060000\x00\x130000000000000000000"
11// That bytes sequence created the overflow error and further permutations of that sequence were able to trigger
12// the other code paths.
13func TestPackDataNsec(t *testing.T) {
14	type args struct {
15		bitmap []uint16
16		msg    []byte
17		off    int
18	}
19	tests := []struct {
20		name       string
21		args       args
22		want       int
23		wantErr    bool
24		wantErrMsg string
25	}{
26		{
27			name: "overflow",
28			args: args{
29				bitmap: []uint16{
30					8962, 8963, 8970, 8971, 8978, 8979,
31					8986, 8987, 8994, 8995, 9002, 9003,
32					9010, 9011, 9018, 9019, 9026, 9027,
33					9034, 9035, 9042, 9043, 9050, 9051,
34					9058, 9059, 9066,
35				},
36				msg: []byte{
37					48, 48, 48, 48, 0, 0, 0,
38					1, 0, 0, 0, 0, 0, 0, 50,
39					48, 48, 48, 48, 48, 48,
40					0, 54, 48, 48, 48, 48,
41					0, 19, 48, 48,
42				},
43				off: 48,
44			},
45			wantErr:    true,
46			wantErrMsg: "dns: overflow packing nsec",
47			want:       31,
48		},
49		{
50			name: "disordered nsec bits",
51			args: args{
52				bitmap: []uint16{
53					8962,
54					0,
55				},
56				msg: []byte{
57					48, 48, 48, 48, 0, 0, 0, 1, 0, 0, 0, 0,
58					0, 0, 50, 48, 48, 48, 48, 48, 48, 0, 54, 48,
59					48, 48, 48, 0, 19, 48, 48, 48, 48, 48, 48, 0,
60					0, 0, 1, 0, 0, 0, 0, 0, 0, 50, 48, 48,
61					48, 48, 48, 48, 0, 54, 48, 48, 48, 48, 0, 19,
62					48, 48, 48, 48, 48, 48, 0, 0, 0, 1, 0, 0,
63					0, 0, 0, 0, 50, 48, 48, 48, 48, 48, 48, 0,
64					54, 48, 48, 48, 48, 0, 19, 48, 48, 48, 48, 48,
65					48, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 50,
66					48, 48, 48, 48, 48, 48, 0, 54, 48, 48, 48, 48,
67					0, 19, 48, 48, 48, 48, 48, 48, 0, 0, 0, 1,
68					0, 0, 0, 0, 0, 0, 50, 48, 48, 48, 48, 48,
69					48, 0, 54, 48, 48, 48, 48, 0, 19, 48, 48,
70				},
71				off: 0,
72			},
73			wantErr:    true,
74			wantErrMsg: "dns: nsec bits out of order",
75			want:       155,
76		},
77		{
78			name: "simple message with only one window",
79			args: args{
80				bitmap: []uint16{
81					0,
82				},
83				msg: []byte{
84					48, 48, 48, 48, 0, 0,
85					0, 1, 0, 0, 0, 0,
86					0, 0, 50, 48, 48, 48,
87					48, 48, 48, 0, 54, 48,
88					48, 48, 48, 0, 19, 48, 48,
89				},
90				off: 0,
91			},
92			wantErr: false,
93			want:    3,
94		},
95	}
96	for _, tt := range tests {
97		t.Run(tt.name, func(t *testing.T) {
98			got, err := packDataNsec(tt.args.bitmap, tt.args.msg, tt.args.off)
99			if (err != nil) != tt.wantErr {
100				t.Errorf("packDataNsec() error = %v, wantErr %v", err, tt.wantErr)
101				return
102			}
103			if err != nil && tt.wantErrMsg != err.Error() {
104				t.Errorf("packDataNsec() error msg = %v, wantErrMsg %v", err.Error(), tt.wantErrMsg)
105				return
106			}
107			if got != tt.want {
108				t.Errorf("packDataNsec() = %v, want %v", got, tt.want)
109			}
110		})
111	}
112}
113
114func TestUnpackString(t *testing.T) {
115	msg := []byte("\x00abcdef\x0f\\\"ghi\x04mmm\x7f")
116	msg[0] = byte(len(msg) - 1)
117
118	got, _, err := unpackString(msg, 0)
119	if err != nil {
120		t.Fatal(err)
121	}
122
123	if want := `abcdef\015\\\"ghi\004mmm\127`; want != got {
124		t.Errorf("expected %q, got %q", want, got)
125	}
126}
127
128func BenchmarkUnpackString(b *testing.B) {
129	b.Run("Escaped", func(b *testing.B) {
130		msg := []byte("\x00abcdef\x0f\\\"ghi\x04mmm")
131		msg[0] = byte(len(msg) - 1)
132
133		for n := 0; n < b.N; n++ {
134			got, _, err := unpackString(msg, 0)
135			if err != nil {
136				b.Fatal(err)
137			}
138
139			if want := `abcdef\015\\\"ghi\004mmm`; want != got {
140				b.Errorf("expected %q, got %q", want, got)
141			}
142		}
143	})
144	b.Run("Unescaped", func(b *testing.B) {
145		msg := []byte("\x00large.example.com")
146		msg[0] = byte(len(msg) - 1)
147
148		for n := 0; n < b.N; n++ {
149			got, _, err := unpackString(msg, 0)
150			if err != nil {
151				b.Fatal(err)
152			}
153
154			if want := "large.example.com"; want != got {
155				b.Errorf("expected %q, got %q", want, got)
156			}
157		}
158	})
159}
160
161func TestPackDataAplPrefix(t *testing.T) {
162	tests := []struct {
163		name     string
164		negation bool
165		ip       net.IP
166		mask     net.IPMask
167		expect   []byte
168	}{
169		{
170			"1:192.0.2.0/24",
171			false,
172			net.ParseIP("192.0.2.0").To4(),
173			net.CIDRMask(24, 32),
174			[]byte{0x00, 0x01, 0x18, 0x03, 192, 0, 2},
175		},
176		{
177			"2:2001:db8:cafe::0/48",
178			false,
179			net.ParseIP("2001:db8:cafe::"),
180			net.CIDRMask(48, 128),
181			[]byte{0x00, 0x02, 0x30, 0x06, 0x20, 0x01, 0x0d, 0xb8, 0xca, 0xfe},
182		},
183		{
184			"with trailing zero bytes 2:2001:db8:cafe::0/64",
185			false,
186			net.ParseIP("2001:db8:cafe::"),
187			net.CIDRMask(64, 128),
188			[]byte{0x00, 0x02, 0x40, 0x06, 0x20, 0x01, 0x0d, 0xb8, 0xca, 0xfe},
189		},
190		{
191			"no non-zero bytes 2::/16",
192			false,
193			net.ParseIP("::"),
194			net.CIDRMask(16, 128),
195			[]byte{0x00, 0x02, 0x10, 0x00},
196		},
197		{
198			"!2:2001:db8::/32",
199			true,
200			net.ParseIP("2001:db8::"),
201			net.CIDRMask(32, 128),
202			[]byte{0x00, 0x02, 0x20, 0x84, 0x20, 0x01, 0x0d, 0xb8},
203		},
204		{
205			"normalize 1:198.51.103.255/22",
206			false,
207			net.ParseIP("198.51.103.255").To4(),
208			net.CIDRMask(22, 32),
209			[]byte{0x00, 0x01, 0x16, 0x03, 198, 51, 100}, // 1:198.51.100.0/22
210		},
211	}
212	for _, tt := range tests {
213		t.Run(tt.name, func(t *testing.T) {
214			ap := &APLPrefix{
215				Negation: tt.negation,
216				Network:  net.IPNet{IP: tt.ip, Mask: tt.mask},
217			}
218			out := make([]byte, 16)
219			off, err := packDataAplPrefix(ap, out, 0)
220			if err != nil {
221				t.Fatalf("expected no error, got %q", err)
222			}
223			if !bytes.Equal(tt.expect, out[:off]) {
224				t.Fatalf("expected output %02x, got %02x", tt.expect, out[:off])
225			}
226			// Make sure the packed bytes would be accepted by its own unpack
227			_, _, err = unpackDataAplPrefix(out, 0)
228			if err != nil {
229				t.Fatalf("expected no error, got %q", err)
230			}
231		})
232	}
233}
234
235func TestPackDataAplPrefix_Failures(t *testing.T) {
236	tests := []struct {
237		name string
238		ip   net.IP
239		mask net.IPMask
240	}{
241		{
242			"family mismatch",
243			net.ParseIP("2001:db8::"),
244			net.CIDRMask(24, 32),
245		},
246		{
247			"unrecognized family",
248			[]byte{0x42},
249			[]byte{0xff},
250		},
251	}
252	for _, tt := range tests {
253		t.Run(tt.name, func(t *testing.T) {
254			ap := &APLPrefix{Network: net.IPNet{IP: tt.ip, Mask: tt.mask}}
255			msg := make([]byte, 16)
256			off, err := packDataAplPrefix(ap, msg, 0)
257			if err == nil {
258				t.Fatal("expected error, got none")
259			}
260			if off != len(msg) {
261				t.Fatalf("expected %d, got %d", len(msg), off)
262			}
263		})
264	}
265}
266
267func TestPackDataAplPrefix_BufferBounds(t *testing.T) {
268	ap := &APLPrefix{
269		Negation: false,
270		Network: net.IPNet{
271			IP:   net.ParseIP("2001:db8::"),
272			Mask: net.CIDRMask(32, 128),
273		},
274	}
275	wire := []byte{0x00, 0x02, 0x20, 0x04, 0x20, 0x01, 0x0d, 0xb8}
276
277	t.Run("small", func(t *testing.T) {
278		msg := make([]byte, len(wire))
279		_, err := packDataAplPrefix(ap, msg, 1) // offset
280		if err == nil {
281			t.Fatal("expected error, got none")
282		}
283	})
284
285	t.Run("exact fit", func(t *testing.T) {
286		msg := make([]byte, len(wire))
287		off, err := packDataAplPrefix(ap, msg, 0)
288		if err != nil {
289			t.Fatalf("expected no error, got %q", err)
290		}
291		if !bytes.Equal(wire, msg[:off]) {
292			t.Fatalf("expected %02x, got %02x", wire, msg[:off])
293		}
294	})
295}
296
297func TestPackDataApl(t *testing.T) {
298	in := []APLPrefix{
299		{
300			Negation: true,
301			Network: net.IPNet{
302				IP:   net.ParseIP("198.51.0.0").To4(),
303				Mask: net.CIDRMask(16, 32),
304			},
305		},
306		{
307			Negation: false,
308			Network: net.IPNet{
309				IP:   net.ParseIP("2001:db8:beef::"),
310				Mask: net.CIDRMask(48, 128),
311			},
312		},
313	}
314	expect := []byte{
315		// 1:192.51.0.0/16
316		0x00, 0x01, 0x10, 0x82, 0xc6, 0x33,
317		// 2:2001:db8:beef::0/48
318		0x00, 0x02, 0x30, 0x06, 0x20, 0x01, 0x0d, 0xb8, 0xbe, 0xef,
319	}
320
321	msg := make([]byte, 32)
322	off, err := packDataApl(in, msg, 0)
323	if err != nil {
324		t.Fatalf("expected no error, got %q", err)
325	}
326	if !bytes.Equal(expect, msg[:off]) {
327		t.Fatalf("expected %02x, got %02x", expect, msg[:off])
328	}
329}
330
331func TestUnpackDataAplPrefix(t *testing.T) {
332	tests := []struct {
333		name     string
334		wire     []byte
335		negation bool
336		ip       net.IP
337		mask     net.IPMask
338	}{
339		{
340			"1:192.0.2.0/24",
341			[]byte{0x00, 0x01, 0x18, 0x03, 192, 0, 2},
342			false,
343			net.ParseIP("192.0.2.0").To4(),
344			net.CIDRMask(24, 32),
345		},
346		{
347			"2:2001:db8::/32",
348			[]byte{0x00, 0x02, 0x20, 0x04, 0x20, 0x01, 0x0d, 0xb8},
349			false,
350			net.ParseIP("2001:db8::"),
351			net.CIDRMask(32, 128),
352		},
353		{
354			"!2:2001:db8:8000::/33",
355			[]byte{0x00, 0x02, 0x21, 0x85, 0x20, 0x01, 0x0d, 0xb8, 0x80},
356			true,
357			net.ParseIP("2001:db8:8000::"),
358			net.CIDRMask(33, 128),
359		},
360		{
361			"1:0.0.0.0/0",
362			[]byte{0x00, 0x01, 0x00, 0x00},
363			false,
364			net.ParseIP("0.0.0.0").To4(),
365			net.CIDRMask(0, 32),
366		},
367	}
368	for _, tt := range tests {
369		t.Run(tt.name, func(t *testing.T) {
370			got, off, err := unpackDataAplPrefix(tt.wire, 0)
371			if err != nil {
372				t.Fatalf("expected no error, got %q", err)
373			}
374			if off != len(tt.wire) {
375				t.Fatalf("expected offset %d, got %d", len(tt.wire), off)
376			}
377			if got.Negation != tt.negation {
378				t.Errorf("expected negation %v, got %v", tt.negation, got.Negation)
379			}
380			if !bytes.Equal(got.Network.IP, tt.ip) {
381				t.Errorf("expected IP %02x, got %02x", tt.ip, got.Network.IP)
382			}
383			if !bytes.Equal(got.Network.Mask, tt.mask) {
384				t.Errorf("expected mask %02x, got %02x", tt.mask, got.Network.Mask)
385			}
386		})
387	}
388}
389
390func TestUnpackDataAplPrefix_Errors(t *testing.T) {
391	tests := []struct {
392		name string
393		wire []byte
394	}{
395		{
396			"incomplete header",
397			[]byte{0x00, 0x01, 0x18},
398		},
399		{
400			"unrecognized family",
401			[]byte{0x00, 0x03, 0x00, 0x00},
402		},
403		{
404			"prefix length exceeded",
405			[]byte{0x00, 0x01, 0x21, 0x04, 192, 0, 2, 0},
406		},
407		{
408			"address with extra byte",
409			[]byte{0x00, 0x01, 0x10, 0x03, 192, 0, 2},
410		},
411		{
412			"incomplete buffer",
413			[]byte{0x00, 0x01, 0x10, 0x02, 192},
414		},
415		{
416			"extra bits set",
417			[]byte{0x00, 0x01, 22, 0x03, 192, 0, 2},
418		},
419		{
420			"afdlen invalid",
421			[]byte{0x00, 0x01, 22, 0x05, 192, 0, 2, 0, 0},
422		},
423	}
424	for _, tt := range tests {
425		t.Run(tt.name, func(t *testing.T) {
426			_, _, err := unpackDataAplPrefix(tt.wire, 0)
427			if err == nil {
428				t.Fatal("expected error, got none")
429			}
430		})
431	}
432}
433
434func TestUnpackDataApl(t *testing.T) {
435	wire := []byte{
436		// 2:2001:db8:cafe:4200:0/56
437		0x00, 0x02, 0x38, 0x07, 0x20, 0x01, 0x0d, 0xb8, 0xca, 0xfe, 0x42,
438		// 1:192.0.2.0/24
439		0x00, 0x01, 0x18, 0x03, 192, 0, 2,
440		// !1:192.0.2.128/25
441		0x00, 0x01, 0x19, 0x84, 192, 0, 2, 128,
442		// 1:10.0.0.0/24
443		0x00, 0x01, 0x18, 0x01, 0x0a,
444		// !1:10.0.0.1/32
445		0x00, 0x01, 0x20, 0x84, 0x0a, 0, 0, 1,
446		// !1:0.0.0.0/0
447		0x00, 0x01, 0x00, 0x80,
448		// 2::0/0
449		0x00, 0x02, 0x00, 0x00,
450	}
451	expect := []APLPrefix{
452		{
453			Negation: false,
454			Network: net.IPNet{
455				IP:   net.ParseIP("2001:db8:cafe:4200::"),
456				Mask: net.CIDRMask(56, 128),
457			},
458		},
459		{
460			Negation: false,
461			Network: net.IPNet{
462				IP:   net.ParseIP("192.0.2.0").To4(),
463				Mask: net.CIDRMask(24, 32),
464			},
465		},
466		{
467			Negation: true,
468			Network: net.IPNet{
469				IP:   net.ParseIP("192.0.2.128").To4(),
470				Mask: net.CIDRMask(25, 32),
471			},
472		},
473		{
474			Negation: false,
475			Network: net.IPNet{
476				IP:   net.ParseIP("10.0.0.0").To4(),
477				Mask: net.CIDRMask(24, 32),
478			},
479		},
480		{
481			Negation: true,
482			Network: net.IPNet{
483				IP:   net.ParseIP("10.0.0.1").To4(),
484				Mask: net.CIDRMask(32, 32),
485			},
486		},
487		{
488			Negation: true,
489			Network: net.IPNet{
490				IP:   net.ParseIP("0.0.0.0").To4(),
491				Mask: net.CIDRMask(0, 32),
492			},
493		},
494		{
495			Negation: false,
496			Network: net.IPNet{
497				IP:   net.ParseIP("::").To16(),
498				Mask: net.CIDRMask(0, 128),
499			},
500		},
501	}
502
503	got, off, err := unpackDataApl(wire, 0)
504	if err != nil {
505		t.Fatalf("expected no error, got %q", err)
506	}
507	if off != len(wire) {
508		t.Fatalf("expected offset %d, got %d", len(wire), off)
509	}
510	if len(got) != len(expect) {
511		t.Fatalf("expected %d prefixes, got %d", len(expect), len(got))
512	}
513	for i, exp := range expect {
514		if got[i].Negation != exp.Negation {
515			t.Errorf("[%d] expected negation %v, got %v", i, exp.Negation, got[i].Negation)
516		}
517		if !bytes.Equal(got[i].Network.IP, exp.Network.IP) {
518			t.Errorf("[%d] expected IP %02x, got %02x", i, exp.Network.IP, got[i].Network.IP)
519		}
520		if !bytes.Equal(got[i].Network.Mask, exp.Network.Mask) {
521			t.Errorf("[%d] expected mask %02x, got %02x", i, exp.Network.Mask, got[i].Network.Mask)
522		}
523	}
524}
525