1package dns
2
3import (
4	"bytes"
5	"net"
6	"testing"
7)
8
9// TestPacketDataNsec tests generated using fuzz.go and with a message pack
10// containing the following bytes: 0000\x00\x00000000\x00\x002000000\x0060000\x00\x130000000000000000000"
11// That bytes sequence created the overflow error and further permutations of that sequence were able to trigger
12// the other code paths.
13func TestPackDataNsec(t *testing.T) {
14	type args struct {
15		bitmap []uint16
16		msg    []byte
17		off    int
18	}
19	tests := []struct {
20		name       string
21		args       args
22		want       int
23		wantErr    bool
24		wantErrMsg string
25	}{
26		{
27			name: "overflow",
28			args: args{
29				bitmap: []uint16{
30					8962, 8963, 8970, 8971, 8978, 8979,
31					8986, 8987, 8994, 8995, 9002, 9003,
32					9010, 9011, 9018, 9019, 9026, 9027,
33					9034, 9035, 9042, 9043, 9050, 9051,
34					9058, 9059, 9066,
35				},
36				msg: []byte{
37					48, 48, 48, 48, 0, 0, 0,
38					1, 0, 0, 0, 0, 0, 0, 50,
39					48, 48, 48, 48, 48, 48,
40					0, 54, 48, 48, 48, 48,
41					0, 19, 48, 48,
42				},
43				off: 48,
44			},
45			wantErr:    true,
46			wantErrMsg: "dns: overflow packing nsec",
47			want:       31,
48		},
49		{
50			name: "disordered nsec bits",
51			args: args{
52				bitmap: []uint16{
53					8962,
54					0,
55				},
56				msg: []byte{
57					48, 48, 48, 48, 0, 0, 0, 1, 0, 0, 0, 0,
58					0, 0, 50, 48, 48, 48, 48, 48, 48, 0, 54, 48,
59					48, 48, 48, 0, 19, 48, 48, 48, 48, 48, 48, 0,
60					0, 0, 1, 0, 0, 0, 0, 0, 0, 50, 48, 48,
61					48, 48, 48, 48, 0, 54, 48, 48, 48, 48, 0, 19,
62					48, 48, 48, 48, 48, 48, 0, 0, 0, 1, 0, 0,
63					0, 0, 0, 0, 50, 48, 48, 48, 48, 48, 48, 0,
64					54, 48, 48, 48, 48, 0, 19, 48, 48, 48, 48, 48,
65					48, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 50,
66					48, 48, 48, 48, 48, 48, 0, 54, 48, 48, 48, 48,
67					0, 19, 48, 48, 48, 48, 48, 48, 0, 0, 0, 1,
68					0, 0, 0, 0, 0, 0, 50, 48, 48, 48, 48, 48,
69					48, 0, 54, 48, 48, 48, 48, 0, 19, 48, 48,
70				},
71				off: 0,
72			},
73			wantErr:    true,
74			wantErrMsg: "dns: nsec bits out of order",
75			want:       155,
76		},
77		{
78			name: "simple message with only one window",
79			args: args{
80				bitmap: []uint16{
81					0,
82				},
83				msg: []byte{
84					48, 48, 48, 48, 0, 0,
85					0, 1, 0, 0, 0, 0,
86					0, 0, 50, 48, 48, 48,
87					48, 48, 48, 0, 54, 48,
88					48, 48, 48, 0, 19, 48, 48,
89				},
90				off: 0,
91			},
92			wantErr: false,
93			want:    3,
94		},
95	}
96	for _, tt := range tests {
97		t.Run(tt.name, func(t *testing.T) {
98			got, err := packDataNsec(tt.args.bitmap, tt.args.msg, tt.args.off)
99			if (err != nil) != tt.wantErr {
100				t.Errorf("packDataNsec() error = %v, wantErr %v", err, tt.wantErr)
101				return
102			}
103			if err != nil && tt.wantErrMsg != err.Error() {
104				t.Errorf("packDataNsec() error msg = %v, wantErrMsg %v", err.Error(), tt.wantErrMsg)
105				return
106			}
107			if got != tt.want {
108				t.Errorf("packDataNsec() = %v, want %v", got, tt.want)
109			}
110		})
111	}
112}
113
114func TestUnpackString(t *testing.T) {
115	msg := []byte("\x00abcdef\x0f\\\"ghi\x04mmm\x7f")
116	msg[0] = byte(len(msg) - 1)
117
118	got, _, err := unpackString(msg, 0)
119	if err != nil {
120		t.Fatal(err)
121	}
122
123	if want := `abcdef\015\\\"ghi\004mmm\127`; want != got {
124		t.Errorf("expected %q, got %q", want, got)
125	}
126}
127
128func BenchmarkUnpackString(b *testing.B) {
129	b.Run("Escaped", func(b *testing.B) {
130		msg := []byte("\x00abcdef\x0f\\\"ghi\x04mmm")
131		msg[0] = byte(len(msg) - 1)
132
133		for n := 0; n < b.N; n++ {
134			got, _, err := unpackString(msg, 0)
135			if err != nil {
136				b.Fatal(err)
137			}
138
139			if want := `abcdef\015\\\"ghi\004mmm`; want != got {
140				b.Errorf("expected %q, got %q", want, got)
141			}
142		}
143	})
144	b.Run("Unescaped", func(b *testing.B) {
145		msg := []byte("\x00large.example.com")
146		msg[0] = byte(len(msg) - 1)
147
148		for n := 0; n < b.N; n++ {
149			got, _, err := unpackString(msg, 0)
150			if err != nil {
151				b.Fatal(err)
152			}
153
154			if want := "large.example.com"; want != got {
155				b.Errorf("expected %q, got %q", want, got)
156			}
157		}
158	})
159}
160
161func TestPackDataAplPrefix(t *testing.T) {
162	tests := []struct {
163		name     string
164		negation bool
165		ip       net.IP
166		mask     net.IPMask
167		expect   []byte
168	}{
169		{
170			"1:192.0.2.0/24",
171			false,
172			net.ParseIP("192.0.2.0").To4(),
173			net.CIDRMask(24, 32),
174			[]byte{0x00, 0x01, 0x18, 0x03, 192, 0, 2},
175		},
176		{
177			"2:2001:db8:cafe::0/48",
178			false,
179			net.ParseIP("2001:db8:cafe::"),
180			net.CIDRMask(48, 128),
181			[]byte{0x00, 0x02, 0x30, 0x06, 0x20, 0x01, 0x0d, 0xb8, 0xca, 0xfe},
182		},
183		{
184			"!2:2001:db8::/32",
185			true,
186			net.ParseIP("2001:db8::"),
187			net.CIDRMask(32, 128),
188			[]byte{0x00, 0x02, 0x20, 0x84, 0x20, 0x01, 0x0d, 0xb8},
189		},
190		{
191			"normalize 1:198.51.103.255/22",
192			false,
193			net.ParseIP("198.51.103.255").To4(),
194			net.CIDRMask(22, 32),
195			[]byte{0x00, 0x01, 0x16, 0x03, 198, 51, 100}, // 1:198.51.100.0/22
196		},
197	}
198	for _, tt := range tests {
199		t.Run(tt.name, func(t *testing.T) {
200			ap := &APLPrefix{
201				Negation: tt.negation,
202				Network:  net.IPNet{IP: tt.ip, Mask: tt.mask},
203			}
204			out := make([]byte, 16)
205			off, err := packDataAplPrefix(ap, out, 0)
206			if err != nil {
207				t.Fatalf("expected no error, got %q", err)
208			}
209			if !bytes.Equal(tt.expect, out[:off]) {
210				t.Fatalf("expected output %02x, got %02x", tt.expect, out[:off])
211			}
212		})
213	}
214}
215
216func TestPackDataAplPrefix_Failures(t *testing.T) {
217	tests := []struct {
218		name string
219		ip   net.IP
220		mask net.IPMask
221	}{
222		{
223			"family mismatch",
224			net.ParseIP("2001:db8::"),
225			net.CIDRMask(24, 32),
226		},
227		{
228			"unrecognized family",
229			[]byte{0x42},
230			[]byte{0xff},
231		},
232	}
233	for _, tt := range tests {
234		t.Run(tt.name, func(t *testing.T) {
235			ap := &APLPrefix{Network: net.IPNet{IP: tt.ip, Mask: tt.mask}}
236			msg := make([]byte, 16)
237			off, err := packDataAplPrefix(ap, msg, 0)
238			if err == nil {
239				t.Fatal("expected error, got none")
240			}
241			if off != len(msg) {
242				t.Fatalf("expected %d, got %d", len(msg), off)
243			}
244		})
245	}
246}
247
248func TestPackDataAplPrefix_BufferBounds(t *testing.T) {
249	ap := &APLPrefix{
250		Negation: false,
251		Network: net.IPNet{
252			IP:   net.ParseIP("2001:db8::"),
253			Mask: net.CIDRMask(32, 128),
254		},
255	}
256	wire := []byte{0x00, 0x02, 0x20, 0x04, 0x20, 0x01, 0x0d, 0xb8}
257
258	t.Run("small", func(t *testing.T) {
259		msg := make([]byte, len(wire))
260		_, err := packDataAplPrefix(ap, msg, 1) // offset
261		if err == nil {
262			t.Fatal("expected error, got none")
263		}
264	})
265
266	t.Run("exact fit", func(t *testing.T) {
267		msg := make([]byte, len(wire))
268		off, err := packDataAplPrefix(ap, msg, 0)
269		if err != nil {
270			t.Fatalf("expected no error, got %q", err)
271		}
272		if !bytes.Equal(wire, msg[:off]) {
273			t.Fatalf("expected %02x, got %02x", wire, msg[:off])
274		}
275	})
276}
277
278func TestPackDataApl(t *testing.T) {
279	in := []APLPrefix{
280		APLPrefix{
281			Negation: true,
282			Network: net.IPNet{
283				IP:   net.ParseIP("198.51.0.0").To4(),
284				Mask: net.CIDRMask(16, 32),
285			},
286		},
287		APLPrefix{
288			Negation: false,
289			Network: net.IPNet{
290				IP:   net.ParseIP("2001:db8:beef::"),
291				Mask: net.CIDRMask(48, 128),
292			},
293		},
294	}
295	expect := []byte{
296		// 1:192.51.0.0/16
297		0x00, 0x01, 0x10, 0x82, 0xc6, 0x33,
298		// 2:2001:db8:beef::0/48
299		0x00, 0x02, 0x30, 0x06, 0x20, 0x01, 0x0d, 0xb8, 0xbe, 0xef,
300	}
301
302	msg := make([]byte, 32)
303	off, err := packDataApl(in, msg, 0)
304	if err != nil {
305		t.Fatalf("expected no error, got %q", err)
306	}
307	if !bytes.Equal(expect, msg[:off]) {
308		t.Fatalf("expected %02x, got %02x", expect, msg[:off])
309	}
310}
311
312func TestUnpackDataAplPrefix(t *testing.T) {
313	tests := []struct {
314		name     string
315		wire     []byte
316		negation bool
317		ip       net.IP
318		mask     net.IPMask
319	}{
320		{
321			"1:192.0.2.0/24",
322			[]byte{0x00, 0x01, 0x18, 0x03, 192, 0, 2},
323			false,
324			net.ParseIP("192.0.2.0").To4(),
325			net.CIDRMask(24, 32),
326		},
327		{
328			"2:2001:db8::/32",
329			[]byte{0x00, 0x02, 0x20, 0x04, 0x20, 0x01, 0x0d, 0xb8},
330			false,
331			net.ParseIP("2001:db8::"),
332			net.CIDRMask(32, 128),
333		},
334		{
335			"!2:2001:db8:8000::/33",
336			[]byte{0x00, 0x02, 0x21, 0x85, 0x20, 0x01, 0x0d, 0xb8, 0x80},
337			true,
338			net.ParseIP("2001:db8:8000::"),
339			net.CIDRMask(33, 128),
340		},
341		{
342			"1:0.0.0.0/0",
343			[]byte{0x00, 0x01, 0x00, 0x00},
344			false,
345			net.ParseIP("0.0.0.0").To4(),
346			net.CIDRMask(0, 32),
347		},
348	}
349	for _, tt := range tests {
350		t.Run(tt.name, func(t *testing.T) {
351			got, off, err := unpackDataAplPrefix(tt.wire, 0)
352			if err != nil {
353				t.Fatalf("expected no error, got %q", err)
354			}
355			if off != len(tt.wire) {
356				t.Fatalf("expected offset %d, got %d", len(tt.wire), off)
357			}
358			if got.Negation != tt.negation {
359				t.Errorf("expected negation %v, got %v", tt.negation, got.Negation)
360			}
361			if !bytes.Equal(got.Network.IP, tt.ip) {
362				t.Errorf("expected IP %02x, got %02x", tt.ip, got.Network.IP)
363			}
364			if !bytes.Equal(got.Network.Mask, tt.mask) {
365				t.Errorf("expected mask %02x, got %02x", tt.mask, got.Network.Mask)
366			}
367		})
368	}
369}
370
371func TestUnpackDataAplPrefix_Errors(t *testing.T) {
372	tests := []struct {
373		name string
374		wire []byte
375	}{
376		{
377			"incomplete header",
378			[]byte{0x00, 0x01, 0x18},
379		},
380		{
381			"unrecognized family",
382			[]byte{0x00, 0x03, 0x00, 0x00},
383		},
384		{
385			"prefix length exceeded",
386			[]byte{0x00, 0x01, 0x21, 0x04, 192, 0, 2, 0},
387		},
388		{
389			"address with extra byte",
390			[]byte{0x00, 0x01, 0x10, 0x03, 192, 0, 2},
391		},
392		{
393			"incomplete buffer",
394			[]byte{0x00, 0x01, 0x10, 0x02, 192},
395		},
396		{
397			"extra bits set",
398			[]byte{0x00, 0x01, 22, 0x03, 192, 0, 2},
399		},
400	}
401	for _, tt := range tests {
402		t.Run(tt.name, func(t *testing.T) {
403			_, _, err := unpackDataAplPrefix(tt.wire, 0)
404			if err == nil {
405				t.Fatal("expected error, got none")
406			}
407		})
408	}
409}
410
411func TestUnpackDataApl(t *testing.T) {
412	wire := []byte{
413		// 2:2001:db8:cafe:4200:0/56
414		0x00, 0x02, 0x38, 0x07, 0x20, 0x01, 0x0d, 0xb8, 0xca, 0xfe, 0x42,
415		// 1:192.0.2.0/24
416		0x00, 0x01, 0x18, 0x03, 192, 0, 2,
417		// !1:192.0.2.128/25
418		0x00, 0x01, 0x19, 0x84, 192, 0, 2, 128,
419	}
420	expect := []APLPrefix{
421		{
422			Negation: false,
423			Network: net.IPNet{
424				IP:   net.ParseIP("2001:db8:cafe:4200::"),
425				Mask: net.CIDRMask(56, 128),
426			},
427		},
428		{
429			Negation: false,
430			Network: net.IPNet{
431				IP:   net.ParseIP("192.0.2.0").To4(),
432				Mask: net.CIDRMask(24, 32),
433			},
434		},
435		{
436			Negation: true,
437			Network: net.IPNet{
438				IP:   net.ParseIP("192.0.2.128").To4(),
439				Mask: net.CIDRMask(25, 32),
440			},
441		},
442	}
443
444	got, off, err := unpackDataApl(wire, 0)
445	if err != nil {
446		t.Fatalf("expected no error, got %q", err)
447	}
448	if off != len(wire) {
449		t.Fatalf("expected offset %d, got %d", len(wire), off)
450	}
451	if len(got) != len(expect) {
452		t.Fatalf("expected %d prefixes, got %d", len(expect), len(got))
453	}
454	for i, exp := range expect {
455		if got[i].Negation != exp.Negation {
456			t.Errorf("[%d] expected negation %v, got %v", i, exp.Negation, got[i].Negation)
457		}
458		if !bytes.Equal(got[i].Network.IP, exp.Network.IP) {
459			t.Errorf("[%d] expected IP %02x, got %02x", i, exp.Network.IP, got[i].Network.IP)
460		}
461		if !bytes.Equal(got[i].Network.Mask, exp.Network.Mask) {
462			t.Errorf("[%d] expected mask %02x, got %02x", i, exp.Network.Mask, got[i].Network.Mask)
463		}
464	}
465}
466