1package huff0
2
3import (
4	"errors"
5	"fmt"
6	"io"
7
8	"github.com/klauspost/compress/fse"
9)
10
11type dTable struct {
12	single []dEntrySingle
13	double []dEntryDouble
14}
15
16// single-symbols decoding
17type dEntrySingle struct {
18	entry uint16
19}
20
21// double-symbols decoding
22type dEntryDouble struct {
23	seq   uint16
24	nBits uint8
25	len   uint8
26}
27
28// Uses special code for all tables that are < 8 bits.
29const use8BitTables = true
30
31// ReadTable will read a table from the input.
32// The size of the input may be larger than the table definition.
33// Any content remaining after the table definition will be returned.
34// If no Scratch is provided a new one is allocated.
35// The returned Scratch can be used for encoding or decoding input using this table.
36func ReadTable(in []byte, s *Scratch) (s2 *Scratch, remain []byte, err error) {
37	s, err = s.prepare(in)
38	if err != nil {
39		return s, nil, err
40	}
41	if len(in) <= 1 {
42		return s, nil, errors.New("input too small for table")
43	}
44	iSize := in[0]
45	in = in[1:]
46	if iSize >= 128 {
47		// Uncompressed
48		oSize := iSize - 127
49		iSize = (oSize + 1) / 2
50		if int(iSize) > len(in) {
51			return s, nil, errors.New("input too small for table")
52		}
53		for n := uint8(0); n < oSize; n += 2 {
54			v := in[n/2]
55			s.huffWeight[n] = v >> 4
56			s.huffWeight[n+1] = v & 15
57		}
58		s.symbolLen = uint16(oSize)
59		in = in[iSize:]
60	} else {
61		if len(in) < int(iSize) {
62			return s, nil, fmt.Errorf("input too small for table, want %d bytes, have %d", iSize, len(in))
63		}
64		// FSE compressed weights
65		s.fse.DecompressLimit = 255
66		hw := s.huffWeight[:]
67		s.fse.Out = hw
68		b, err := fse.Decompress(in[:iSize], s.fse)
69		s.fse.Out = nil
70		if err != nil {
71			return s, nil, err
72		}
73		if len(b) > 255 {
74			return s, nil, errors.New("corrupt input: output table too large")
75		}
76		s.symbolLen = uint16(len(b))
77		in = in[iSize:]
78	}
79
80	// collect weight stats
81	var rankStats [16]uint32
82	weightTotal := uint32(0)
83	for _, v := range s.huffWeight[:s.symbolLen] {
84		if v > tableLogMax {
85			return s, nil, errors.New("corrupt input: weight too large")
86		}
87		v2 := v & 15
88		rankStats[v2]++
89		// (1 << (v2-1)) is slower since the compiler cannot prove that v2 isn't 0.
90		weightTotal += (1 << v2) >> 1
91	}
92	if weightTotal == 0 {
93		return s, nil, errors.New("corrupt input: weights zero")
94	}
95
96	// get last non-null symbol weight (implied, total must be 2^n)
97	{
98		tableLog := highBit32(weightTotal) + 1
99		if tableLog > tableLogMax {
100			return s, nil, errors.New("corrupt input: tableLog too big")
101		}
102		s.actualTableLog = uint8(tableLog)
103		// determine last weight
104		{
105			total := uint32(1) << tableLog
106			rest := total - weightTotal
107			verif := uint32(1) << highBit32(rest)
108			lastWeight := highBit32(rest) + 1
109			if verif != rest {
110				// last value must be a clean power of 2
111				return s, nil, errors.New("corrupt input: last value not power of two")
112			}
113			s.huffWeight[s.symbolLen] = uint8(lastWeight)
114			s.symbolLen++
115			rankStats[lastWeight]++
116		}
117	}
118
119	if (rankStats[1] < 2) || (rankStats[1]&1 != 0) {
120		// by construction : at least 2 elts of rank 1, must be even
121		return s, nil, errors.New("corrupt input: min elt size, even check failed ")
122	}
123
124	// TODO: Choose between single/double symbol decoding
125
126	// Calculate starting value for each rank
127	{
128		var nextRankStart uint32
129		for n := uint8(1); n < s.actualTableLog+1; n++ {
130			current := nextRankStart
131			nextRankStart += rankStats[n] << (n - 1)
132			rankStats[n] = current
133		}
134	}
135
136	// fill DTable (always full size)
137	tSize := 1 << tableLogMax
138	if len(s.dt.single) != tSize {
139		s.dt.single = make([]dEntrySingle, tSize)
140	}
141	cTable := s.prevTable
142	if cap(cTable) < maxSymbolValue+1 {
143		cTable = make([]cTableEntry, 0, maxSymbolValue+1)
144	}
145	cTable = cTable[:maxSymbolValue+1]
146	s.prevTable = cTable[:s.symbolLen]
147	s.prevTableLog = s.actualTableLog
148
149	for n, w := range s.huffWeight[:s.symbolLen] {
150		if w == 0 {
151			cTable[n] = cTableEntry{
152				val:   0,
153				nBits: 0,
154			}
155			continue
156		}
157		length := (uint32(1) << w) >> 1
158		d := dEntrySingle{
159			entry: uint16(s.actualTableLog+1-w) | (uint16(n) << 8),
160		}
161
162		rank := &rankStats[w]
163		cTable[n] = cTableEntry{
164			val:   uint16(*rank >> (w - 1)),
165			nBits: uint8(d.entry),
166		}
167
168		single := s.dt.single[*rank : *rank+length]
169		for i := range single {
170			single[i] = d
171		}
172		*rank += length
173	}
174
175	return s, in, nil
176}
177
178// Decompress1X will decompress a 1X encoded stream.
179// The length of the supplied input must match the end of a block exactly.
180// Before this is called, the table must be initialized with ReadTable unless
181// the encoder re-used the table.
182// deprecated: Use the stateless Decoder() to get a concurrent version.
183func (s *Scratch) Decompress1X(in []byte) (out []byte, err error) {
184	if cap(s.Out) < s.MaxDecodedSize {
185		s.Out = make([]byte, s.MaxDecodedSize)
186	}
187	s.Out = s.Out[:0:s.MaxDecodedSize]
188	s.Out, err = s.Decoder().Decompress1X(s.Out, in)
189	return s.Out, err
190}
191
192// Decompress4X will decompress a 4X encoded stream.
193// Before this is called, the table must be initialized with ReadTable unless
194// the encoder re-used the table.
195// The length of the supplied input must match the end of a block exactly.
196// The destination size of the uncompressed data must be known and provided.
197// deprecated: Use the stateless Decoder() to get a concurrent version.
198func (s *Scratch) Decompress4X(in []byte, dstSize int) (out []byte, err error) {
199	if dstSize > s.MaxDecodedSize {
200		return nil, ErrMaxDecodedSizeExceeded
201	}
202	if cap(s.Out) < dstSize {
203		s.Out = make([]byte, s.MaxDecodedSize)
204	}
205	s.Out = s.Out[:0:dstSize]
206	s.Out, err = s.Decoder().Decompress4X(s.Out, in)
207	return s.Out, err
208}
209
210// Decoder will return a stateless decoder that can be used by multiple
211// decompressors concurrently.
212// Before this is called, the table must be initialized with ReadTable.
213// The Decoder is still linked to the scratch buffer so that cannot be reused.
214// However, it is safe to discard the scratch.
215func (s *Scratch) Decoder() *Decoder {
216	return &Decoder{
217		dt:             s.dt,
218		actualTableLog: s.actualTableLog,
219	}
220}
221
222// Decoder provides stateless decoding.
223type Decoder struct {
224	dt             dTable
225	actualTableLog uint8
226}
227
228// Decompress1X will decompress a 1X encoded stream.
229// The cap of the output buffer will be the maximum decompressed size.
230// The length of the supplied input must match the end of a block exactly.
231func (d *Decoder) Decompress1X(dst, src []byte) ([]byte, error) {
232	if len(d.dt.single) == 0 {
233		return nil, errors.New("no table loaded")
234	}
235	if use8BitTables && d.actualTableLog <= 8 {
236		return d.decompress1X8Bit(dst, src)
237	}
238	var br bitReaderShifted
239	err := br.init(src)
240	if err != nil {
241		return dst, err
242	}
243	maxDecodedSize := cap(dst)
244	dst = dst[:0]
245
246	// Avoid bounds check by always having full sized table.
247	const tlSize = 1 << tableLogMax
248	const tlMask = tlSize - 1
249	dt := d.dt.single[:tlSize]
250
251	// Use temp table to avoid bound checks/append penalty.
252	var buf [256]byte
253	var off uint8
254
255	for br.off >= 8 {
256		br.fillFast()
257		v := dt[br.peekBitsFast(d.actualTableLog)&tlMask]
258		br.advance(uint8(v.entry))
259		buf[off+0] = uint8(v.entry >> 8)
260
261		v = dt[br.peekBitsFast(d.actualTableLog)&tlMask]
262		br.advance(uint8(v.entry))
263		buf[off+1] = uint8(v.entry >> 8)
264
265		// Refill
266		br.fillFast()
267
268		v = dt[br.peekBitsFast(d.actualTableLog)&tlMask]
269		br.advance(uint8(v.entry))
270		buf[off+2] = uint8(v.entry >> 8)
271
272		v = dt[br.peekBitsFast(d.actualTableLog)&tlMask]
273		br.advance(uint8(v.entry))
274		buf[off+3] = uint8(v.entry >> 8)
275
276		off += 4
277		if off == 0 {
278			if len(dst)+256 > maxDecodedSize {
279				br.close()
280				return nil, ErrMaxDecodedSizeExceeded
281			}
282			dst = append(dst, buf[:]...)
283		}
284	}
285
286	if len(dst)+int(off) > maxDecodedSize {
287		br.close()
288		return nil, ErrMaxDecodedSizeExceeded
289	}
290	dst = append(dst, buf[:off]...)
291
292	// br < 8, so uint8 is fine
293	bitsLeft := uint8(br.off)*8 + 64 - br.bitsRead
294	for bitsLeft > 0 {
295		br.fill()
296		if false && br.bitsRead >= 32 {
297			if br.off >= 4 {
298				v := br.in[br.off-4:]
299				v = v[:4]
300				low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
301				br.value = (br.value << 32) | uint64(low)
302				br.bitsRead -= 32
303				br.off -= 4
304			} else {
305				for br.off > 0 {
306					br.value = (br.value << 8) | uint64(br.in[br.off-1])
307					br.bitsRead -= 8
308					br.off--
309				}
310			}
311		}
312		if len(dst) >= maxDecodedSize {
313			br.close()
314			return nil, ErrMaxDecodedSizeExceeded
315		}
316		v := d.dt.single[br.peekBitsFast(d.actualTableLog)&tlMask]
317		nBits := uint8(v.entry)
318		br.advance(nBits)
319		bitsLeft -= nBits
320		dst = append(dst, uint8(v.entry>>8))
321	}
322	return dst, br.close()
323}
324
325// decompress1X8Bit will decompress a 1X encoded stream with tablelog <= 8.
326// The cap of the output buffer will be the maximum decompressed size.
327// The length of the supplied input must match the end of a block exactly.
328func (d *Decoder) decompress1X8Bit(dst, src []byte) ([]byte, error) {
329	if d.actualTableLog == 8 {
330		return d.decompress1X8BitExactly(dst, src)
331	}
332	var br bitReaderBytes
333	err := br.init(src)
334	if err != nil {
335		return dst, err
336	}
337	maxDecodedSize := cap(dst)
338	dst = dst[:0]
339
340	// Avoid bounds check by always having full sized table.
341	dt := d.dt.single[:256]
342
343	// Use temp table to avoid bound checks/append penalty.
344	var buf [256]byte
345	var off uint8
346
347	switch d.actualTableLog {
348	case 8:
349		const shift = 8 - 8
350		for br.off >= 4 {
351			br.fillFast()
352			v := dt[uint8(br.value>>(56+shift))]
353			br.advance(uint8(v.entry))
354			buf[off+0] = uint8(v.entry >> 8)
355
356			v = dt[uint8(br.value>>(56+shift))]
357			br.advance(uint8(v.entry))
358			buf[off+1] = uint8(v.entry >> 8)
359
360			v = dt[uint8(br.value>>(56+shift))]
361			br.advance(uint8(v.entry))
362			buf[off+2] = uint8(v.entry >> 8)
363
364			v = dt[uint8(br.value>>(56+shift))]
365			br.advance(uint8(v.entry))
366			buf[off+3] = uint8(v.entry >> 8)
367
368			off += 4
369			if off == 0 {
370				if len(dst)+256 > maxDecodedSize {
371					br.close()
372					return nil, ErrMaxDecodedSizeExceeded
373				}
374				dst = append(dst, buf[:]...)
375			}
376		}
377	case 7:
378		const shift = 8 - 7
379		for br.off >= 4 {
380			br.fillFast()
381			v := dt[uint8(br.value>>(56+shift))]
382			br.advance(uint8(v.entry))
383			buf[off+0] = uint8(v.entry >> 8)
384
385			v = dt[uint8(br.value>>(56+shift))]
386			br.advance(uint8(v.entry))
387			buf[off+1] = uint8(v.entry >> 8)
388
389			v = dt[uint8(br.value>>(56+shift))]
390			br.advance(uint8(v.entry))
391			buf[off+2] = uint8(v.entry >> 8)
392
393			v = dt[uint8(br.value>>(56+shift))]
394			br.advance(uint8(v.entry))
395			buf[off+3] = uint8(v.entry >> 8)
396
397			off += 4
398			if off == 0 {
399				if len(dst)+256 > maxDecodedSize {
400					br.close()
401					return nil, ErrMaxDecodedSizeExceeded
402				}
403				dst = append(dst, buf[:]...)
404			}
405		}
406	case 6:
407		const shift = 8 - 6
408		for br.off >= 4 {
409			br.fillFast()
410			v := dt[uint8(br.value>>(56+shift))]
411			br.advance(uint8(v.entry))
412			buf[off+0] = uint8(v.entry >> 8)
413
414			v = dt[uint8(br.value>>(56+shift))]
415			br.advance(uint8(v.entry))
416			buf[off+1] = uint8(v.entry >> 8)
417
418			v = dt[uint8(br.value>>(56+shift))]
419			br.advance(uint8(v.entry))
420			buf[off+2] = uint8(v.entry >> 8)
421
422			v = dt[uint8(br.value>>(56+shift))]
423			br.advance(uint8(v.entry))
424			buf[off+3] = uint8(v.entry >> 8)
425
426			off += 4
427			if off == 0 {
428				if len(dst)+256 > maxDecodedSize {
429					br.close()
430					return nil, ErrMaxDecodedSizeExceeded
431				}
432				dst = append(dst, buf[:]...)
433			}
434		}
435	case 5:
436		const shift = 8 - 5
437		for br.off >= 4 {
438			br.fillFast()
439			v := dt[uint8(br.value>>(56+shift))]
440			br.advance(uint8(v.entry))
441			buf[off+0] = uint8(v.entry >> 8)
442
443			v = dt[uint8(br.value>>(56+shift))]
444			br.advance(uint8(v.entry))
445			buf[off+1] = uint8(v.entry >> 8)
446
447			v = dt[uint8(br.value>>(56+shift))]
448			br.advance(uint8(v.entry))
449			buf[off+2] = uint8(v.entry >> 8)
450
451			v = dt[uint8(br.value>>(56+shift))]
452			br.advance(uint8(v.entry))
453			buf[off+3] = uint8(v.entry >> 8)
454
455			off += 4
456			if off == 0 {
457				if len(dst)+256 > maxDecodedSize {
458					br.close()
459					return nil, ErrMaxDecodedSizeExceeded
460				}
461				dst = append(dst, buf[:]...)
462			}
463		}
464	case 4:
465		const shift = 8 - 4
466		for br.off >= 4 {
467			br.fillFast()
468			v := dt[uint8(br.value>>(56+shift))]
469			br.advance(uint8(v.entry))
470			buf[off+0] = uint8(v.entry >> 8)
471
472			v = dt[uint8(br.value>>(56+shift))]
473			br.advance(uint8(v.entry))
474			buf[off+1] = uint8(v.entry >> 8)
475
476			v = dt[uint8(br.value>>(56+shift))]
477			br.advance(uint8(v.entry))
478			buf[off+2] = uint8(v.entry >> 8)
479
480			v = dt[uint8(br.value>>(56+shift))]
481			br.advance(uint8(v.entry))
482			buf[off+3] = uint8(v.entry >> 8)
483
484			off += 4
485			if off == 0 {
486				if len(dst)+256 > maxDecodedSize {
487					br.close()
488					return nil, ErrMaxDecodedSizeExceeded
489				}
490				dst = append(dst, buf[:]...)
491			}
492		}
493	case 3:
494		const shift = 8 - 3
495		for br.off >= 4 {
496			br.fillFast()
497			v := dt[uint8(br.value>>(56+shift))]
498			br.advance(uint8(v.entry))
499			buf[off+0] = uint8(v.entry >> 8)
500
501			v = dt[uint8(br.value>>(56+shift))]
502			br.advance(uint8(v.entry))
503			buf[off+1] = uint8(v.entry >> 8)
504
505			v = dt[uint8(br.value>>(56+shift))]
506			br.advance(uint8(v.entry))
507			buf[off+2] = uint8(v.entry >> 8)
508
509			v = dt[uint8(br.value>>(56+shift))]
510			br.advance(uint8(v.entry))
511			buf[off+3] = uint8(v.entry >> 8)
512
513			off += 4
514			if off == 0 {
515				if len(dst)+256 > maxDecodedSize {
516					br.close()
517					return nil, ErrMaxDecodedSizeExceeded
518				}
519				dst = append(dst, buf[:]...)
520			}
521		}
522	case 2:
523		const shift = 8 - 2
524		for br.off >= 4 {
525			br.fillFast()
526			v := dt[uint8(br.value>>(56+shift))]
527			br.advance(uint8(v.entry))
528			buf[off+0] = uint8(v.entry >> 8)
529
530			v = dt[uint8(br.value>>(56+shift))]
531			br.advance(uint8(v.entry))
532			buf[off+1] = uint8(v.entry >> 8)
533
534			v = dt[uint8(br.value>>(56+shift))]
535			br.advance(uint8(v.entry))
536			buf[off+2] = uint8(v.entry >> 8)
537
538			v = dt[uint8(br.value>>(56+shift))]
539			br.advance(uint8(v.entry))
540			buf[off+3] = uint8(v.entry >> 8)
541
542			off += 4
543			if off == 0 {
544				if len(dst)+256 > maxDecodedSize {
545					br.close()
546					return nil, ErrMaxDecodedSizeExceeded
547				}
548				dst = append(dst, buf[:]...)
549			}
550		}
551	case 1:
552		const shift = 8 - 1
553		for br.off >= 4 {
554			br.fillFast()
555			v := dt[uint8(br.value>>(56+shift))]
556			br.advance(uint8(v.entry))
557			buf[off+0] = uint8(v.entry >> 8)
558
559			v = dt[uint8(br.value>>(56+shift))]
560			br.advance(uint8(v.entry))
561			buf[off+1] = uint8(v.entry >> 8)
562
563			v = dt[uint8(br.value>>(56+shift))]
564			br.advance(uint8(v.entry))
565			buf[off+2] = uint8(v.entry >> 8)
566
567			v = dt[uint8(br.value>>(56+shift))]
568			br.advance(uint8(v.entry))
569			buf[off+3] = uint8(v.entry >> 8)
570
571			off += 4
572			if off == 0 {
573				if len(dst)+256 > maxDecodedSize {
574					br.close()
575					return nil, ErrMaxDecodedSizeExceeded
576				}
577				dst = append(dst, buf[:]...)
578			}
579		}
580	default:
581		return nil, fmt.Errorf("invalid tablelog: %d", d.actualTableLog)
582	}
583
584	if len(dst)+int(off) > maxDecodedSize {
585		br.close()
586		return nil, ErrMaxDecodedSizeExceeded
587	}
588	dst = append(dst, buf[:off]...)
589
590	// br < 4, so uint8 is fine
591	bitsLeft := int8(uint8(br.off)*8 + (64 - br.bitsRead))
592	shift := (8 - d.actualTableLog) & 7
593
594	for bitsLeft > 0 {
595		if br.bitsRead >= 64-8 {
596			for br.off > 0 {
597				br.value |= uint64(br.in[br.off-1]) << (br.bitsRead - 8)
598				br.bitsRead -= 8
599				br.off--
600			}
601		}
602		if len(dst) >= maxDecodedSize {
603			br.close()
604			return nil, ErrMaxDecodedSizeExceeded
605		}
606		v := dt[br.peekByteFast()>>shift]
607		nBits := uint8(v.entry)
608		br.advance(nBits)
609		bitsLeft -= int8(nBits)
610		dst = append(dst, uint8(v.entry>>8))
611	}
612	return dst, br.close()
613}
614
615// decompress1X8Bit will decompress a 1X encoded stream with tablelog <= 8.
616// The cap of the output buffer will be the maximum decompressed size.
617// The length of the supplied input must match the end of a block exactly.
618func (d *Decoder) decompress1X8BitExactly(dst, src []byte) ([]byte, error) {
619	var br bitReaderBytes
620	err := br.init(src)
621	if err != nil {
622		return dst, err
623	}
624	maxDecodedSize := cap(dst)
625	dst = dst[:0]
626
627	// Avoid bounds check by always having full sized table.
628	dt := d.dt.single[:256]
629
630	// Use temp table to avoid bound checks/append penalty.
631	var buf [256]byte
632	var off uint8
633
634	const shift = 56
635
636	//fmt.Printf("mask: %b, tl:%d\n", mask, d.actualTableLog)
637	for br.off >= 4 {
638		br.fillFast()
639		v := dt[uint8(br.value>>shift)]
640		br.advance(uint8(v.entry))
641		buf[off+0] = uint8(v.entry >> 8)
642
643		v = dt[uint8(br.value>>shift)]
644		br.advance(uint8(v.entry))
645		buf[off+1] = uint8(v.entry >> 8)
646
647		v = dt[uint8(br.value>>shift)]
648		br.advance(uint8(v.entry))
649		buf[off+2] = uint8(v.entry >> 8)
650
651		v = dt[uint8(br.value>>shift)]
652		br.advance(uint8(v.entry))
653		buf[off+3] = uint8(v.entry >> 8)
654
655		off += 4
656		if off == 0 {
657			if len(dst)+256 > maxDecodedSize {
658				br.close()
659				return nil, ErrMaxDecodedSizeExceeded
660			}
661			dst = append(dst, buf[:]...)
662		}
663	}
664
665	if len(dst)+int(off) > maxDecodedSize {
666		br.close()
667		return nil, ErrMaxDecodedSizeExceeded
668	}
669	dst = append(dst, buf[:off]...)
670
671	// br < 4, so uint8 is fine
672	bitsLeft := int8(uint8(br.off)*8 + (64 - br.bitsRead))
673	for bitsLeft > 0 {
674		if br.bitsRead >= 64-8 {
675			for br.off > 0 {
676				br.value |= uint64(br.in[br.off-1]) << (br.bitsRead - 8)
677				br.bitsRead -= 8
678				br.off--
679			}
680		}
681		if len(dst) >= maxDecodedSize {
682			br.close()
683			return nil, ErrMaxDecodedSizeExceeded
684		}
685		v := dt[br.peekByteFast()]
686		nBits := uint8(v.entry)
687		br.advance(nBits)
688		bitsLeft -= int8(nBits)
689		dst = append(dst, uint8(v.entry>>8))
690	}
691	return dst, br.close()
692}
693
694// Decompress4X will decompress a 4X encoded stream.
695// The length of the supplied input must match the end of a block exactly.
696// The *capacity* of the dst slice must match the destination size of
697// the uncompressed data exactly.
698func (d *Decoder) Decompress4X(dst, src []byte) ([]byte, error) {
699	if len(d.dt.single) == 0 {
700		return nil, errors.New("no table loaded")
701	}
702	if len(src) < 6+(4*1) {
703		return nil, errors.New("input too small")
704	}
705	if use8BitTables && d.actualTableLog <= 8 {
706		return d.decompress4X8bit(dst, src)
707	}
708
709	var br [4]bitReaderShifted
710	start := 6
711	for i := 0; i < 3; i++ {
712		length := int(src[i*2]) | (int(src[i*2+1]) << 8)
713		if start+length >= len(src) {
714			return nil, errors.New("truncated input (or invalid offset)")
715		}
716		err := br[i].init(src[start : start+length])
717		if err != nil {
718			return nil, err
719		}
720		start += length
721	}
722	err := br[3].init(src[start:])
723	if err != nil {
724		return nil, err
725	}
726
727	// destination, offset to match first output
728	dstSize := cap(dst)
729	dst = dst[:dstSize]
730	out := dst
731	dstEvery := (dstSize + 3) / 4
732
733	const tlSize = 1 << tableLogMax
734	const tlMask = tlSize - 1
735	single := d.dt.single[:tlSize]
736
737	// Use temp table to avoid bound checks/append penalty.
738	var buf [256]byte
739	var off uint8
740	var decoded int
741
742	// Decode 2 values from each decoder/loop.
743	const bufoff = 256 / 4
744	for {
745		if br[0].off < 4 || br[1].off < 4 || br[2].off < 4 || br[3].off < 4 {
746			break
747		}
748
749		{
750			const stream = 0
751			const stream2 = 1
752			br[stream].fillFast()
753			br[stream2].fillFast()
754
755			val := br[stream].peekBitsFast(d.actualTableLog)
756			v := single[val&tlMask]
757			br[stream].advance(uint8(v.entry))
758			buf[off+bufoff*stream] = uint8(v.entry >> 8)
759
760			val2 := br[stream2].peekBitsFast(d.actualTableLog)
761			v2 := single[val2&tlMask]
762			br[stream2].advance(uint8(v2.entry))
763			buf[off+bufoff*stream2] = uint8(v2.entry >> 8)
764
765			val = br[stream].peekBitsFast(d.actualTableLog)
766			v = single[val&tlMask]
767			br[stream].advance(uint8(v.entry))
768			buf[off+bufoff*stream+1] = uint8(v.entry >> 8)
769
770			val2 = br[stream2].peekBitsFast(d.actualTableLog)
771			v2 = single[val2&tlMask]
772			br[stream2].advance(uint8(v2.entry))
773			buf[off+bufoff*stream2+1] = uint8(v2.entry >> 8)
774		}
775
776		{
777			const stream = 2
778			const stream2 = 3
779			br[stream].fillFast()
780			br[stream2].fillFast()
781
782			val := br[stream].peekBitsFast(d.actualTableLog)
783			v := single[val&tlMask]
784			br[stream].advance(uint8(v.entry))
785			buf[off+bufoff*stream] = uint8(v.entry >> 8)
786
787			val2 := br[stream2].peekBitsFast(d.actualTableLog)
788			v2 := single[val2&tlMask]
789			br[stream2].advance(uint8(v2.entry))
790			buf[off+bufoff*stream2] = uint8(v2.entry >> 8)
791
792			val = br[stream].peekBitsFast(d.actualTableLog)
793			v = single[val&tlMask]
794			br[stream].advance(uint8(v.entry))
795			buf[off+bufoff*stream+1] = uint8(v.entry >> 8)
796
797			val2 = br[stream2].peekBitsFast(d.actualTableLog)
798			v2 = single[val2&tlMask]
799			br[stream2].advance(uint8(v2.entry))
800			buf[off+bufoff*stream2+1] = uint8(v2.entry >> 8)
801		}
802
803		off += 2
804
805		if off == bufoff {
806			if bufoff > dstEvery {
807				return nil, errors.New("corruption detected: stream overrun 1")
808			}
809			copy(out, buf[:bufoff])
810			copy(out[dstEvery:], buf[bufoff:bufoff*2])
811			copy(out[dstEvery*2:], buf[bufoff*2:bufoff*3])
812			copy(out[dstEvery*3:], buf[bufoff*3:bufoff*4])
813			off = 0
814			out = out[bufoff:]
815			decoded += 256
816			// There must at least be 3 buffers left.
817			if len(out) < dstEvery*3 {
818				return nil, errors.New("corruption detected: stream overrun 2")
819			}
820		}
821	}
822	if off > 0 {
823		ioff := int(off)
824		if len(out) < dstEvery*3+ioff {
825			return nil, errors.New("corruption detected: stream overrun 3")
826		}
827		copy(out, buf[:off])
828		copy(out[dstEvery:dstEvery+ioff], buf[bufoff:bufoff*2])
829		copy(out[dstEvery*2:dstEvery*2+ioff], buf[bufoff*2:bufoff*3])
830		copy(out[dstEvery*3:dstEvery*3+ioff], buf[bufoff*3:bufoff*4])
831		decoded += int(off) * 4
832		out = out[off:]
833	}
834
835	// Decode remaining.
836	for i := range br {
837		offset := dstEvery * i
838		br := &br[i]
839		bitsLeft := br.off*8 + uint(64-br.bitsRead)
840		for bitsLeft > 0 {
841			br.fill()
842			if false && br.bitsRead >= 32 {
843				if br.off >= 4 {
844					v := br.in[br.off-4:]
845					v = v[:4]
846					low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
847					br.value = (br.value << 32) | uint64(low)
848					br.bitsRead -= 32
849					br.off -= 4
850				} else {
851					for br.off > 0 {
852						br.value = (br.value << 8) | uint64(br.in[br.off-1])
853						br.bitsRead -= 8
854						br.off--
855					}
856				}
857			}
858			// end inline...
859			if offset >= len(out) {
860				return nil, errors.New("corruption detected: stream overrun 4")
861			}
862
863			// Read value and increment offset.
864			val := br.peekBitsFast(d.actualTableLog)
865			v := single[val&tlMask].entry
866			nBits := uint8(v)
867			br.advance(nBits)
868			bitsLeft -= uint(nBits)
869			out[offset] = uint8(v >> 8)
870			offset++
871		}
872		decoded += offset - dstEvery*i
873		err = br.close()
874		if err != nil {
875			return nil, err
876		}
877	}
878	if dstSize != decoded {
879		return nil, errors.New("corruption detected: short output block")
880	}
881	return dst, nil
882}
883
884// Decompress4X will decompress a 4X encoded stream.
885// The length of the supplied input must match the end of a block exactly.
886// The *capacity* of the dst slice must match the destination size of
887// the uncompressed data exactly.
888func (d *Decoder) decompress4X8bit(dst, src []byte) ([]byte, error) {
889	if d.actualTableLog == 8 {
890		return d.decompress4X8bitExactly(dst, src)
891	}
892
893	var br [4]bitReaderBytes
894	start := 6
895	for i := 0; i < 3; i++ {
896		length := int(src[i*2]) | (int(src[i*2+1]) << 8)
897		if start+length >= len(src) {
898			return nil, errors.New("truncated input (or invalid offset)")
899		}
900		err := br[i].init(src[start : start+length])
901		if err != nil {
902			return nil, err
903		}
904		start += length
905	}
906	err := br[3].init(src[start:])
907	if err != nil {
908		return nil, err
909	}
910
911	// destination, offset to match first output
912	dstSize := cap(dst)
913	dst = dst[:dstSize]
914	out := dst
915	dstEvery := (dstSize + 3) / 4
916
917	shift := (8 - d.actualTableLog) & 7
918
919	const tlSize = 1 << 8
920	single := d.dt.single[:tlSize]
921
922	// Use temp table to avoid bound checks/append penalty.
923	var buf [256]byte
924	var off uint8
925	var decoded int
926
927	// Decode 4 values from each decoder/loop.
928	const bufoff = 256 / 4
929	for {
930		if br[0].off < 4 || br[1].off < 4 || br[2].off < 4 || br[3].off < 4 {
931			break
932		}
933
934		{
935			// Interleave 2 decodes.
936			const stream = 0
937			const stream2 = 1
938			br[stream].fillFast()
939			br[stream2].fillFast()
940
941			v := single[br[stream].peekByteFast()>>shift].entry
942			buf[off+bufoff*stream] = uint8(v >> 8)
943			br[stream].advance(uint8(v))
944
945			v2 := single[br[stream2].peekByteFast()>>shift].entry
946			buf[off+bufoff*stream2] = uint8(v2 >> 8)
947			br[stream2].advance(uint8(v2))
948
949			v = single[br[stream].peekByteFast()>>shift].entry
950			buf[off+bufoff*stream+1] = uint8(v >> 8)
951			br[stream].advance(uint8(v))
952
953			v2 = single[br[stream2].peekByteFast()>>shift].entry
954			buf[off+bufoff*stream2+1] = uint8(v2 >> 8)
955			br[stream2].advance(uint8(v2))
956
957			v = single[br[stream].peekByteFast()>>shift].entry
958			buf[off+bufoff*stream+2] = uint8(v >> 8)
959			br[stream].advance(uint8(v))
960
961			v2 = single[br[stream2].peekByteFast()>>shift].entry
962			buf[off+bufoff*stream2+2] = uint8(v2 >> 8)
963			br[stream2].advance(uint8(v2))
964
965			v = single[br[stream].peekByteFast()>>shift].entry
966			buf[off+bufoff*stream+3] = uint8(v >> 8)
967			br[stream].advance(uint8(v))
968
969			v2 = single[br[stream2].peekByteFast()>>shift].entry
970			buf[off+bufoff*stream2+3] = uint8(v2 >> 8)
971			br[stream2].advance(uint8(v2))
972		}
973
974		{
975			const stream = 2
976			const stream2 = 3
977			br[stream].fillFast()
978			br[stream2].fillFast()
979
980			v := single[br[stream].peekByteFast()>>shift].entry
981			buf[off+bufoff*stream] = uint8(v >> 8)
982			br[stream].advance(uint8(v))
983
984			v2 := single[br[stream2].peekByteFast()>>shift].entry
985			buf[off+bufoff*stream2] = uint8(v2 >> 8)
986			br[stream2].advance(uint8(v2))
987
988			v = single[br[stream].peekByteFast()>>shift].entry
989			buf[off+bufoff*stream+1] = uint8(v >> 8)
990			br[stream].advance(uint8(v))
991
992			v2 = single[br[stream2].peekByteFast()>>shift].entry
993			buf[off+bufoff*stream2+1] = uint8(v2 >> 8)
994			br[stream2].advance(uint8(v2))
995
996			v = single[br[stream].peekByteFast()>>shift].entry
997			buf[off+bufoff*stream+2] = uint8(v >> 8)
998			br[stream].advance(uint8(v))
999
1000			v2 = single[br[stream2].peekByteFast()>>shift].entry
1001			buf[off+bufoff*stream2+2] = uint8(v2 >> 8)
1002			br[stream2].advance(uint8(v2))
1003
1004			v = single[br[stream].peekByteFast()>>shift].entry
1005			buf[off+bufoff*stream+3] = uint8(v >> 8)
1006			br[stream].advance(uint8(v))
1007
1008			v2 = single[br[stream2].peekByteFast()>>shift].entry
1009			buf[off+bufoff*stream2+3] = uint8(v2 >> 8)
1010			br[stream2].advance(uint8(v2))
1011		}
1012
1013		off += 4
1014
1015		if off == bufoff {
1016			if bufoff > dstEvery {
1017				return nil, errors.New("corruption detected: stream overrun 1")
1018			}
1019			copy(out, buf[:bufoff])
1020			copy(out[dstEvery:], buf[bufoff:bufoff*2])
1021			copy(out[dstEvery*2:], buf[bufoff*2:bufoff*3])
1022			copy(out[dstEvery*3:], buf[bufoff*3:bufoff*4])
1023			off = 0
1024			out = out[bufoff:]
1025			decoded += 256
1026			// There must at least be 3 buffers left.
1027			if len(out) < dstEvery*3 {
1028				return nil, errors.New("corruption detected: stream overrun 2")
1029			}
1030		}
1031	}
1032	if off > 0 {
1033		ioff := int(off)
1034		if len(out) < dstEvery*3+ioff {
1035			return nil, errors.New("corruption detected: stream overrun 3")
1036		}
1037		copy(out, buf[:off])
1038		copy(out[dstEvery:dstEvery+ioff], buf[bufoff:bufoff*2])
1039		copy(out[dstEvery*2:dstEvery*2+ioff], buf[bufoff*2:bufoff*3])
1040		copy(out[dstEvery*3:dstEvery*3+ioff], buf[bufoff*3:bufoff*4])
1041		decoded += int(off) * 4
1042		out = out[off:]
1043	}
1044
1045	// Decode remaining.
1046	for i := range br {
1047		offset := dstEvery * i
1048		br := &br[i]
1049		bitsLeft := int(br.off*8) + int(64-br.bitsRead)
1050		for bitsLeft > 0 {
1051			if br.finished() {
1052				return nil, io.ErrUnexpectedEOF
1053			}
1054			if br.bitsRead >= 56 {
1055				if br.off >= 4 {
1056					v := br.in[br.off-4:]
1057					v = v[:4]
1058					low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
1059					br.value |= uint64(low) << (br.bitsRead - 32)
1060					br.bitsRead -= 32
1061					br.off -= 4
1062				} else {
1063					for br.off > 0 {
1064						br.value |= uint64(br.in[br.off-1]) << (br.bitsRead - 8)
1065						br.bitsRead -= 8
1066						br.off--
1067					}
1068				}
1069			}
1070			// end inline...
1071			if offset >= len(out) {
1072				return nil, errors.New("corruption detected: stream overrun 4")
1073			}
1074
1075			// Read value and increment offset.
1076			v := single[br.peekByteFast()>>shift].entry
1077			nBits := uint8(v)
1078			br.advance(nBits)
1079			bitsLeft -= int(nBits)
1080			out[offset] = uint8(v >> 8)
1081			offset++
1082		}
1083		decoded += offset - dstEvery*i
1084		err = br.close()
1085		if err != nil {
1086			return nil, err
1087		}
1088	}
1089	if dstSize != decoded {
1090		return nil, errors.New("corruption detected: short output block")
1091	}
1092	return dst, nil
1093}
1094
1095// Decompress4X will decompress a 4X encoded stream.
1096// The length of the supplied input must match the end of a block exactly.
1097// The *capacity* of the dst slice must match the destination size of
1098// the uncompressed data exactly.
1099func (d *Decoder) decompress4X8bitExactly(dst, src []byte) ([]byte, error) {
1100	var br [4]bitReaderBytes
1101	start := 6
1102	for i := 0; i < 3; i++ {
1103		length := int(src[i*2]) | (int(src[i*2+1]) << 8)
1104		if start+length >= len(src) {
1105			return nil, errors.New("truncated input (or invalid offset)")
1106		}
1107		err := br[i].init(src[start : start+length])
1108		if err != nil {
1109			return nil, err
1110		}
1111		start += length
1112	}
1113	err := br[3].init(src[start:])
1114	if err != nil {
1115		return nil, err
1116	}
1117
1118	// destination, offset to match first output
1119	dstSize := cap(dst)
1120	dst = dst[:dstSize]
1121	out := dst
1122	dstEvery := (dstSize + 3) / 4
1123
1124	const shift = 0
1125	const tlSize = 1 << 8
1126	const tlMask = tlSize - 1
1127	single := d.dt.single[:tlSize]
1128
1129	// Use temp table to avoid bound checks/append penalty.
1130	var buf [256]byte
1131	var off uint8
1132	var decoded int
1133
1134	// Decode 4 values from each decoder/loop.
1135	const bufoff = 256 / 4
1136	for {
1137		if br[0].off < 4 || br[1].off < 4 || br[2].off < 4 || br[3].off < 4 {
1138			break
1139		}
1140
1141		{
1142			// Interleave 2 decodes.
1143			const stream = 0
1144			const stream2 = 1
1145			br[stream].fillFast()
1146			br[stream2].fillFast()
1147
1148			v := single[br[stream].peekByteFast()>>shift].entry
1149			buf[off+bufoff*stream] = uint8(v >> 8)
1150			br[stream].advance(uint8(v))
1151
1152			v2 := single[br[stream2].peekByteFast()>>shift].entry
1153			buf[off+bufoff*stream2] = uint8(v2 >> 8)
1154			br[stream2].advance(uint8(v2))
1155
1156			v = single[br[stream].peekByteFast()>>shift].entry
1157			buf[off+bufoff*stream+1] = uint8(v >> 8)
1158			br[stream].advance(uint8(v))
1159
1160			v2 = single[br[stream2].peekByteFast()>>shift].entry
1161			buf[off+bufoff*stream2+1] = uint8(v2 >> 8)
1162			br[stream2].advance(uint8(v2))
1163
1164			v = single[br[stream].peekByteFast()>>shift].entry
1165			buf[off+bufoff*stream+2] = uint8(v >> 8)
1166			br[stream].advance(uint8(v))
1167
1168			v2 = single[br[stream2].peekByteFast()>>shift].entry
1169			buf[off+bufoff*stream2+2] = uint8(v2 >> 8)
1170			br[stream2].advance(uint8(v2))
1171
1172			v = single[br[stream].peekByteFast()>>shift].entry
1173			buf[off+bufoff*stream+3] = uint8(v >> 8)
1174			br[stream].advance(uint8(v))
1175
1176			v2 = single[br[stream2].peekByteFast()>>shift].entry
1177			buf[off+bufoff*stream2+3] = uint8(v2 >> 8)
1178			br[stream2].advance(uint8(v2))
1179		}
1180
1181		{
1182			const stream = 2
1183			const stream2 = 3
1184			br[stream].fillFast()
1185			br[stream2].fillFast()
1186
1187			v := single[br[stream].peekByteFast()>>shift].entry
1188			buf[off+bufoff*stream] = uint8(v >> 8)
1189			br[stream].advance(uint8(v))
1190
1191			v2 := single[br[stream2].peekByteFast()>>shift].entry
1192			buf[off+bufoff*stream2] = uint8(v2 >> 8)
1193			br[stream2].advance(uint8(v2))
1194
1195			v = single[br[stream].peekByteFast()>>shift].entry
1196			buf[off+bufoff*stream+1] = uint8(v >> 8)
1197			br[stream].advance(uint8(v))
1198
1199			v2 = single[br[stream2].peekByteFast()>>shift].entry
1200			buf[off+bufoff*stream2+1] = uint8(v2 >> 8)
1201			br[stream2].advance(uint8(v2))
1202
1203			v = single[br[stream].peekByteFast()>>shift].entry
1204			buf[off+bufoff*stream+2] = uint8(v >> 8)
1205			br[stream].advance(uint8(v))
1206
1207			v2 = single[br[stream2].peekByteFast()>>shift].entry
1208			buf[off+bufoff*stream2+2] = uint8(v2 >> 8)
1209			br[stream2].advance(uint8(v2))
1210
1211			v = single[br[stream].peekByteFast()>>shift].entry
1212			buf[off+bufoff*stream+3] = uint8(v >> 8)
1213			br[stream].advance(uint8(v))
1214
1215			v2 = single[br[stream2].peekByteFast()>>shift].entry
1216			buf[off+bufoff*stream2+3] = uint8(v2 >> 8)
1217			br[stream2].advance(uint8(v2))
1218		}
1219
1220		off += 4
1221
1222		if off == bufoff {
1223			if bufoff > dstEvery {
1224				return nil, errors.New("corruption detected: stream overrun 1")
1225			}
1226			copy(out, buf[:bufoff])
1227			copy(out[dstEvery:], buf[bufoff:bufoff*2])
1228			copy(out[dstEvery*2:], buf[bufoff*2:bufoff*3])
1229			copy(out[dstEvery*3:], buf[bufoff*3:bufoff*4])
1230			off = 0
1231			out = out[bufoff:]
1232			decoded += 256
1233			// There must at least be 3 buffers left.
1234			if len(out) < dstEvery*3 {
1235				return nil, errors.New("corruption detected: stream overrun 2")
1236			}
1237		}
1238	}
1239	if off > 0 {
1240		ioff := int(off)
1241		if len(out) < dstEvery*3+ioff {
1242			return nil, errors.New("corruption detected: stream overrun 3")
1243		}
1244		copy(out, buf[:off])
1245		copy(out[dstEvery:dstEvery+ioff], buf[bufoff:bufoff*2])
1246		copy(out[dstEvery*2:dstEvery*2+ioff], buf[bufoff*2:bufoff*3])
1247		copy(out[dstEvery*3:dstEvery*3+ioff], buf[bufoff*3:bufoff*4])
1248		decoded += int(off) * 4
1249		out = out[off:]
1250	}
1251
1252	// Decode remaining.
1253	for i := range br {
1254		offset := dstEvery * i
1255		br := &br[i]
1256		bitsLeft := int(br.off*8) + int(64-br.bitsRead)
1257		for bitsLeft > 0 {
1258			if br.finished() {
1259				return nil, io.ErrUnexpectedEOF
1260			}
1261			if br.bitsRead >= 56 {
1262				if br.off >= 4 {
1263					v := br.in[br.off-4:]
1264					v = v[:4]
1265					low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
1266					br.value |= uint64(low) << (br.bitsRead - 32)
1267					br.bitsRead -= 32
1268					br.off -= 4
1269				} else {
1270					for br.off > 0 {
1271						br.value |= uint64(br.in[br.off-1]) << (br.bitsRead - 8)
1272						br.bitsRead -= 8
1273						br.off--
1274					}
1275				}
1276			}
1277			// end inline...
1278			if offset >= len(out) {
1279				return nil, errors.New("corruption detected: stream overrun 4")
1280			}
1281
1282			// Read value and increment offset.
1283			v := single[br.peekByteFast()>>shift].entry
1284			nBits := uint8(v)
1285			br.advance(nBits)
1286			bitsLeft -= int(nBits)
1287			out[offset] = uint8(v >> 8)
1288			offset++
1289		}
1290		decoded += offset - dstEvery*i
1291		err = br.close()
1292		if err != nil {
1293			return nil, err
1294		}
1295	}
1296	if dstSize != decoded {
1297		return nil, errors.New("corruption detected: short output block")
1298	}
1299	return dst, nil
1300}
1301
1302// matches will compare a decoding table to a coding table.
1303// Errors are written to the writer.
1304// Nothing will be written if table is ok.
1305func (s *Scratch) matches(ct cTable, w io.Writer) {
1306	if s == nil || len(s.dt.single) == 0 {
1307		return
1308	}
1309	dt := s.dt.single[:1<<s.actualTableLog]
1310	tablelog := s.actualTableLog
1311	ok := 0
1312	broken := 0
1313	for sym, enc := range ct {
1314		errs := 0
1315		broken++
1316		if enc.nBits == 0 {
1317			for _, dec := range dt {
1318				if uint8(dec.entry>>8) == byte(sym) {
1319					fmt.Fprintf(w, "symbol %x has decoder, but no encoder\n", sym)
1320					errs++
1321					break
1322				}
1323			}
1324			if errs == 0 {
1325				broken--
1326			}
1327			continue
1328		}
1329		// Unused bits in input
1330		ub := tablelog - enc.nBits
1331		top := enc.val << ub
1332		// decoder looks at top bits.
1333		dec := dt[top]
1334		if uint8(dec.entry) != enc.nBits {
1335			fmt.Fprintf(w, "symbol 0x%x bit size mismatch (enc: %d, dec:%d).\n", sym, enc.nBits, uint8(dec.entry))
1336			errs++
1337		}
1338		if uint8(dec.entry>>8) != uint8(sym) {
1339			fmt.Fprintf(w, "symbol 0x%x decoder output mismatch (enc: %d, dec:%d).\n", sym, sym, uint8(dec.entry>>8))
1340			errs++
1341		}
1342		if errs > 0 {
1343			fmt.Fprintf(w, "%d errros in base, stopping\n", errs)
1344			continue
1345		}
1346		// Ensure that all combinations are covered.
1347		for i := uint16(0); i < (1 << ub); i++ {
1348			vval := top | i
1349			dec := dt[vval]
1350			if uint8(dec.entry) != enc.nBits {
1351				fmt.Fprintf(w, "symbol 0x%x bit size mismatch (enc: %d, dec:%d).\n", vval, enc.nBits, uint8(dec.entry))
1352				errs++
1353			}
1354			if uint8(dec.entry>>8) != uint8(sym) {
1355				fmt.Fprintf(w, "symbol 0x%x decoder output mismatch (enc: %d, dec:%d).\n", vval, sym, uint8(dec.entry>>8))
1356				errs++
1357			}
1358			if errs > 20 {
1359				fmt.Fprintf(w, "%d errros, stopping\n", errs)
1360				break
1361			}
1362		}
1363		if errs == 0 {
1364			ok++
1365			broken--
1366		}
1367	}
1368	if broken > 0 {
1369		fmt.Fprintf(w, "%d broken, %d ok\n", broken, ok)
1370	}
1371}
1372