1package huff0
2
3import (
4	"errors"
5	"fmt"
6	"io"
7
8	"github.com/klauspost/compress/fse"
9)
10
11type dTable struct {
12	single []dEntrySingle
13	double []dEntryDouble
14}
15
16// single-symbols decoding
17type dEntrySingle struct {
18	entry uint16
19}
20
21// double-symbols decoding
22type dEntryDouble struct {
23	seq   uint16
24	nBits uint8
25	len   uint8
26}
27
28// ReadTable will read a table from the input.
29// The size of the input may be larger than the table definition.
30// Any content remaining after the table definition will be returned.
31// If no Scratch is provided a new one is allocated.
32// The returned Scratch can be used for decoding input using this table.
33func ReadTable(in []byte, s *Scratch) (s2 *Scratch, remain []byte, err error) {
34	s, err = s.prepare(in)
35	if err != nil {
36		return s, nil, err
37	}
38	if len(in) <= 1 {
39		return s, nil, errors.New("input too small for table")
40	}
41	iSize := in[0]
42	in = in[1:]
43	if iSize >= 128 {
44		// Uncompressed
45		oSize := iSize - 127
46		iSize = (oSize + 1) / 2
47		if int(iSize) > len(in) {
48			return s, nil, errors.New("input too small for table")
49		}
50		for n := uint8(0); n < oSize; n += 2 {
51			v := in[n/2]
52			s.huffWeight[n] = v >> 4
53			s.huffWeight[n+1] = v & 15
54		}
55		s.symbolLen = uint16(oSize)
56		in = in[iSize:]
57	} else {
58		if len(in) <= int(iSize) {
59			return s, nil, errors.New("input too small for table")
60		}
61		// FSE compressed weights
62		s.fse.DecompressLimit = 255
63		hw := s.huffWeight[:]
64		s.fse.Out = hw
65		b, err := fse.Decompress(in[:iSize], s.fse)
66		s.fse.Out = nil
67		if err != nil {
68			return s, nil, err
69		}
70		if len(b) > 255 {
71			return s, nil, errors.New("corrupt input: output table too large")
72		}
73		s.symbolLen = uint16(len(b))
74		in = in[iSize:]
75	}
76
77	// collect weight stats
78	var rankStats [16]uint32
79	weightTotal := uint32(0)
80	for _, v := range s.huffWeight[:s.symbolLen] {
81		if v > tableLogMax {
82			return s, nil, errors.New("corrupt input: weight too large")
83		}
84		v2 := v & 15
85		rankStats[v2]++
86		weightTotal += (1 << v2) >> 1
87	}
88	if weightTotal == 0 {
89		return s, nil, errors.New("corrupt input: weights zero")
90	}
91
92	// get last non-null symbol weight (implied, total must be 2^n)
93	{
94		tableLog := highBit32(weightTotal) + 1
95		if tableLog > tableLogMax {
96			return s, nil, errors.New("corrupt input: tableLog too big")
97		}
98		s.actualTableLog = uint8(tableLog)
99		// determine last weight
100		{
101			total := uint32(1) << tableLog
102			rest := total - weightTotal
103			verif := uint32(1) << highBit32(rest)
104			lastWeight := highBit32(rest) + 1
105			if verif != rest {
106				// last value must be a clean power of 2
107				return s, nil, errors.New("corrupt input: last value not power of two")
108			}
109			s.huffWeight[s.symbolLen] = uint8(lastWeight)
110			s.symbolLen++
111			rankStats[lastWeight]++
112		}
113	}
114
115	if (rankStats[1] < 2) || (rankStats[1]&1 != 0) {
116		// by construction : at least 2 elts of rank 1, must be even
117		return s, nil, errors.New("corrupt input: min elt size, even check failed ")
118	}
119
120	// TODO: Choose between single/double symbol decoding
121
122	// Calculate starting value for each rank
123	{
124		var nextRankStart uint32
125		for n := uint8(1); n < s.actualTableLog+1; n++ {
126			current := nextRankStart
127			nextRankStart += rankStats[n] << (n - 1)
128			rankStats[n] = current
129		}
130	}
131
132	// fill DTable (always full size)
133	tSize := 1 << tableLogMax
134	if len(s.dt.single) != tSize {
135		s.dt.single = make([]dEntrySingle, tSize)
136	}
137	for n, w := range s.huffWeight[:s.symbolLen] {
138		if w == 0 {
139			continue
140		}
141		length := (uint32(1) << w) >> 1
142		d := dEntrySingle{
143			entry: uint16(s.actualTableLog+1-w) | (uint16(n) << 8),
144		}
145		single := s.dt.single[rankStats[w] : rankStats[w]+length]
146		for i := range single {
147			single[i] = d
148		}
149		rankStats[w] += length
150	}
151	return s, in, nil
152}
153
154// Decompress1X will decompress a 1X encoded stream.
155// The length of the supplied input must match the end of a block exactly.
156// Before this is called, the table must be initialized with ReadTable unless
157// the encoder re-used the table.
158func (s *Scratch) Decompress1X(in []byte) (out []byte, err error) {
159	if len(s.dt.single) == 0 {
160		return nil, errors.New("no table loaded")
161	}
162	var br bitReader
163	err = br.init(in)
164	if err != nil {
165		return nil, err
166	}
167	s.Out = s.Out[:0]
168
169	decode := func() byte {
170		val := br.peekBitsFast(s.actualTableLog) /* note : actualTableLog >= 1 */
171		v := s.dt.single[val]
172		br.bitsRead += uint8(v.entry)
173		return uint8(v.entry >> 8)
174	}
175	hasDec := func(v dEntrySingle) byte {
176		br.bitsRead += uint8(v.entry)
177		return uint8(v.entry >> 8)
178	}
179
180	// Avoid bounds check by always having full sized table.
181	const tlSize = 1 << tableLogMax
182	const tlMask = tlSize - 1
183	dt := s.dt.single[:tlSize]
184
185	// Use temp table to avoid bound checks/append penalty.
186	var tmp = s.huffWeight[:256]
187	var off uint8
188
189	for br.off >= 8 {
190		br.fillFast()
191		tmp[off+0] = hasDec(dt[br.peekBitsFast(s.actualTableLog)&tlMask])
192		tmp[off+1] = hasDec(dt[br.peekBitsFast(s.actualTableLog)&tlMask])
193		br.fillFast()
194		tmp[off+2] = hasDec(dt[br.peekBitsFast(s.actualTableLog)&tlMask])
195		tmp[off+3] = hasDec(dt[br.peekBitsFast(s.actualTableLog)&tlMask])
196		off += 4
197		if off == 0 {
198			if len(s.Out)+256 > s.MaxDecodedSize {
199				br.close()
200				return nil, ErrMaxDecodedSizeExceeded
201			}
202			s.Out = append(s.Out, tmp...)
203		}
204	}
205
206	if len(s.Out)+int(off) > s.MaxDecodedSize {
207		br.close()
208		return nil, ErrMaxDecodedSizeExceeded
209	}
210	s.Out = append(s.Out, tmp[:off]...)
211
212	for !br.finished() {
213		br.fill()
214		if len(s.Out) >= s.MaxDecodedSize {
215			br.close()
216			return nil, ErrMaxDecodedSizeExceeded
217		}
218		s.Out = append(s.Out, decode())
219	}
220	return s.Out, br.close()
221}
222
223// Decompress4X will decompress a 4X encoded stream.
224// Before this is called, the table must be initialized with ReadTable unless
225// the encoder re-used the table.
226// The length of the supplied input must match the end of a block exactly.
227// The destination size of the uncompressed data must be known and provided.
228func (s *Scratch) Decompress4X(in []byte, dstSize int) (out []byte, err error) {
229	if len(s.dt.single) == 0 {
230		return nil, errors.New("no table loaded")
231	}
232	if len(in) < 6+(4*1) {
233		return nil, errors.New("input too small")
234	}
235	if dstSize > s.MaxDecodedSize {
236		return nil, ErrMaxDecodedSizeExceeded
237	}
238	// TODO: We do not detect when we overrun a buffer, except if the last one does.
239
240	var br [4]bitReader
241	start := 6
242	for i := 0; i < 3; i++ {
243		length := int(in[i*2]) | (int(in[i*2+1]) << 8)
244		if start+length >= len(in) {
245			return nil, errors.New("truncated input (or invalid offset)")
246		}
247		err = br[i].init(in[start : start+length])
248		if err != nil {
249			return nil, err
250		}
251		start += length
252	}
253	err = br[3].init(in[start:])
254	if err != nil {
255		return nil, err
256	}
257
258	// Prepare output
259	if cap(s.Out) < dstSize {
260		s.Out = make([]byte, 0, dstSize)
261	}
262	s.Out = s.Out[:dstSize]
263	// destination, offset to match first output
264	dstOut := s.Out
265	dstEvery := (dstSize + 3) / 4
266
267	const tlSize = 1 << tableLogMax
268	const tlMask = tlSize - 1
269	single := s.dt.single[:tlSize]
270
271	decode := func(br *bitReader) byte {
272		val := br.peekBitsFast(s.actualTableLog) /* note : actualTableLog >= 1 */
273		v := single[val&tlMask]
274		br.bitsRead += uint8(v.entry)
275		return uint8(v.entry >> 8)
276	}
277
278	// Use temp table to avoid bound checks/append penalty.
279	var tmp = s.huffWeight[:256]
280	var off uint8
281	var decoded int
282
283	// Decode 2 values from each decoder/loop.
284	const bufoff = 256 / 4
285bigloop:
286	for {
287		for i := range br {
288			br := &br[i]
289			if br.off < 4 {
290				break bigloop
291			}
292			br.fillFast()
293		}
294
295		{
296			const stream = 0
297			val := br[stream].peekBitsFast(s.actualTableLog)
298			v := single[val&tlMask]
299			br[stream].bitsRead += uint8(v.entry)
300
301			val2 := br[stream].peekBitsFast(s.actualTableLog)
302			v2 := single[val2&tlMask]
303			tmp[off+bufoff*stream+1] = uint8(v2.entry >> 8)
304			tmp[off+bufoff*stream] = uint8(v.entry >> 8)
305			br[stream].bitsRead += uint8(v2.entry)
306		}
307
308		{
309			const stream = 1
310			val := br[stream].peekBitsFast(s.actualTableLog)
311			v := single[val&tlMask]
312			br[stream].bitsRead += uint8(v.entry)
313
314			val2 := br[stream].peekBitsFast(s.actualTableLog)
315			v2 := single[val2&tlMask]
316			tmp[off+bufoff*stream+1] = uint8(v2.entry >> 8)
317			tmp[off+bufoff*stream] = uint8(v.entry >> 8)
318			br[stream].bitsRead += uint8(v2.entry)
319		}
320
321		{
322			const stream = 2
323			val := br[stream].peekBitsFast(s.actualTableLog)
324			v := single[val&tlMask]
325			br[stream].bitsRead += uint8(v.entry)
326
327			val2 := br[stream].peekBitsFast(s.actualTableLog)
328			v2 := single[val2&tlMask]
329			tmp[off+bufoff*stream+1] = uint8(v2.entry >> 8)
330			tmp[off+bufoff*stream] = uint8(v.entry >> 8)
331			br[stream].bitsRead += uint8(v2.entry)
332		}
333
334		{
335			const stream = 3
336			val := br[stream].peekBitsFast(s.actualTableLog)
337			v := single[val&tlMask]
338			br[stream].bitsRead += uint8(v.entry)
339
340			val2 := br[stream].peekBitsFast(s.actualTableLog)
341			v2 := single[val2&tlMask]
342			tmp[off+bufoff*stream+1] = uint8(v2.entry >> 8)
343			tmp[off+bufoff*stream] = uint8(v.entry >> 8)
344			br[stream].bitsRead += uint8(v2.entry)
345		}
346
347		off += 2
348
349		if off == bufoff {
350			if bufoff > dstEvery {
351				return nil, errors.New("corruption detected: stream overrun 1")
352			}
353			copy(dstOut, tmp[:bufoff])
354			copy(dstOut[dstEvery:], tmp[bufoff:bufoff*2])
355			copy(dstOut[dstEvery*2:], tmp[bufoff*2:bufoff*3])
356			copy(dstOut[dstEvery*3:], tmp[bufoff*3:bufoff*4])
357			off = 0
358			dstOut = dstOut[bufoff:]
359			decoded += 256
360			// There must at least be 3 buffers left.
361			if len(dstOut) < dstEvery*3 {
362				return nil, errors.New("corruption detected: stream overrun 2")
363			}
364		}
365	}
366	if off > 0 {
367		ioff := int(off)
368		if len(dstOut) < dstEvery*3+ioff {
369			return nil, errors.New("corruption detected: stream overrun 3")
370		}
371		copy(dstOut, tmp[:off])
372		copy(dstOut[dstEvery:dstEvery+ioff], tmp[bufoff:bufoff*2])
373		copy(dstOut[dstEvery*2:dstEvery*2+ioff], tmp[bufoff*2:bufoff*3])
374		copy(dstOut[dstEvery*3:dstEvery*3+ioff], tmp[bufoff*3:bufoff*4])
375		decoded += int(off) * 4
376		dstOut = dstOut[off:]
377	}
378
379	// Decode remaining.
380	for i := range br {
381		offset := dstEvery * i
382		br := &br[i]
383		for !br.finished() {
384			br.fill()
385			if offset >= len(dstOut) {
386				return nil, errors.New("corruption detected: stream overrun 4")
387			}
388			dstOut[offset] = decode(br)
389			offset++
390		}
391		decoded += offset - dstEvery*i
392		err = br.close()
393		if err != nil {
394			return nil, err
395		}
396	}
397	if dstSize != decoded {
398		return nil, errors.New("corruption detected: short output block")
399	}
400	return s.Out, nil
401}
402
403// matches will compare a decoding table to a coding table.
404// Errors are written to the writer.
405// Nothing will be written if table is ok.
406func (s *Scratch) matches(ct cTable, w io.Writer) {
407	if s == nil || len(s.dt.single) == 0 {
408		return
409	}
410	dt := s.dt.single[:1<<s.actualTableLog]
411	tablelog := s.actualTableLog
412	ok := 0
413	broken := 0
414	for sym, enc := range ct {
415		errs := 0
416		broken++
417		if enc.nBits == 0 {
418			for _, dec := range dt {
419				if uint8(dec.entry>>8) == byte(sym) {
420					fmt.Fprintf(w, "symbol %x has decoder, but no encoder\n", sym)
421					errs++
422					break
423				}
424			}
425			if errs == 0 {
426				broken--
427			}
428			continue
429		}
430		// Unused bits in input
431		ub := tablelog - enc.nBits
432		top := enc.val << ub
433		// decoder looks at top bits.
434		dec := dt[top]
435		if uint8(dec.entry) != enc.nBits {
436			fmt.Fprintf(w, "symbol 0x%x bit size mismatch (enc: %d, dec:%d).\n", sym, enc.nBits, uint8(dec.entry))
437			errs++
438		}
439		if uint8(dec.entry>>8) != uint8(sym) {
440			fmt.Fprintf(w, "symbol 0x%x decoder output mismatch (enc: %d, dec:%d).\n", sym, sym, uint8(dec.entry>>8))
441			errs++
442		}
443		if errs > 0 {
444			fmt.Fprintf(w, "%d errros in base, stopping\n", errs)
445			continue
446		}
447		// Ensure that all combinations are covered.
448		for i := uint16(0); i < (1 << ub); i++ {
449			vval := top | i
450			dec := dt[vval]
451			if uint8(dec.entry) != enc.nBits {
452				fmt.Fprintf(w, "symbol 0x%x bit size mismatch (enc: %d, dec:%d).\n", vval, enc.nBits, uint8(dec.entry))
453				errs++
454			}
455			if uint8(dec.entry>>8) != uint8(sym) {
456				fmt.Fprintf(w, "symbol 0x%x decoder output mismatch (enc: %d, dec:%d).\n", vval, sym, uint8(dec.entry>>8))
457				errs++
458			}
459			if errs > 20 {
460				fmt.Fprintf(w, "%d errros, stopping\n", errs)
461				break
462			}
463		}
464		if errs == 0 {
465			ok++
466			broken--
467		}
468	}
469	if broken > 0 {
470		fmt.Fprintf(w, "%d broken, %d ok\n", broken, ok)
471	}
472}
473