1// Copyright 2015, Joe Tsai. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE.md file.
4
5package prefix
6
7import (
8	"bufio"
9	"bytes"
10	"io"
11	"math"
12	"sort"
13	"strings"
14	"testing"
15
16	"github.com/dsnet/compress"
17	"github.com/dsnet/compress/internal"
18	"github.com/dsnet/compress/internal/testutil"
19)
20
21const testSize = 1000
22
23var (
24	testVector = testutil.MustDecodeHex("" +
25		"f795bd4a52e29ed713d3ff6eef91fc7f0735428e2bff9ee8e85f00bfa9c4aa3a" +
26		"beab8e3d59248dd6b5ff1e3cc1f49f024bff6edffd3f04ea797b45cc128eabe9" +
27		"9eb55197a2ff9ea118f2bf0c076ab6e095ffaea789ff063e7ff816d299c31411" +
28		"90baffeef5ff355b79ff9ed1d7d3fe8f01ff1e5a56f32f02b9bef6ae3932ff2e" +
29		"eeff057637f086ebbbff9e62f96915af14ffabcb1381f024a50bf86425ffee80" +
30		"7bffeefadf0327534c04b82835a07e09821915ce8cff6ea4dfe60de7f29e7b0f" +
31		"16ffae5f8190ff16ff9e64f81f0285ffee0daa7f04dd383667c3ff1eebffff4f" +
32		"2cc3553c4236ff8efd6b57e7c69932351ebf858722b2ffae2399d913ffce43ff" +
33		"29e86bc831330a66ff1e982fa98286ccbc29c6d03429b492349bff1e83ff4700" +
34		"37ec814173da6a0c79ffee99ff47e0e632e85ccdc44f21ff6eedbf0094e11b07" +
35		"99a0cfe5b789ee58ff1e462bfe75eb1aae7754bc62b37f0e699094ff9e95ff5f" +
36		"0461b381c3557076df1f4a6d2ff3faff9ef9af0246f6e9abf9e4cd0be7ff9ee9" +
37		"90eeff1bd24a0d7b67ffee2dff5b49b8b0828b0d51669fcb04ff2e2cbbd30181" +
38		"c06c1d39137a4bff97a92cff4e1ffe639c0589759657c71a0bceff1ed83f0662" +
39		"6f6eea066aff6e1476483462a2c0a213ffaefa7f00faecc24e592e9eff6e72da" +
40		"18ff23f9a598c26d479f52a6ce97e7891556ff1efcbf05422263c89384f18b25" +
41		"f1c5031c78ff9e9dcbc4178f639e298a0e8e8a40ff2e0042300b76b9ff6e2af5" +
42		"fadf023288d043fcff1e0ffe39ff9e75ff7ee8c8a68dd063e4ed7d5b7fbbe8c3" +
43		"3aff1e66f18f03ceb43fb0a49fcb78cb086b107cf8d3ffae6981bdff1f744f57" +
44		"a3d990881b50c3ffee74bb2ffee89d3ef0419f278eff1e22710f3ed9ff9e3b33" +
45		"f98f02c5b575cfb5ecd408ffeefed701ff9eb7019a7f2fd610c4670d667dc0f5" +
46		"bbff6ed0578cff2200bffbd22521f5c59430ff6ef63f0c1c5f64bce753f261ff" +
47		"2d2e9de9b4feff1eefc0fe9b8dc0d5d91cff1e8b84c9ffb701883442f5ff1e18" +
48		"0aff7da3d0c04abe9b9fbda906f55fe5f6ceff4e21fa0f000185990b348ba2d5" +
49		"a88903207be2ffeefe6f0287bcfbffee5abcf39f030a95b5d6ff9e8bcb7f0ef0" +
50		"a26bb9fa463591589de79bd633ffaecf7f122fb470144fff9e08ecc57f16f2a3" +
51		"6ab646e6f883b0d46b2e7eafadff1ea1ff5f07be86ff1e8a61f504b469527523" +
52		"85339cff9e90e2f316549a364bf62469d053ad0b1a552c1affeecb27f13f0875" +
53		"5a008a5f064098afff1efe2f01ce0808826b4d5f1fa3d71c1726ff6e9064f13f" +
54		"01068c5269fb7a99e55c7869ff9e22f8bf02f4cb0d5e3ae00d3dc3e1ff9eddbd" +
55		"520ffac3f8a16bb0b1e7f5d050e0ffee950d10ff238fb3daade442040dd0e8cf" +
56		"f50cffeed2fa5f008ebbd4dacbeb8fcd7c5ba440bf",
57	)
58
59	testCodes = func() (codes PrefixCodes) {
60		for i := 0; i < 100; i++ {
61			codes = append(codes, PrefixCode{Sym: uint32(len(codes)), Cnt: 0})
62		}
63		for i := 0; i < 25; i++ {
64			codes = append(codes, PrefixCode{Sym: uint32(len(codes)), Cnt: 10})
65		}
66		for i := 0; i < 5; i++ {
67			codes = append(codes, PrefixCode{Sym: uint32(len(codes)), Cnt: 1000})
68		}
69		codes.SortByCount()
70		if err := GenerateLengths(codes, 15); err != nil {
71			panic(err)
72		}
73		codes.SortBySymbol()
74		if err := GeneratePrefixes(codes); err != nil {
75			panic(err)
76		}
77		return codes
78	}()
79
80	testRanges = MakeRangeCodes(0, []uint{0, 1, 2, 3, 4})
81)
82
83func TestReader(t *testing.T) {
84	readers := map[string]func([]byte) io.Reader{
85		"io.Reader": func(b []byte) io.Reader {
86			return struct{ io.Reader }{bytes.NewReader(b)}
87		},
88		"bytes.Buffer": func(b []byte) io.Reader {
89			return bytes.NewBuffer(b)
90		},
91		"bytes.Reader": func(b []byte) io.Reader {
92			return bytes.NewReader(b)
93		},
94		"string.Reader": func(b []byte) io.Reader {
95			return strings.NewReader(string(b))
96		},
97		"compress.ByteReader": func(b []byte) io.Reader {
98			return struct{ compress.ByteReader }{bytes.NewReader(b)}
99		},
100		"compress.BufferedReader": func(b []byte) io.Reader {
101			return struct{ compress.BufferedReader }{bufio.NewReader(bytes.NewReader(b))}
102		},
103	}
104	endians := map[string]bool{"littleEndian": false, "bigEndian": true}
105
106	var i int
107	for ne, endian := range endians {
108		for nr, newReader := range readers {
109			var br Reader
110			buf := make([]byte, len(testVector))
111			copy(buf, testVector)
112			if endian {
113				for i, c := range buf {
114					buf[i] = internal.ReverseLUT[c]
115				}
116			}
117			rd := newReader(buf)
118			br.Init(rd, endian)
119
120			var pd Decoder
121			pd.Init(testCodes)
122
123			r := testutil.NewRand(0)
124		loop:
125			for j := 0; br.BitsRead() < 8*testSize; j++ {
126				switch j % 4 {
127				case 0:
128					// Test unaligned Read.
129					if br.numBits%8 != 0 {
130						cnt, err := br.Read([]byte{0})
131						if cnt != 0 {
132							t.Errorf("test %d, %s %s, write count mismatch: got %d, want 0", i, ne, nr, cnt)
133							break loop
134						}
135						if err == nil {
136							t.Errorf("test %d, %s %s, unexpected write success", i, ne, nr)
137							break loop
138						}
139					}
140
141					pads := br.ReadPads()
142					if pads != 0 {
143						t.Errorf("test %d, %s %s, bit padding mismatch: got %d, want 0", i, ne, nr, pads)
144						break loop
145					}
146					want := r.Bytes(r.Intn(16))
147					if endian {
148						for i, c := range want {
149							want[i] = internal.ReverseLUT[c]
150						}
151					}
152					got := make([]byte, len(want))
153					cnt, err := io.ReadFull(&br, got)
154					if cnt != len(want) {
155						t.Errorf("test %d, %s %s, read count mismatch: got %d, want %d", i, ne, nr, cnt, len(want))
156						break loop
157					}
158					if err != nil {
159						t.Errorf("test %d, %s %s, unexpected read error: got %v", i, ne, nr, err)
160						break loop
161					}
162					if bytes.Compare(want, got) != 0 {
163						t.Errorf("test %d, %s %s, read bytes mismatch:\ngot  %x\nwant %x", i, ne, nr, got, want)
164						break loop
165					}
166				case 1:
167					n := int(testRanges.End() - testRanges.Base())
168					want := uint(testRanges.Base() + uint32(r.Intn(n)))
169					got := br.ReadOffset(&pd, testRanges)
170					if got != want {
171						t.Errorf("test %d, %s %s, read offset mismatch: got %d, want %d", i, ne, nr, got, want)
172						break loop
173					}
174				case 2:
175					nb := uint(r.Intn(24))
176					want := uint(r.Int() & (1<<nb - 1))
177					got, ok := br.TryReadBits(nb)
178					if !ok {
179						got = br.ReadBits(nb)
180					}
181					if got != want {
182						t.Errorf("test %d, %s %s, read bits mismatch: got %d, want %d", i, ne, nr, got, want)
183						break loop
184					}
185				case 3:
186					want := uint(testCodes[r.Intn(len(testCodes))].Sym)
187					got, ok := br.TryReadSymbol(&pd)
188					if !ok {
189						got = br.ReadSymbol(&pd)
190					}
191					if got != want {
192						t.Errorf("test %d, %s %s, read symbol mismatch: got %d, want %d", i, ne, nr, got, want)
193						break loop
194					}
195				}
196			}
197
198			pads := br.ReadPads()
199			if pads != 0 {
200				t.Errorf("test %d, %s %s, bit padding mismatch: got %d, want 0", i, ne, nr, pads)
201			}
202			ofs, err := br.Flush()
203			if br.numBits != 0 {
204				t.Errorf("test %d, %s, bit buffer not drained: got %d, want < 8", i, ne, br.numBits)
205			}
206			if ofs != int64(len(testVector)) {
207				t.Errorf("test %d, %s, offset mismatch: got %d, want %d", i, ne, ofs, len(testVector))
208			}
209			if err != nil {
210				t.Errorf("test %d, %s, unexpected flush error: got %v", i, ne, err)
211			}
212			i++
213		}
214	}
215}
216
217func TestWriter(t *testing.T) {
218	endians := map[string]bool{"littleEndian": false, "bigEndian": true}
219
220	var i int
221	for ne, endian := range endians {
222		var bw Writer
223		wr := bytes.NewBuffer(nil)
224		bw.Init(wr, endian)
225
226		var pe Encoder
227		pe.Init(testCodes)
228
229		var re RangeEncoder
230		re.Init(testRanges)
231
232		r := testutil.NewRand(0)
233	loop:
234		for j := 0; bw.BitsWritten() < 8*testSize; j++ {
235			switch j % 4 {
236			case 0:
237				// Test unaligned Write.
238				if bw.numBits%8 != 0 {
239					cnt, err := bw.Write([]byte{0})
240					if cnt != 0 {
241						t.Errorf("test %d, %s, write count mismatch: got %d, want 0", i, ne, cnt)
242						break loop
243					}
244					if err == nil {
245						t.Errorf("test %d, %s, unexpected write success", i, ne)
246						break loop
247					}
248				}
249
250				bw.WritePads(0)
251				b := r.Bytes(r.Intn(16))
252				if endian {
253					for i, c := range b {
254						b[i] = internal.ReverseLUT[c]
255					}
256				}
257				cnt, err := bw.Write(b)
258				if cnt != len(b) {
259					t.Errorf("test %d, %s, write count mismatch: got %d, want %d", i, ne, cnt, len(b))
260					break loop
261				}
262				if err != nil {
263					t.Errorf("test %d, %s, unexpected write error: got %v", i, ne, err)
264					break loop
265				}
266			case 1:
267				n := int(testRanges.End() - testRanges.Base())
268				ofs := uint(testRanges.Base() + uint32(r.Intn(n)))
269				bw.WriteOffset(ofs, &pe, &re)
270			case 2:
271				nb := uint(r.Intn(24))
272				val := uint(r.Int() & (1<<nb - 1))
273				ok := bw.TryWriteBits(val, nb)
274				if !ok {
275					bw.WriteBits(val, nb)
276				}
277			case 3:
278				sym := uint(testCodes[r.Intn(len(testCodes))].Sym)
279				ok := bw.TryWriteSymbol(sym, &pe)
280				if !ok {
281					bw.WriteSymbol(sym, &pe)
282				}
283			}
284		}
285
286		// Flush the Writer.
287		bw.WritePads(0)
288		ofs, err := bw.Flush()
289		if bw.numBits != 0 {
290			t.Errorf("test %d, %s, bit buffer not drained: got %d, want 0", i, ne, bw.numBits)
291		}
292		if bw.cntBuf != 0 {
293			t.Errorf("test %d, %s, byte buffer not drained: got %d, want 0", i, ne, bw.cntBuf)
294		}
295		if ofs != int64(wr.Len()) {
296			t.Errorf("test %d, %s, offset mismatch: got %d, want %d", i, ne, ofs, wr.Len())
297		}
298		if err != nil {
299			t.Errorf("test %d, %s, unexpected flush error: got %v", i, ne, err)
300		}
301
302		// Check that output matches expected.
303		buf := wr.Bytes()
304		if endian {
305			for i, c := range buf {
306				buf[i] = internal.ReverseLUT[c]
307			}
308		}
309		if bytes.Compare(buf, testVector) != 0 {
310			t.Errorf("test %d, %s, output string mismatch:\ngot  %x\nwant %x", i, ne, buf, testVector)
311		}
312		i++
313	}
314}
315
316func TestGenerate(t *testing.T) {
317	r := testutil.NewRand(0)
318	makeCodes := func(freqs []uint) PrefixCodes {
319		codes := make(PrefixCodes, len(freqs))
320		for i, j := range r.Perm(len(freqs)) {
321			codes[i] = PrefixCode{Sym: uint32(i), Cnt: uint32(freqs[j])}
322		}
323		codes.SortByCount()
324		return codes
325	}
326
327	vectors := []struct {
328		maxBits uint // Maximum prefix bit-length (0 to skip GenerateLengths)
329		input   PrefixCodes
330		valid   bool
331	}{{
332		maxBits: 15,
333		input:   makeCodes([]uint{}),
334		valid:   true,
335	}, {
336		maxBits: 15,
337		input:   makeCodes([]uint{0}),
338		valid:   true,
339	}, {
340		maxBits: 15,
341		input:   makeCodes([]uint{5}),
342		valid:   true,
343	}, {
344		maxBits: 15,
345		input:   makeCodes([]uint{0, 0}),
346		valid:   true,
347	}, {
348		maxBits: 15,
349		input:   makeCodes([]uint{5, 15}),
350		valid:   true,
351	}, {
352		maxBits: 15,
353		input:   makeCodes([]uint{1, 1, 2, 4}),
354		valid:   true,
355	}, {
356		maxBits: 2,
357		input:   makeCodes([]uint{1, 1, 2, 4}),
358		valid:   true,
359	}, {
360		maxBits: 7,
361		input:   makeCodes([]uint{100, 101, 102, 103}),
362		valid:   true,
363	}, {
364		maxBits: 10,
365		input:   makeCodes([]uint{2, 2, 2, 2, 5, 5, 5}),
366		valid:   true,
367	}, {
368		maxBits: 15,
369		input:   makeCodes([]uint{1, 2, 3, 4, 5, 6, 7, 8, 9}),
370		valid:   true,
371	}, {
372		maxBits: 15,
373		input:   makeCodes([]uint{0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}),
374		valid:   true,
375	}, {
376		maxBits: 7,
377		input:   makeCodes([]uint{0, 0, 2, 3, 4, 4, 4, 5, 5, 6, 6, 7, 7, 9, 10, 11, 13, 15}),
378		valid:   true,
379	}, {
380		maxBits: 20,
381		input:   makeCodes([]uint{1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536}),
382		valid:   true,
383	}, {
384		maxBits: 12,
385		input:   makeCodes([]uint{1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536}),
386		valid:   true,
387	}, {
388		maxBits: 15,
389		input: makeCodes([]uint{
390			1, 1, 1, 1, 1, 2, 2, 3, 3, 4, 4, 4, 4, 6, 6, 7, 7, 8, 8, 9, 9, 11, 11,
391			11, 11, 14, 15, 15, 17, 17, 18, 19, 19, 19, 20, 20, 21, 24, 26, 26, 31,
392			32, 34, 35, 38, 40, 43, 47, 48, 50, 59, 62, 63, 75, 78, 79, 85, 86, 97,
393			100, 100, 102, 114, 119, 128, 128, 139, 153, 166, 170, 174, 182, 184,
394			185, 186, 205, 325, 536, 948, 1610, 2555, 2628, 3741,
395		}),
396		valid: true,
397	}, {
398		// Input counts are not sorted in ascending order.
399		maxBits: 15,
400		input: []PrefixCode{
401			{Sym: 0, Cnt: 3},
402			{Sym: 1, Cnt: 2},
403			{Sym: 2, Cnt: 1},
404		},
405		valid: false,
406	}, {
407		// Input symbols are not sorted in ascending order.
408		input: []PrefixCode{
409			{Sym: 2, Len: 1},
410			{Sym: 1, Len: 2},
411			{Sym: 0, Len: 2},
412		},
413		valid: false,
414	}, {
415		// Input symbols are not unique.
416		input: []PrefixCode{
417			{Sym: 5, Len: 1},
418			{Sym: 5, Len: 1},
419		},
420		valid: false,
421	}, {
422		// Invalid small tree.
423		input: []PrefixCode{
424			{Sym: 0, Len: 500},
425		},
426		valid: false,
427	}, {
428		// Some bit-length is too short.
429		input: []PrefixCode{
430			{Sym: 0, Len: 1},
431			{Sym: 1, Len: 2},
432			{Sym: 2, Len: 0},
433		},
434		valid: false,
435	}, {
436		// Under-subscribed tree.
437		input: []PrefixCode{
438			{Sym: 0, Len: 3},
439			{Sym: 1, Len: 4},
440			{Sym: 2, Len: 3},
441		},
442		valid: false,
443	}, {
444		// Over-subscribed tree.
445		input: []PrefixCode{
446			{Sym: 0, Len: 1},
447			{Sym: 1, Len: 3},
448			{Sym: 2, Len: 4},
449			{Sym: 3, Len: 3},
450			{Sym: 4, Len: 2},
451		},
452		valid: false,
453	}, {
454		// Over-subscribed tree (golang.org/issues/5915).
455		input: []PrefixCode{
456			{Sym: 0, Len: 4},
457			{Sym: 3, Len: 6},
458			{Sym: 4, Len: 4},
459			{Sym: 5, Len: 3},
460			{Sym: 6, Len: 2},
461			{Sym: 7, Len: 3},
462			{Sym: 8, Len: 3},
463			{Sym: 9, Len: 4},
464			{Sym: 10, Len: 4},
465			{Sym: 11, Len: 5},
466			{Sym: 16, Len: 5},
467			{Sym: 17, Len: 5},
468			{Sym: 18, Len: 6},
469			{Sym: 29, Len: 11},
470			{Sym: 51, Len: 7},
471			{Sym: 52, Len: 8},
472			{Sym: 53, Len: 6},
473			{Sym: 55, Len: 11},
474			{Sym: 57, Len: 8},
475			{Sym: 59, Len: 6},
476			{Sym: 60, Len: 6},
477			{Sym: 61, Len: 10},
478			{Sym: 62, Len: 8},
479		},
480		valid: false,
481	}, {
482		// Over-subscribed tree (golang.org/issues/5962).
483		input: []PrefixCode{
484			{Sym: 0, Len: 4},
485			{Sym: 3, Len: 6},
486			{Sym: 4, Len: 4},
487			{Sym: 5, Len: 3},
488			{Sym: 6, Len: 2},
489			{Sym: 7, Len: 3},
490			{Sym: 8, Len: 3},
491			{Sym: 9, Len: 4},
492			{Sym: 10, Len: 4},
493			{Sym: 11, Len: 5},
494			{Sym: 16, Len: 5},
495			{Sym: 17, Len: 5},
496			{Sym: 18, Len: 6},
497			{Sym: 29, Len: 11},
498		},
499		valid: false,
500	}, {
501		// Under-subscribed tree (golang.org/issues/6255).
502		input: []PrefixCode{
503			{Sym: 0, Len: 11},
504			{Sym: 1, Len: 13},
505		},
506		valid: false,
507	}}
508
509	for i, v := range vectors {
510		var sum uint32
511		var maxLen uint
512		var lens []int
513		var symBits [valueBits + 1]uint
514
515		codes := v.input
516		if v.maxBits == 0 {
517			goto genPrefixes
518		}
519
520		if err := GenerateLengths(codes, v.maxBits); err != nil {
521			if v.valid {
522				t.Errorf("test %d, unexpected failure", i)
523			}
524			continue
525		}
526
527		for _, c := range codes {
528			if maxLen < uint(c.Len) {
529				maxLen = uint(c.Len)
530			}
531			symBits[c.Len]++
532			lens = append(lens, int(c.Len))
533			sum += c.Cnt
534		}
535
536		if !codes.checkLengths() {
537			t.Errorf("test %d, incomplete tree generated", i)
538		}
539		if !sort.IsSorted(sort.Reverse(sort.IntSlice(lens))) {
540			t.Errorf("test %d, bit-lengths are not sorted:\ngot %v", i, lens)
541		}
542		if maxLen > v.maxBits {
543			t.Errorf("test %d, max bit-length exceeded: %d not in 1..%d", i, maxLen, v.maxBits)
544		}
545
546		// The whole point of prefix encoding is that the resulting bit-lengths
547		// produce an encoding with close to ideal entropy. Thus, compute the
548		// best-case entropy and check that we're not too far from it.
549		if len(codes) >= 4 && sum > 0 {
550			var worst, got, best float64
551			worst = math.Log2(float64(len(codes)))
552			got = float64(codes.Length()) / float64(sum)
553			for _, c := range codes {
554				if c.Cnt > 0 {
555					p := float64(c.Cnt) / float64(sum)
556					best += -(p * math.Log2(p))
557				}
558			}
559
560			if got > worst {
561				t.Errorf("test %d, actual entropy worst than worst-case: %0.3f > %0.3f", i, got, worst)
562			}
563			if got < best {
564				t.Errorf("test %d, actual entropy better than best-case: %0.3f < %0.3f", i, got, best)
565			}
566			if got > 1.15*best {
567				t.Errorf("test %d, actual entropy too high: %0.3f > %0.3f", i, got, 1.15*best)
568			}
569		}
570		codes.SortBySymbol()
571
572	genPrefixes:
573		if err := GeneratePrefixes(codes); err != nil {
574			if v.valid {
575				t.Errorf("test %d, unexpected failure", i)
576			}
577			continue
578		}
579
580		if !codes.checkPrefixes() {
581			t.Errorf("test %d, tree with non-unique prefixes generated", i)
582		}
583		if !codes.checkCanonical() {
584			t.Errorf("test %d, tree with non-canonical prefixes generated", i)
585		}
586		if !v.valid {
587			t.Errorf("test %d, unexpected success", i)
588		}
589	}
590}
591
592func TestPrefix(t *testing.T) {
593	makeCodes := func(freqs []uint) PrefixCodes {
594		codes := make(PrefixCodes, len(freqs))
595		for i, n := range freqs {
596			codes[i] = PrefixCode{Sym: uint32(i), Cnt: uint32(n)}
597		}
598		codes.SortByCount()
599		if err := GenerateLengths(codes, 15); err != nil {
600			t.Fatalf("unexpected error: %v", err)
601		}
602		codes.SortBySymbol()
603		if err := GeneratePrefixes(codes); err != nil {
604			t.Fatalf("unexpected error: %v", err)
605		}
606		return codes
607	}
608
609	vectors := []struct {
610		codes PrefixCodes
611	}{{
612		codes: makeCodes([]uint{}),
613	}, {
614		codes: makeCodes([]uint{0}),
615	}, {
616		codes: makeCodes([]uint{2, 4, 3, 2, 2, 4}),
617	}, {
618		codes: makeCodes([]uint{2, 2, 2, 2, 5, 5, 5}),
619	}, {
620		codes: makeCodes([]uint{100, 101, 102, 103}),
621	}, {
622		codes: makeCodes([]uint{
623			1, 1, 1, 1, 1, 2, 2, 2, 3, 4, 5, 6, 6, 7, 8, 9, 9, 10, 11, 11, 12, 12,
624			14, 15, 15, 16, 18, 18, 19, 19, 20, 20, 20, 25, 25, 27, 29, 31, 32, 35,
625			39, 44, 47, 52, 60, 62, 71, 73, 74, 82, 86, 97, 98, 103, 108, 110, 112,
626			125, 130, 142, 154, 155, 160, 185, 198, 204, 204, 219, 222, 259, 262,
627			292, 296, 302, 334, 434, 450, 679, 697, 1032, 1441, 1888, 1892, 2188,
628		}),
629	}, {
630		codes: testCodes,
631	}, {
632		// Sparsely allocated symbols.
633		codes: []PrefixCode{
634			{Sym: 16, Val: 0, Len: 1},
635			{Sym: 32, Val: 1, Len: 2},
636			{Sym: 64, Val: 3, Len: 3},
637			{Sym: 128, Val: 7, Len: 3},
638		},
639	}, {
640		// Large number of symbols.
641		codes: func() PrefixCodes {
642			freqs := make([]uint, 4096)
643			for i := range freqs {
644				freqs[i] = uint(i)
645			}
646			return makeCodes(freqs)
647		}(),
648	}, {
649		// Max RLE codes from Brotli.
650		codes: func() (codes PrefixCodes) {
651			codes = PrefixCodes{{Sym: 0, Val: 0, Len: 1}}
652			for i := uint32(0); i < 16; i++ {
653				code := PrefixCode{Sym: i + 1, Val: i<<1 | 1, Len: 5}
654				codes = append(codes, code)
655			}
656			return codes
657		}(),
658	}, {
659		// Window bits codes from Brotli.
660		codes: func() (codes PrefixCodes) {
661			for i := uint32(9); i <= 24; i++ {
662				var code PrefixCode
663				switch {
664				case i == 16:
665					code = PrefixCode{Sym: i, Val: (i-16)<<0 | 0, Len: 1} // Symbols: 16
666				case i > 17:
667					code = PrefixCode{Sym: i, Val: (i-17)<<1 | 1, Len: 4} // Symbols: 18..24
668				case i < 17:
669					code = PrefixCode{Sym: i, Val: (i-8)<<4 | 1, Len: 7} // Symbols: 9..15
670				default:
671					code = PrefixCode{Sym: i, Val: (i-17)<<4 | 1, Len: 7} // Symbols: 17
672				}
673				codes = append(codes, code)
674			}
675			codes[0].Sym = 0
676			return codes
677		}(),
678	}, {
679		// Count codes from Brotli.
680		codes: func() (codes PrefixCodes) {
681			codes = PrefixCodes{{Sym: 1, Val: 0, Len: 1}}
682			c := codes[len(codes)-1]
683			for i := uint32(0); i < 8; i++ {
684				for j := uint32(0); j < 1<<i; j++ {
685					c.Sym = c.Sym + 1
686					c.Val = j<<4 | i<<1 | 1
687					c.Len = uint32(i + 4)
688					codes = append(codes, c)
689				}
690			}
691			return codes
692		}(),
693	}, {
694		// Fixed literal codes from DEFLATE.
695		codes: func() (codes PrefixCodes) {
696			for i := 0; i < 144; i++ {
697				codes = append(codes, PrefixCode{Sym: uint32(i), Len: 8})
698			}
699			for i := 144; i < 256; i++ {
700				codes = append(codes, PrefixCode{Sym: uint32(i), Len: 9})
701			}
702			for i := 256; i < 280; i++ {
703				codes = append(codes, PrefixCode{Sym: uint32(i), Len: 7})
704			}
705			for i := 280; i < 288; i++ {
706				codes = append(codes, PrefixCode{Sym: uint32(i), Len: 8})
707			}
708			if err := GeneratePrefixes(codes); err != nil {
709				t.Fatalf("unexpected error: %v", err)
710			}
711			return codes
712		}(),
713	}, {
714		// Fixed distance codes from DEFLATE.
715		codes: func() (codes PrefixCodes) {
716			for i := 0; i < 32; i++ {
717				codes = append(codes, PrefixCode{Sym: uint32(i), Len: 5})
718			}
719			if err := GeneratePrefixes(codes); err != nil {
720				t.Fatalf("unexpected error: %v", err)
721			}
722			return codes
723		}(),
724	}}
725
726	for i, v := range vectors {
727		// Generate an arbitrary prefix Decoder and Encoder.
728		var pd Decoder
729		var pe Encoder
730		pd.Init(v.codes)
731		pe.Init(v.codes)
732		if len(v.codes) == 0 {
733			continue
734		}
735
736		// Create an arbitrary list of symbols to encode.
737		r := testutil.NewRand(0)
738		syms := make([]uint, 1000)
739		for i := range syms {
740			syms[i] = uint(v.codes[r.Intn(len(v.codes))].Sym)
741		}
742
743		// Setup a Reader and Writer.
744		var buf bytes.Buffer
745		var rd Reader
746		var wr Writer
747		rdwr := struct {
748			io.Reader
749			io.ByteReader
750			io.Writer
751		}{&buf, &buf, &buf}
752		rd.Init(rdwr, false)
753		wr.Init(rdwr, false)
754
755		// Write some symbols.
756		for _, sym := range syms {
757			ok := wr.TryWriteSymbol(sym, &pe)
758			if !ok {
759				wr.WriteSymbol(sym, &pe)
760			}
761		}
762		wr.WritePads(0)
763		if _, err := wr.Flush(); err != nil {
764			t.Errorf("test %d, unexpected Writer error: %v", i, err)
765		}
766
767		// Verify some Writer statistics.
768		if wr.Offset != int64(buf.Len()) {
769			t.Errorf("test %d, offset mismatch: got %d, want %d", i, wr.Offset, buf.Len())
770		}
771		if wr.numBits != 0 {
772			t.Errorf("test %d, residual bits remaining: got %d, want 0", i, wr.numBits)
773		}
774		if wr.cntBuf != 0 {
775			t.Errorf("test %d, residual bytes remaining: got %d, want 0", i, wr.cntBuf)
776		}
777
778		// Read some symbols.
779		for i := range syms {
780			sym, ok := rd.TryReadSymbol(&pd)
781			if !ok {
782				sym = rd.ReadSymbol(&pd)
783			}
784			if sym != syms[i] {
785				t.Errorf("test %d, read back wrong symbol: got %d, want %d", i, sym, syms[i])
786			}
787			if rd.numBits >= 8 {
788				t.Errorf("test %d, residual bits remaining: got %d, want < 8", i, rd.numBits)
789			}
790		}
791		pads := rd.ReadPads()
792		if _, err := rd.Flush(); err != nil {
793			t.Errorf("test %d, unexpected Reader error: %v", i, err)
794		}
795
796		// Verify some Reader statistics.
797		if pads != 0 {
798			t.Errorf("test %d, unexpected padding bits: got %d, want 0", i, pads)
799		}
800		if rd.numBits != 0 {
801			t.Errorf("test %d, residual bits remaining: got %d, want 0", i, rd.numBits)
802		}
803		if rd.Offset != wr.Offset {
804			t.Errorf("test %d, offset mismatch: got %d, want %d", i, rd.Offset, wr.Offset)
805		}
806	}
807}
808
809func TestRange(t *testing.T) {
810	vectors := []struct {
811		input RangeCodes
812		valid bool
813	}{{
814		input: RangeCodes{},
815		valid: false,
816	}, {
817		input: RangeCodes{{5, 2}, {10, 5}}, // Gap in-between
818		valid: false,
819	}, {
820		input: RangeCodes{{5, 20}, {7, 5}}, // All-encompassing overlap
821		valid: false,
822	}, {
823		input: RangeCodes{{7, 5}, {5, 2}}, // Out-of-order
824		valid: false,
825	}, {
826		input: RangeCodes{{5, 10}, {6, 11}}, // Forward-overlap is okay
827		valid: true,
828	}, {
829		input: testRanges,
830		valid: true,
831	}, {
832		input: MakeRangeCodes(0, []uint{
833			0, 0, 0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 7, 8, 9, 10, 12, 14, 24,
834		}),
835		valid: true,
836	}, {
837		input: MakeRangeCodes(2, []uint{
838			0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 7, 8, 9, 10, 24,
839		}),
840		valid: true,
841	}, {
842		input: MakeRangeCodes(1, []uint{
843			2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 7, 8, 9, 10, 11, 12, 13, 24,
844		}),
845		valid: true,
846	}, {
847		input: MakeRangeCodes(2, []uint{
848			1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
849		}),
850		valid: true,
851	}, {
852		input: append(MakeRangeCodes(3, []uint{
853			0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5,
854		}), RangeCode{Base: 258, Len: 0}),
855		valid: true,
856	}, {
857		input: MakeRangeCodes(1, []uint{
858			0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13,
859		}),
860		valid: true,
861	}}
862
863	r := testutil.NewRand(0)
864	for i, v := range vectors {
865		if valid := v.input.checkValid(); valid != v.valid {
866			t.Errorf("test %d, validity mismatch: got %v, want %v", i, valid, v.valid)
867		}
868		if !v.valid {
869			continue // No point further testing invalid ranges
870		}
871
872		var re RangeEncoder
873		re.Init(v.input)
874
875		for _, rc := range v.input {
876			offset := rc.Base + uint32(r.Intn(int(rc.End()-rc.Base)))
877			sym := re.Encode(uint(offset))
878			if int(sym) >= len(v.input) {
879				t.Errorf("test %d, invalid symbol: re.Encode(%d) = %d", i, offset, sym)
880			}
881			rc := v.input[sym]
882			if offset < rc.Base || offset >= rc.End() {
883				t.Errorf("test %d, symbol not in range: %d not in %d..%d", i, offset, rc.Base, rc.End()-1)
884			}
885		}
886	}
887}
888
889func BenchmarkBitReader(b *testing.B) {
890	var br Reader
891	nbs := []uint{1, 2, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 7, 7, 8, 9, 9, 13, 15}
892	n := 16 * b.N
893	bb := bytes.NewBuffer(make([]byte, n))
894	br.Init(bb, false)
895
896	b.SetBytes(16)
897	b.ResetTimer()
898	for i := 0; i < b.N; i++ {
899		for _, nb := range nbs {
900			_, ok := br.TryReadBits(nb)
901			if !ok {
902				_ = br.ReadBits(nb)
903			}
904		}
905	}
906}
907
908func BenchmarkBitWriter(b *testing.B) {
909	var bw Writer
910	nbs := []uint{1, 2, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 7, 7, 8, 9, 9, 13, 15}
911	n := 16 * b.N
912	bb := bytes.NewBuffer(make([]byte, 0, n))
913	bw.Init(bb, false)
914
915	b.SetBytes(16)
916	b.ResetTimer()
917	for i := 0; i < b.N; i++ {
918		for _, nb := range nbs {
919			ok := bw.TryWriteBits(0, nb)
920			if !ok {
921				bw.WriteBits(0, nb)
922			}
923		}
924	}
925}
926