1package huff0
2
3import (
4	"errors"
5	"fmt"
6	"io"
7
8	"github.com/klauspost/compress/fse"
9)
10
11type dTable struct {
12	single []dEntrySingle
13	double []dEntryDouble
14}
15
16// single-symbols decoding
17type dEntrySingle struct {
18	entry uint16
19}
20
21// double-symbols decoding
22type dEntryDouble struct {
23	seq   uint16
24	nBits uint8
25	len   uint8
26}
27
28// Uses special code for all tables that are < 8 bits.
29const use8BitTables = true
30
31// ReadTable will read a table from the input.
32// The size of the input may be larger than the table definition.
33// Any content remaining after the table definition will be returned.
34// If no Scratch is provided a new one is allocated.
35// The returned Scratch can be used for encoding or decoding input using this table.
36func ReadTable(in []byte, s *Scratch) (s2 *Scratch, remain []byte, err error) {
37	s, err = s.prepare(in)
38	if err != nil {
39		return s, nil, err
40	}
41	if len(in) <= 1 {
42		return s, nil, errors.New("input too small for table")
43	}
44	iSize := in[0]
45	in = in[1:]
46	if iSize >= 128 {
47		// Uncompressed
48		oSize := iSize - 127
49		iSize = (oSize + 1) / 2
50		if int(iSize) > len(in) {
51			return s, nil, errors.New("input too small for table")
52		}
53		for n := uint8(0); n < oSize; n += 2 {
54			v := in[n/2]
55			s.huffWeight[n] = v >> 4
56			s.huffWeight[n+1] = v & 15
57		}
58		s.symbolLen = uint16(oSize)
59		in = in[iSize:]
60	} else {
61		if len(in) < int(iSize) {
62			return s, nil, fmt.Errorf("input too small for table, want %d bytes, have %d", iSize, len(in))
63		}
64		// FSE compressed weights
65		s.fse.DecompressLimit = 255
66		hw := s.huffWeight[:]
67		s.fse.Out = hw
68		b, err := fse.Decompress(in[:iSize], s.fse)
69		s.fse.Out = nil
70		if err != nil {
71			return s, nil, err
72		}
73		if len(b) > 255 {
74			return s, nil, errors.New("corrupt input: output table too large")
75		}
76		s.symbolLen = uint16(len(b))
77		in = in[iSize:]
78	}
79
80	// collect weight stats
81	var rankStats [16]uint32
82	weightTotal := uint32(0)
83	for _, v := range s.huffWeight[:s.symbolLen] {
84		if v > tableLogMax {
85			return s, nil, errors.New("corrupt input: weight too large")
86		}
87		v2 := v & 15
88		rankStats[v2]++
89		// (1 << (v2-1)) is slower since the compiler cannot prove that v2 isn't 0.
90		weightTotal += (1 << v2) >> 1
91	}
92	if weightTotal == 0 {
93		return s, nil, errors.New("corrupt input: weights zero")
94	}
95
96	// get last non-null symbol weight (implied, total must be 2^n)
97	{
98		tableLog := highBit32(weightTotal) + 1
99		if tableLog > tableLogMax {
100			return s, nil, errors.New("corrupt input: tableLog too big")
101		}
102		s.actualTableLog = uint8(tableLog)
103		// determine last weight
104		{
105			total := uint32(1) << tableLog
106			rest := total - weightTotal
107			verif := uint32(1) << highBit32(rest)
108			lastWeight := highBit32(rest) + 1
109			if verif != rest {
110				// last value must be a clean power of 2
111				return s, nil, errors.New("corrupt input: last value not power of two")
112			}
113			s.huffWeight[s.symbolLen] = uint8(lastWeight)
114			s.symbolLen++
115			rankStats[lastWeight]++
116		}
117	}
118
119	if (rankStats[1] < 2) || (rankStats[1]&1 != 0) {
120		// by construction : at least 2 elts of rank 1, must be even
121		return s, nil, errors.New("corrupt input: min elt size, even check failed ")
122	}
123
124	// TODO: Choose between single/double symbol decoding
125
126	// Calculate starting value for each rank
127	{
128		var nextRankStart uint32
129		for n := uint8(1); n < s.actualTableLog+1; n++ {
130			current := nextRankStart
131			nextRankStart += rankStats[n] << (n - 1)
132			rankStats[n] = current
133		}
134	}
135
136	// fill DTable (always full size)
137	tSize := 1 << tableLogMax
138	if len(s.dt.single) != tSize {
139		s.dt.single = make([]dEntrySingle, tSize)
140	}
141	cTable := s.prevTable
142	if cap(cTable) < maxSymbolValue+1 {
143		cTable = make([]cTableEntry, 0, maxSymbolValue+1)
144	}
145	cTable = cTable[:maxSymbolValue+1]
146	s.prevTable = cTable[:s.symbolLen]
147	s.prevTableLog = s.actualTableLog
148
149	for n, w := range s.huffWeight[:s.symbolLen] {
150		if w == 0 {
151			cTable[n] = cTableEntry{
152				val:   0,
153				nBits: 0,
154			}
155			continue
156		}
157		length := (uint32(1) << w) >> 1
158		d := dEntrySingle{
159			entry: uint16(s.actualTableLog+1-w) | (uint16(n) << 8),
160		}
161
162		rank := &rankStats[w]
163		cTable[n] = cTableEntry{
164			val:   uint16(*rank >> (w - 1)),
165			nBits: uint8(d.entry),
166		}
167
168		single := s.dt.single[*rank : *rank+length]
169		for i := range single {
170			single[i] = d
171		}
172		*rank += length
173	}
174
175	return s, in, nil
176}
177
178// Decompress1X will decompress a 1X encoded stream.
179// The length of the supplied input must match the end of a block exactly.
180// Before this is called, the table must be initialized with ReadTable unless
181// the encoder re-used the table.
182// deprecated: Use the stateless Decoder() to get a concurrent version.
183func (s *Scratch) Decompress1X(in []byte) (out []byte, err error) {
184	if cap(s.Out) < s.MaxDecodedSize {
185		s.Out = make([]byte, s.MaxDecodedSize)
186	}
187	s.Out = s.Out[:0:s.MaxDecodedSize]
188	s.Out, err = s.Decoder().Decompress1X(s.Out, in)
189	return s.Out, err
190}
191
192// Decompress4X will decompress a 4X encoded stream.
193// Before this is called, the table must be initialized with ReadTable unless
194// the encoder re-used the table.
195// The length of the supplied input must match the end of a block exactly.
196// The destination size of the uncompressed data must be known and provided.
197// deprecated: Use the stateless Decoder() to get a concurrent version.
198func (s *Scratch) Decompress4X(in []byte, dstSize int) (out []byte, err error) {
199	if dstSize > s.MaxDecodedSize {
200		return nil, ErrMaxDecodedSizeExceeded
201	}
202	if cap(s.Out) < dstSize {
203		s.Out = make([]byte, s.MaxDecodedSize)
204	}
205	s.Out = s.Out[:0:dstSize]
206	s.Out, err = s.Decoder().Decompress4X(s.Out, in)
207	return s.Out, err
208}
209
210// Decoder will return a stateless decoder that can be used by multiple
211// decompressors concurrently.
212// Before this is called, the table must be initialized with ReadTable.
213// The Decoder is still linked to the scratch buffer so that cannot be reused.
214// However, it is safe to discard the scratch.
215func (s *Scratch) Decoder() *Decoder {
216	return &Decoder{
217		dt:             s.dt,
218		actualTableLog: s.actualTableLog,
219	}
220}
221
222// Decoder provides stateless decoding.
223type Decoder struct {
224	dt             dTable
225	actualTableLog uint8
226}
227
228// Decompress1X will decompress a 1X encoded stream.
229// The cap of the output buffer will be the maximum decompressed size.
230// The length of the supplied input must match the end of a block exactly.
231func (d *Decoder) Decompress1X(dst, src []byte) ([]byte, error) {
232	if len(d.dt.single) == 0 {
233		return nil, errors.New("no table loaded")
234	}
235	if use8BitTables && d.actualTableLog <= 8 {
236		return d.decompress1X8Bit(dst, src)
237	}
238	var br bitReaderShifted
239	err := br.init(src)
240	if err != nil {
241		return dst, err
242	}
243	maxDecodedSize := cap(dst)
244	dst = dst[:0]
245
246	// Avoid bounds check by always having full sized table.
247	const tlSize = 1 << tableLogMax
248	const tlMask = tlSize - 1
249	dt := d.dt.single[:tlSize]
250
251	// Use temp table to avoid bound checks/append penalty.
252	var buf [256]byte
253	var off uint8
254
255	for br.off >= 8 {
256		br.fillFast()
257		v := dt[br.peekBitsFast(d.actualTableLog)&tlMask]
258		br.advance(uint8(v.entry))
259		buf[off+0] = uint8(v.entry >> 8)
260
261		v = dt[br.peekBitsFast(d.actualTableLog)&tlMask]
262		br.advance(uint8(v.entry))
263		buf[off+1] = uint8(v.entry >> 8)
264
265		// Refill
266		br.fillFast()
267
268		v = dt[br.peekBitsFast(d.actualTableLog)&tlMask]
269		br.advance(uint8(v.entry))
270		buf[off+2] = uint8(v.entry >> 8)
271
272		v = dt[br.peekBitsFast(d.actualTableLog)&tlMask]
273		br.advance(uint8(v.entry))
274		buf[off+3] = uint8(v.entry >> 8)
275
276		off += 4
277		if off == 0 {
278			if len(dst)+256 > maxDecodedSize {
279				br.close()
280				return nil, ErrMaxDecodedSizeExceeded
281			}
282			dst = append(dst, buf[:]...)
283		}
284	}
285
286	if len(dst)+int(off) > maxDecodedSize {
287		br.close()
288		return nil, ErrMaxDecodedSizeExceeded
289	}
290	dst = append(dst, buf[:off]...)
291
292	// br < 8, so uint8 is fine
293	bitsLeft := uint8(br.off)*8 + 64 - br.bitsRead
294	for bitsLeft > 0 {
295		br.fill()
296		if false && br.bitsRead >= 32 {
297			if br.off >= 4 {
298				v := br.in[br.off-4:]
299				v = v[:4]
300				low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
301				br.value = (br.value << 32) | uint64(low)
302				br.bitsRead -= 32
303				br.off -= 4
304			} else {
305				for br.off > 0 {
306					br.value = (br.value << 8) | uint64(br.in[br.off-1])
307					br.bitsRead -= 8
308					br.off--
309				}
310			}
311		}
312		if len(dst) >= maxDecodedSize {
313			br.close()
314			return nil, ErrMaxDecodedSizeExceeded
315		}
316		v := d.dt.single[br.peekBitsFast(d.actualTableLog)&tlMask]
317		nBits := uint8(v.entry)
318		br.advance(nBits)
319		bitsLeft -= nBits
320		dst = append(dst, uint8(v.entry>>8))
321	}
322	return dst, br.close()
323}
324
325// decompress1X8Bit will decompress a 1X encoded stream with tablelog <= 8.
326// The cap of the output buffer will be the maximum decompressed size.
327// The length of the supplied input must match the end of a block exactly.
328func (d *Decoder) decompress1X8Bit(dst, src []byte) ([]byte, error) {
329	if d.actualTableLog == 8 {
330		return d.decompress1X8BitExactly(dst, src)
331	}
332	var br bitReaderBytes
333	err := br.init(src)
334	if err != nil {
335		return dst, err
336	}
337	maxDecodedSize := cap(dst)
338	dst = dst[:0]
339
340	// Avoid bounds check by always having full sized table.
341	dt := d.dt.single[:256]
342
343	// Use temp table to avoid bound checks/append penalty.
344	var buf [256]byte
345	var off uint8
346
347	shift := (8 - d.actualTableLog) & 7
348
349	//fmt.Printf("mask: %b, tl:%d\n", mask, d.actualTableLog)
350	for br.off >= 4 {
351		br.fillFast()
352		v := dt[br.peekByteFast()>>shift]
353		br.advance(uint8(v.entry))
354		buf[off+0] = uint8(v.entry >> 8)
355
356		v = dt[br.peekByteFast()>>shift]
357		br.advance(uint8(v.entry))
358		buf[off+1] = uint8(v.entry >> 8)
359
360		v = dt[br.peekByteFast()>>shift]
361		br.advance(uint8(v.entry))
362		buf[off+2] = uint8(v.entry >> 8)
363
364		v = dt[br.peekByteFast()>>shift]
365		br.advance(uint8(v.entry))
366		buf[off+3] = uint8(v.entry >> 8)
367
368		off += 4
369		if off == 0 {
370			if len(dst)+256 > maxDecodedSize {
371				br.close()
372				return nil, ErrMaxDecodedSizeExceeded
373			}
374			dst = append(dst, buf[:]...)
375		}
376	}
377
378	if len(dst)+int(off) > maxDecodedSize {
379		br.close()
380		return nil, ErrMaxDecodedSizeExceeded
381	}
382	dst = append(dst, buf[:off]...)
383
384	// br < 4, so uint8 is fine
385	bitsLeft := int8(uint8(br.off)*8 + (64 - br.bitsRead))
386	for bitsLeft > 0 {
387		if br.bitsRead >= 64-8 {
388			for br.off > 0 {
389				br.value |= uint64(br.in[br.off-1]) << (br.bitsRead - 8)
390				br.bitsRead -= 8
391				br.off--
392			}
393		}
394		if len(dst) >= maxDecodedSize {
395			br.close()
396			return nil, ErrMaxDecodedSizeExceeded
397		}
398		v := dt[br.peekByteFast()>>shift]
399		nBits := uint8(v.entry)
400		br.advance(nBits)
401		bitsLeft -= int8(nBits)
402		dst = append(dst, uint8(v.entry>>8))
403	}
404	return dst, br.close()
405}
406
407// decompress1X8Bit will decompress a 1X encoded stream with tablelog <= 8.
408// The cap of the output buffer will be the maximum decompressed size.
409// The length of the supplied input must match the end of a block exactly.
410func (d *Decoder) decompress1X8BitExactly(dst, src []byte) ([]byte, error) {
411	var br bitReaderBytes
412	err := br.init(src)
413	if err != nil {
414		return dst, err
415	}
416	maxDecodedSize := cap(dst)
417	dst = dst[:0]
418
419	// Avoid bounds check by always having full sized table.
420	dt := d.dt.single[:256]
421
422	// Use temp table to avoid bound checks/append penalty.
423	var buf [256]byte
424	var off uint8
425
426	const shift = 0
427
428	//fmt.Printf("mask: %b, tl:%d\n", mask, d.actualTableLog)
429	for br.off >= 4 {
430		br.fillFast()
431		v := dt[br.peekByteFast()>>shift]
432		br.advance(uint8(v.entry))
433		buf[off+0] = uint8(v.entry >> 8)
434
435		v = dt[br.peekByteFast()>>shift]
436		br.advance(uint8(v.entry))
437		buf[off+1] = uint8(v.entry >> 8)
438
439		v = dt[br.peekByteFast()>>shift]
440		br.advance(uint8(v.entry))
441		buf[off+2] = uint8(v.entry >> 8)
442
443		v = dt[br.peekByteFast()>>shift]
444		br.advance(uint8(v.entry))
445		buf[off+3] = uint8(v.entry >> 8)
446
447		off += 4
448		if off == 0 {
449			if len(dst)+256 > maxDecodedSize {
450				br.close()
451				return nil, ErrMaxDecodedSizeExceeded
452			}
453			dst = append(dst, buf[:]...)
454		}
455	}
456
457	if len(dst)+int(off) > maxDecodedSize {
458		br.close()
459		return nil, ErrMaxDecodedSizeExceeded
460	}
461	dst = append(dst, buf[:off]...)
462
463	// br < 4, so uint8 is fine
464	bitsLeft := int8(uint8(br.off)*8 + (64 - br.bitsRead))
465	for bitsLeft > 0 {
466		if br.bitsRead >= 64-8 {
467			for br.off > 0 {
468				br.value |= uint64(br.in[br.off-1]) << (br.bitsRead - 8)
469				br.bitsRead -= 8
470				br.off--
471			}
472		}
473		if len(dst) >= maxDecodedSize {
474			br.close()
475			return nil, ErrMaxDecodedSizeExceeded
476		}
477		v := dt[br.peekByteFast()>>shift]
478		nBits := uint8(v.entry)
479		br.advance(nBits)
480		bitsLeft -= int8(nBits)
481		dst = append(dst, uint8(v.entry>>8))
482	}
483	return dst, br.close()
484}
485
486// Decompress4X will decompress a 4X encoded stream.
487// The length of the supplied input must match the end of a block exactly.
488// The *capacity* of the dst slice must match the destination size of
489// the uncompressed data exactly.
490func (d *Decoder) Decompress4X(dst, src []byte) ([]byte, error) {
491	if len(d.dt.single) == 0 {
492		return nil, errors.New("no table loaded")
493	}
494	if len(src) < 6+(4*1) {
495		return nil, errors.New("input too small")
496	}
497	if use8BitTables && d.actualTableLog <= 8 {
498		return d.decompress4X8bit(dst, src)
499	}
500
501	var br [4]bitReaderShifted
502	start := 6
503	for i := 0; i < 3; i++ {
504		length := int(src[i*2]) | (int(src[i*2+1]) << 8)
505		if start+length >= len(src) {
506			return nil, errors.New("truncated input (or invalid offset)")
507		}
508		err := br[i].init(src[start : start+length])
509		if err != nil {
510			return nil, err
511		}
512		start += length
513	}
514	err := br[3].init(src[start:])
515	if err != nil {
516		return nil, err
517	}
518
519	// destination, offset to match first output
520	dstSize := cap(dst)
521	dst = dst[:dstSize]
522	out := dst
523	dstEvery := (dstSize + 3) / 4
524
525	const tlSize = 1 << tableLogMax
526	const tlMask = tlSize - 1
527	single := d.dt.single[:tlSize]
528
529	// Use temp table to avoid bound checks/append penalty.
530	var buf [256]byte
531	var off uint8
532	var decoded int
533
534	// Decode 2 values from each decoder/loop.
535	const bufoff = 256 / 4
536	for {
537		if br[0].off < 4 || br[1].off < 4 || br[2].off < 4 || br[3].off < 4 {
538			break
539		}
540
541		{
542			const stream = 0
543			const stream2 = 1
544			br[stream].fillFast()
545			br[stream2].fillFast()
546
547			val := br[stream].peekBitsFast(d.actualTableLog)
548			v := single[val&tlMask]
549			br[stream].advance(uint8(v.entry))
550			buf[off+bufoff*stream] = uint8(v.entry >> 8)
551
552			val2 := br[stream2].peekBitsFast(d.actualTableLog)
553			v2 := single[val2&tlMask]
554			br[stream2].advance(uint8(v2.entry))
555			buf[off+bufoff*stream2] = uint8(v2.entry >> 8)
556
557			val = br[stream].peekBitsFast(d.actualTableLog)
558			v = single[val&tlMask]
559			br[stream].advance(uint8(v.entry))
560			buf[off+bufoff*stream+1] = uint8(v.entry >> 8)
561
562			val2 = br[stream2].peekBitsFast(d.actualTableLog)
563			v2 = single[val2&tlMask]
564			br[stream2].advance(uint8(v2.entry))
565			buf[off+bufoff*stream2+1] = uint8(v2.entry >> 8)
566		}
567
568		{
569			const stream = 2
570			const stream2 = 3
571			br[stream].fillFast()
572			br[stream2].fillFast()
573
574			val := br[stream].peekBitsFast(d.actualTableLog)
575			v := single[val&tlMask]
576			br[stream].advance(uint8(v.entry))
577			buf[off+bufoff*stream] = uint8(v.entry >> 8)
578
579			val2 := br[stream2].peekBitsFast(d.actualTableLog)
580			v2 := single[val2&tlMask]
581			br[stream2].advance(uint8(v2.entry))
582			buf[off+bufoff*stream2] = uint8(v2.entry >> 8)
583
584			val = br[stream].peekBitsFast(d.actualTableLog)
585			v = single[val&tlMask]
586			br[stream].advance(uint8(v.entry))
587			buf[off+bufoff*stream+1] = uint8(v.entry >> 8)
588
589			val2 = br[stream2].peekBitsFast(d.actualTableLog)
590			v2 = single[val2&tlMask]
591			br[stream2].advance(uint8(v2.entry))
592			buf[off+bufoff*stream2+1] = uint8(v2.entry >> 8)
593		}
594
595		off += 2
596
597		if off == bufoff {
598			if bufoff > dstEvery {
599				return nil, errors.New("corruption detected: stream overrun 1")
600			}
601			copy(out, buf[:bufoff])
602			copy(out[dstEvery:], buf[bufoff:bufoff*2])
603			copy(out[dstEvery*2:], buf[bufoff*2:bufoff*3])
604			copy(out[dstEvery*3:], buf[bufoff*3:bufoff*4])
605			off = 0
606			out = out[bufoff:]
607			decoded += 256
608			// There must at least be 3 buffers left.
609			if len(out) < dstEvery*3 {
610				return nil, errors.New("corruption detected: stream overrun 2")
611			}
612		}
613	}
614	if off > 0 {
615		ioff := int(off)
616		if len(out) < dstEvery*3+ioff {
617			return nil, errors.New("corruption detected: stream overrun 3")
618		}
619		copy(out, buf[:off])
620		copy(out[dstEvery:dstEvery+ioff], buf[bufoff:bufoff*2])
621		copy(out[dstEvery*2:dstEvery*2+ioff], buf[bufoff*2:bufoff*3])
622		copy(out[dstEvery*3:dstEvery*3+ioff], buf[bufoff*3:bufoff*4])
623		decoded += int(off) * 4
624		out = out[off:]
625	}
626
627	// Decode remaining.
628	for i := range br {
629		offset := dstEvery * i
630		br := &br[i]
631		bitsLeft := br.off*8 + uint(64-br.bitsRead)
632		for bitsLeft > 0 {
633			br.fill()
634			if false && br.bitsRead >= 32 {
635				if br.off >= 4 {
636					v := br.in[br.off-4:]
637					v = v[:4]
638					low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
639					br.value = (br.value << 32) | uint64(low)
640					br.bitsRead -= 32
641					br.off -= 4
642				} else {
643					for br.off > 0 {
644						br.value = (br.value << 8) | uint64(br.in[br.off-1])
645						br.bitsRead -= 8
646						br.off--
647					}
648				}
649			}
650			// end inline...
651			if offset >= len(out) {
652				return nil, errors.New("corruption detected: stream overrun 4")
653			}
654
655			// Read value and increment offset.
656			val := br.peekBitsFast(d.actualTableLog)
657			v := single[val&tlMask].entry
658			nBits := uint8(v)
659			br.advance(nBits)
660			bitsLeft -= uint(nBits)
661			out[offset] = uint8(v >> 8)
662			offset++
663		}
664		decoded += offset - dstEvery*i
665		err = br.close()
666		if err != nil {
667			return nil, err
668		}
669	}
670	if dstSize != decoded {
671		return nil, errors.New("corruption detected: short output block")
672	}
673	return dst, nil
674}
675
676// Decompress4X will decompress a 4X encoded stream.
677// The length of the supplied input must match the end of a block exactly.
678// The *capacity* of the dst slice must match the destination size of
679// the uncompressed data exactly.
680func (d *Decoder) decompress4X8bit(dst, src []byte) ([]byte, error) {
681	if d.actualTableLog == 8 {
682		return d.decompress4X8bitExactly(dst, src)
683	}
684
685	var br [4]bitReaderBytes
686	start := 6
687	for i := 0; i < 3; i++ {
688		length := int(src[i*2]) | (int(src[i*2+1]) << 8)
689		if start+length >= len(src) {
690			return nil, errors.New("truncated input (or invalid offset)")
691		}
692		err := br[i].init(src[start : start+length])
693		if err != nil {
694			return nil, err
695		}
696		start += length
697	}
698	err := br[3].init(src[start:])
699	if err != nil {
700		return nil, err
701	}
702
703	// destination, offset to match first output
704	dstSize := cap(dst)
705	dst = dst[:dstSize]
706	out := dst
707	dstEvery := (dstSize + 3) / 4
708
709	shift := (8 - d.actualTableLog) & 7
710
711	const tlSize = 1 << 8
712	const tlMask = tlSize - 1
713	single := d.dt.single[:tlSize]
714
715	// Use temp table to avoid bound checks/append penalty.
716	var buf [256]byte
717	var off uint8
718	var decoded int
719
720	// Decode 4 values from each decoder/loop.
721	const bufoff = 256 / 4
722	for {
723		if br[0].off < 4 || br[1].off < 4 || br[2].off < 4 || br[3].off < 4 {
724			break
725		}
726
727		{
728			// Interleave 2 decodes.
729			const stream = 0
730			const stream2 = 1
731			br[stream].fillFast()
732			br[stream2].fillFast()
733
734			v := single[br[stream].peekByteFast()>>shift].entry
735			buf[off+bufoff*stream] = uint8(v >> 8)
736			br[stream].advance(uint8(v))
737
738			v2 := single[br[stream2].peekByteFast()>>shift].entry
739			buf[off+bufoff*stream2] = uint8(v2 >> 8)
740			br[stream2].advance(uint8(v2))
741
742			v = single[br[stream].peekByteFast()>>shift].entry
743			buf[off+bufoff*stream+1] = uint8(v >> 8)
744			br[stream].advance(uint8(v))
745
746			v2 = single[br[stream2].peekByteFast()>>shift].entry
747			buf[off+bufoff*stream2+1] = uint8(v2 >> 8)
748			br[stream2].advance(uint8(v2))
749
750			v = single[br[stream].peekByteFast()>>shift].entry
751			buf[off+bufoff*stream+2] = uint8(v >> 8)
752			br[stream].advance(uint8(v))
753
754			v2 = single[br[stream2].peekByteFast()>>shift].entry
755			buf[off+bufoff*stream2+2] = uint8(v2 >> 8)
756			br[stream2].advance(uint8(v2))
757
758			v = single[br[stream].peekByteFast()>>shift].entry
759			buf[off+bufoff*stream+3] = uint8(v >> 8)
760			br[stream].advance(uint8(v))
761
762			v2 = single[br[stream2].peekByteFast()>>shift].entry
763			buf[off+bufoff*stream2+3] = uint8(v2 >> 8)
764			br[stream2].advance(uint8(v2))
765		}
766
767		{
768			const stream = 2
769			const stream2 = 3
770			br[stream].fillFast()
771			br[stream2].fillFast()
772
773			v := single[br[stream].peekByteFast()>>shift].entry
774			buf[off+bufoff*stream] = uint8(v >> 8)
775			br[stream].advance(uint8(v))
776
777			v2 := single[br[stream2].peekByteFast()>>shift].entry
778			buf[off+bufoff*stream2] = uint8(v2 >> 8)
779			br[stream2].advance(uint8(v2))
780
781			v = single[br[stream].peekByteFast()>>shift].entry
782			buf[off+bufoff*stream+1] = uint8(v >> 8)
783			br[stream].advance(uint8(v))
784
785			v2 = single[br[stream2].peekByteFast()>>shift].entry
786			buf[off+bufoff*stream2+1] = uint8(v2 >> 8)
787			br[stream2].advance(uint8(v2))
788
789			v = single[br[stream].peekByteFast()>>shift].entry
790			buf[off+bufoff*stream+2] = uint8(v >> 8)
791			br[stream].advance(uint8(v))
792
793			v2 = single[br[stream2].peekByteFast()>>shift].entry
794			buf[off+bufoff*stream2+2] = uint8(v2 >> 8)
795			br[stream2].advance(uint8(v2))
796
797			v = single[br[stream].peekByteFast()>>shift].entry
798			buf[off+bufoff*stream+3] = uint8(v >> 8)
799			br[stream].advance(uint8(v))
800
801			v2 = single[br[stream2].peekByteFast()>>shift].entry
802			buf[off+bufoff*stream2+3] = uint8(v2 >> 8)
803			br[stream2].advance(uint8(v2))
804		}
805
806		off += 4
807
808		if off == bufoff {
809			if bufoff > dstEvery {
810				return nil, errors.New("corruption detected: stream overrun 1")
811			}
812			copy(out, buf[:bufoff])
813			copy(out[dstEvery:], buf[bufoff:bufoff*2])
814			copy(out[dstEvery*2:], buf[bufoff*2:bufoff*3])
815			copy(out[dstEvery*3:], buf[bufoff*3:bufoff*4])
816			off = 0
817			out = out[bufoff:]
818			decoded += 256
819			// There must at least be 3 buffers left.
820			if len(out) < dstEvery*3 {
821				return nil, errors.New("corruption detected: stream overrun 2")
822			}
823		}
824	}
825	if off > 0 {
826		ioff := int(off)
827		if len(out) < dstEvery*3+ioff {
828			return nil, errors.New("corruption detected: stream overrun 3")
829		}
830		copy(out, buf[:off])
831		copy(out[dstEvery:dstEvery+ioff], buf[bufoff:bufoff*2])
832		copy(out[dstEvery*2:dstEvery*2+ioff], buf[bufoff*2:bufoff*3])
833		copy(out[dstEvery*3:dstEvery*3+ioff], buf[bufoff*3:bufoff*4])
834		decoded += int(off) * 4
835		out = out[off:]
836	}
837
838	// Decode remaining.
839	for i := range br {
840		offset := dstEvery * i
841		br := &br[i]
842		bitsLeft := int(br.off*8) + int(64-br.bitsRead)
843		for bitsLeft > 0 {
844			if br.finished() {
845				return nil, io.ErrUnexpectedEOF
846			}
847			if br.bitsRead >= 56 {
848				if br.off >= 4 {
849					v := br.in[br.off-4:]
850					v = v[:4]
851					low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
852					br.value |= uint64(low) << (br.bitsRead - 32)
853					br.bitsRead -= 32
854					br.off -= 4
855				} else {
856					for br.off > 0 {
857						br.value |= uint64(br.in[br.off-1]) << (br.bitsRead - 8)
858						br.bitsRead -= 8
859						br.off--
860					}
861				}
862			}
863			// end inline...
864			if offset >= len(out) {
865				return nil, errors.New("corruption detected: stream overrun 4")
866			}
867
868			// Read value and increment offset.
869			v := single[br.peekByteFast()>>shift].entry
870			nBits := uint8(v)
871			br.advance(nBits)
872			bitsLeft -= int(nBits)
873			out[offset] = uint8(v >> 8)
874			offset++
875		}
876		decoded += offset - dstEvery*i
877		err = br.close()
878		if err != nil {
879			return nil, err
880		}
881	}
882	if dstSize != decoded {
883		return nil, errors.New("corruption detected: short output block")
884	}
885	return dst, nil
886}
887
888// Decompress4X will decompress a 4X encoded stream.
889// The length of the supplied input must match the end of a block exactly.
890// The *capacity* of the dst slice must match the destination size of
891// the uncompressed data exactly.
892func (d *Decoder) decompress4X8bitExactly(dst, src []byte) ([]byte, error) {
893	var br [4]bitReaderBytes
894	start := 6
895	for i := 0; i < 3; i++ {
896		length := int(src[i*2]) | (int(src[i*2+1]) << 8)
897		if start+length >= len(src) {
898			return nil, errors.New("truncated input (or invalid offset)")
899		}
900		err := br[i].init(src[start : start+length])
901		if err != nil {
902			return nil, err
903		}
904		start += length
905	}
906	err := br[3].init(src[start:])
907	if err != nil {
908		return nil, err
909	}
910
911	// destination, offset to match first output
912	dstSize := cap(dst)
913	dst = dst[:dstSize]
914	out := dst
915	dstEvery := (dstSize + 3) / 4
916
917	const shift = 0
918	const tlSize = 1 << 8
919	const tlMask = tlSize - 1
920	single := d.dt.single[:tlSize]
921
922	// Use temp table to avoid bound checks/append penalty.
923	var buf [256]byte
924	var off uint8
925	var decoded int
926
927	// Decode 4 values from each decoder/loop.
928	const bufoff = 256 / 4
929	for {
930		if br[0].off < 4 || br[1].off < 4 || br[2].off < 4 || br[3].off < 4 {
931			break
932		}
933
934		{
935			// Interleave 2 decodes.
936			const stream = 0
937			const stream2 = 1
938			br[stream].fillFast()
939			br[stream2].fillFast()
940
941			v := single[br[stream].peekByteFast()>>shift].entry
942			buf[off+bufoff*stream] = uint8(v >> 8)
943			br[stream].advance(uint8(v))
944
945			v2 := single[br[stream2].peekByteFast()>>shift].entry
946			buf[off+bufoff*stream2] = uint8(v2 >> 8)
947			br[stream2].advance(uint8(v2))
948
949			v = single[br[stream].peekByteFast()>>shift].entry
950			buf[off+bufoff*stream+1] = uint8(v >> 8)
951			br[stream].advance(uint8(v))
952
953			v2 = single[br[stream2].peekByteFast()>>shift].entry
954			buf[off+bufoff*stream2+1] = uint8(v2 >> 8)
955			br[stream2].advance(uint8(v2))
956
957			v = single[br[stream].peekByteFast()>>shift].entry
958			buf[off+bufoff*stream+2] = uint8(v >> 8)
959			br[stream].advance(uint8(v))
960
961			v2 = single[br[stream2].peekByteFast()>>shift].entry
962			buf[off+bufoff*stream2+2] = uint8(v2 >> 8)
963			br[stream2].advance(uint8(v2))
964
965			v = single[br[stream].peekByteFast()>>shift].entry
966			buf[off+bufoff*stream+3] = uint8(v >> 8)
967			br[stream].advance(uint8(v))
968
969			v2 = single[br[stream2].peekByteFast()>>shift].entry
970			buf[off+bufoff*stream2+3] = uint8(v2 >> 8)
971			br[stream2].advance(uint8(v2))
972		}
973
974		{
975			const stream = 2
976			const stream2 = 3
977			br[stream].fillFast()
978			br[stream2].fillFast()
979
980			v := single[br[stream].peekByteFast()>>shift].entry
981			buf[off+bufoff*stream] = uint8(v >> 8)
982			br[stream].advance(uint8(v))
983
984			v2 := single[br[stream2].peekByteFast()>>shift].entry
985			buf[off+bufoff*stream2] = uint8(v2 >> 8)
986			br[stream2].advance(uint8(v2))
987
988			v = single[br[stream].peekByteFast()>>shift].entry
989			buf[off+bufoff*stream+1] = uint8(v >> 8)
990			br[stream].advance(uint8(v))
991
992			v2 = single[br[stream2].peekByteFast()>>shift].entry
993			buf[off+bufoff*stream2+1] = uint8(v2 >> 8)
994			br[stream2].advance(uint8(v2))
995
996			v = single[br[stream].peekByteFast()>>shift].entry
997			buf[off+bufoff*stream+2] = uint8(v >> 8)
998			br[stream].advance(uint8(v))
999
1000			v2 = single[br[stream2].peekByteFast()>>shift].entry
1001			buf[off+bufoff*stream2+2] = uint8(v2 >> 8)
1002			br[stream2].advance(uint8(v2))
1003
1004			v = single[br[stream].peekByteFast()>>shift].entry
1005			buf[off+bufoff*stream+3] = uint8(v >> 8)
1006			br[stream].advance(uint8(v))
1007
1008			v2 = single[br[stream2].peekByteFast()>>shift].entry
1009			buf[off+bufoff*stream2+3] = uint8(v2 >> 8)
1010			br[stream2].advance(uint8(v2))
1011		}
1012
1013		off += 4
1014
1015		if off == bufoff {
1016			if bufoff > dstEvery {
1017				return nil, errors.New("corruption detected: stream overrun 1")
1018			}
1019			copy(out, buf[:bufoff])
1020			copy(out[dstEvery:], buf[bufoff:bufoff*2])
1021			copy(out[dstEvery*2:], buf[bufoff*2:bufoff*3])
1022			copy(out[dstEvery*3:], buf[bufoff*3:bufoff*4])
1023			off = 0
1024			out = out[bufoff:]
1025			decoded += 256
1026			// There must at least be 3 buffers left.
1027			if len(out) < dstEvery*3 {
1028				return nil, errors.New("corruption detected: stream overrun 2")
1029			}
1030		}
1031	}
1032	if off > 0 {
1033		ioff := int(off)
1034		if len(out) < dstEvery*3+ioff {
1035			return nil, errors.New("corruption detected: stream overrun 3")
1036		}
1037		copy(out, buf[:off])
1038		copy(out[dstEvery:dstEvery+ioff], buf[bufoff:bufoff*2])
1039		copy(out[dstEvery*2:dstEvery*2+ioff], buf[bufoff*2:bufoff*3])
1040		copy(out[dstEvery*3:dstEvery*3+ioff], buf[bufoff*3:bufoff*4])
1041		decoded += int(off) * 4
1042		out = out[off:]
1043	}
1044
1045	// Decode remaining.
1046	for i := range br {
1047		offset := dstEvery * i
1048		br := &br[i]
1049		bitsLeft := int(br.off*8) + int(64-br.bitsRead)
1050		for bitsLeft > 0 {
1051			if br.finished() {
1052				return nil, io.ErrUnexpectedEOF
1053			}
1054			if br.bitsRead >= 56 {
1055				if br.off >= 4 {
1056					v := br.in[br.off-4:]
1057					v = v[:4]
1058					low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
1059					br.value |= uint64(low) << (br.bitsRead - 32)
1060					br.bitsRead -= 32
1061					br.off -= 4
1062				} else {
1063					for br.off > 0 {
1064						br.value |= uint64(br.in[br.off-1]) << (br.bitsRead - 8)
1065						br.bitsRead -= 8
1066						br.off--
1067					}
1068				}
1069			}
1070			// end inline...
1071			if offset >= len(out) {
1072				return nil, errors.New("corruption detected: stream overrun 4")
1073			}
1074
1075			// Read value and increment offset.
1076			v := single[br.peekByteFast()>>shift].entry
1077			nBits := uint8(v)
1078			br.advance(nBits)
1079			bitsLeft -= int(nBits)
1080			out[offset] = uint8(v >> 8)
1081			offset++
1082		}
1083		decoded += offset - dstEvery*i
1084		err = br.close()
1085		if err != nil {
1086			return nil, err
1087		}
1088	}
1089	if dstSize != decoded {
1090		return nil, errors.New("corruption detected: short output block")
1091	}
1092	return dst, nil
1093}
1094
1095// matches will compare a decoding table to a coding table.
1096// Errors are written to the writer.
1097// Nothing will be written if table is ok.
1098func (s *Scratch) matches(ct cTable, w io.Writer) {
1099	if s == nil || len(s.dt.single) == 0 {
1100		return
1101	}
1102	dt := s.dt.single[:1<<s.actualTableLog]
1103	tablelog := s.actualTableLog
1104	ok := 0
1105	broken := 0
1106	for sym, enc := range ct {
1107		errs := 0
1108		broken++
1109		if enc.nBits == 0 {
1110			for _, dec := range dt {
1111				if uint8(dec.entry>>8) == byte(sym) {
1112					fmt.Fprintf(w, "symbol %x has decoder, but no encoder\n", sym)
1113					errs++
1114					break
1115				}
1116			}
1117			if errs == 0 {
1118				broken--
1119			}
1120			continue
1121		}
1122		// Unused bits in input
1123		ub := tablelog - enc.nBits
1124		top := enc.val << ub
1125		// decoder looks at top bits.
1126		dec := dt[top]
1127		if uint8(dec.entry) != enc.nBits {
1128			fmt.Fprintf(w, "symbol 0x%x bit size mismatch (enc: %d, dec:%d).\n", sym, enc.nBits, uint8(dec.entry))
1129			errs++
1130		}
1131		if uint8(dec.entry>>8) != uint8(sym) {
1132			fmt.Fprintf(w, "symbol 0x%x decoder output mismatch (enc: %d, dec:%d).\n", sym, sym, uint8(dec.entry>>8))
1133			errs++
1134		}
1135		if errs > 0 {
1136			fmt.Fprintf(w, "%d errros in base, stopping\n", errs)
1137			continue
1138		}
1139		// Ensure that all combinations are covered.
1140		for i := uint16(0); i < (1 << ub); i++ {
1141			vval := top | i
1142			dec := dt[vval]
1143			if uint8(dec.entry) != enc.nBits {
1144				fmt.Fprintf(w, "symbol 0x%x bit size mismatch (enc: %d, dec:%d).\n", vval, enc.nBits, uint8(dec.entry))
1145				errs++
1146			}
1147			if uint8(dec.entry>>8) != uint8(sym) {
1148				fmt.Fprintf(w, "symbol 0x%x decoder output mismatch (enc: %d, dec:%d).\n", vval, sym, uint8(dec.entry>>8))
1149				errs++
1150			}
1151			if errs > 20 {
1152				fmt.Fprintf(w, "%d errros, stopping\n", errs)
1153				break
1154			}
1155		}
1156		if errs == 0 {
1157			ok++
1158			broken--
1159		}
1160	}
1161	if broken > 0 {
1162		fmt.Fprintf(w, "%d broken, %d ok\n", broken, ok)
1163	}
1164}
1165