1// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import (
8	"errors"
9	"fmt"
10	"math"
11)
12
13const (
14	// For encoding we only support up to
15	maxEncTableLog    = 8
16	maxEncTablesize   = 1 << maxTableLog
17	maxEncTableMask   = (1 << maxTableLog) - 1
18	minEncTablelog    = 5
19	maxEncSymbolValue = maxMatchLengthSymbol
20)
21
22// Scratch provides temporary storage for compression and decompression.
23type fseEncoder struct {
24	symbolLen      uint16 // Length of active part of the symbol table.
25	actualTableLog uint8  // Selected tablelog.
26	ct             cTable // Compression tables.
27	maxCount       int    // count of the most probable symbol
28	zeroBits       bool   // no bits has prob > 50%.
29	clearCount     bool   // clear count
30	useRLE         bool   // This encoder is for RLE
31	preDefined     bool   // This encoder is predefined.
32	reUsed         bool   // Set to know when the encoder has been reused.
33	rleVal         uint8  // RLE Symbol
34	maxBits        uint8  // Maximum output bits after transform.
35
36	// TODO: Technically zstd should be fine with 64 bytes.
37	count [256]uint32
38	norm  [256]int16
39}
40
41// cTable contains tables used for compression.
42type cTable struct {
43	tableSymbol []byte
44	stateTable  []uint16
45	symbolTT    []symbolTransform
46}
47
48// symbolTransform contains the state transform for a symbol.
49type symbolTransform struct {
50	deltaNbBits    uint32
51	deltaFindState int16
52	outBits        uint8
53}
54
55// String prints values as a human readable string.
56func (s symbolTransform) String() string {
57	return fmt.Sprintf("{deltabits: %08x, findstate:%d outbits:%d}", s.deltaNbBits, s.deltaFindState, s.outBits)
58}
59
60// Histogram allows to populate the histogram and skip that step in the compression,
61// It otherwise allows to inspect the histogram when compression is done.
62// To indicate that you have populated the histogram call HistogramFinished
63// with the value of the highest populated symbol, as well as the number of entries
64// in the most populated entry. These are accepted at face value.
65// The returned slice will always be length 256.
66func (s *fseEncoder) Histogram() []uint32 {
67	return s.count[:]
68}
69
70// HistogramFinished can be called to indicate that the histogram has been populated.
71// maxSymbol is the index of the highest set symbol of the next data segment.
72// maxCount is the number of entries in the most populated entry.
73// These are accepted at face value.
74func (s *fseEncoder) HistogramFinished(maxSymbol uint8, maxCount int) {
75	s.maxCount = maxCount
76	s.symbolLen = uint16(maxSymbol) + 1
77	s.clearCount = maxCount != 0
78}
79
80// prepare will prepare and allocate scratch tables used for both compression and decompression.
81func (s *fseEncoder) prepare() (*fseEncoder, error) {
82	if s == nil {
83		s = &fseEncoder{}
84	}
85	s.useRLE = false
86	if s.clearCount && s.maxCount == 0 {
87		for i := range s.count {
88			s.count[i] = 0
89		}
90		s.clearCount = false
91	}
92	return s, nil
93}
94
95// allocCtable will allocate tables needed for compression.
96// If existing tables a re big enough, they are simply re-used.
97func (s *fseEncoder) allocCtable() {
98	tableSize := 1 << s.actualTableLog
99	// get tableSymbol that is big enough.
100	if cap(s.ct.tableSymbol) < int(tableSize) {
101		s.ct.tableSymbol = make([]byte, tableSize)
102	}
103	s.ct.tableSymbol = s.ct.tableSymbol[:tableSize]
104
105	ctSize := tableSize
106	if cap(s.ct.stateTable) < ctSize {
107		s.ct.stateTable = make([]uint16, ctSize)
108	}
109	s.ct.stateTable = s.ct.stateTable[:ctSize]
110
111	if cap(s.ct.symbolTT) < 256 {
112		s.ct.symbolTT = make([]symbolTransform, 256)
113	}
114	s.ct.symbolTT = s.ct.symbolTT[:256]
115}
116
117// buildCTable will populate the compression table so it is ready to be used.
118func (s *fseEncoder) buildCTable() error {
119	tableSize := uint32(1 << s.actualTableLog)
120	highThreshold := tableSize - 1
121	var cumul [256]int16
122
123	s.allocCtable()
124	tableSymbol := s.ct.tableSymbol[:tableSize]
125	// symbol start positions
126	{
127		cumul[0] = 0
128		for ui, v := range s.norm[:s.symbolLen-1] {
129			u := byte(ui) // one less than reference
130			if v == -1 {
131				// Low proba symbol
132				cumul[u+1] = cumul[u] + 1
133				tableSymbol[highThreshold] = u
134				highThreshold--
135			} else {
136				cumul[u+1] = cumul[u] + v
137			}
138		}
139		// Encode last symbol separately to avoid overflowing u
140		u := int(s.symbolLen - 1)
141		v := s.norm[s.symbolLen-1]
142		if v == -1 {
143			// Low proba symbol
144			cumul[u+1] = cumul[u] + 1
145			tableSymbol[highThreshold] = byte(u)
146			highThreshold--
147		} else {
148			cumul[u+1] = cumul[u] + v
149		}
150		if uint32(cumul[s.symbolLen]) != tableSize {
151			return fmt.Errorf("internal error: expected cumul[s.symbolLen] (%d) == tableSize (%d)", cumul[s.symbolLen], tableSize)
152		}
153		cumul[s.symbolLen] = int16(tableSize) + 1
154	}
155	// Spread symbols
156	s.zeroBits = false
157	{
158		step := tableStep(tableSize)
159		tableMask := tableSize - 1
160		var position uint32
161		// if any symbol > largeLimit, we may have 0 bits output.
162		largeLimit := int16(1 << (s.actualTableLog - 1))
163		for ui, v := range s.norm[:s.symbolLen] {
164			symbol := byte(ui)
165			if v > largeLimit {
166				s.zeroBits = true
167			}
168			for nbOccurrences := int16(0); nbOccurrences < v; nbOccurrences++ {
169				tableSymbol[position] = symbol
170				position = (position + step) & tableMask
171				for position > highThreshold {
172					position = (position + step) & tableMask
173				} /* Low proba area */
174			}
175		}
176
177		// Check if we have gone through all positions
178		if position != 0 {
179			return errors.New("position!=0")
180		}
181	}
182
183	// Build table
184	table := s.ct.stateTable
185	{
186		tsi := int(tableSize)
187		for u, v := range tableSymbol {
188			// TableU16 : sorted by symbol order; gives next state value
189			table[cumul[v]] = uint16(tsi + u)
190			cumul[v]++
191		}
192	}
193
194	// Build Symbol Transformation Table
195	{
196		total := int16(0)
197		symbolTT := s.ct.symbolTT[:s.symbolLen]
198		tableLog := s.actualTableLog
199		tl := (uint32(tableLog) << 16) - (1 << tableLog)
200		for i, v := range s.norm[:s.symbolLen] {
201			switch v {
202			case 0:
203			case -1, 1:
204				symbolTT[i].deltaNbBits = tl
205				symbolTT[i].deltaFindState = int16(total - 1)
206				total++
207			default:
208				maxBitsOut := uint32(tableLog) - highBit(uint32(v-1))
209				minStatePlus := uint32(v) << maxBitsOut
210				symbolTT[i].deltaNbBits = (maxBitsOut << 16) - minStatePlus
211				symbolTT[i].deltaFindState = int16(total - v)
212				total += v
213			}
214		}
215		if total != int16(tableSize) {
216			return fmt.Errorf("total mismatch %d (got) != %d (want)", total, tableSize)
217		}
218	}
219	return nil
220}
221
222var rtbTable = [...]uint32{0, 473195, 504333, 520860, 550000, 700000, 750000, 830000}
223
224func (s *fseEncoder) setRLE(val byte) {
225	s.allocCtable()
226	s.actualTableLog = 0
227	s.ct.stateTable = s.ct.stateTable[:1]
228	s.ct.symbolTT[val] = symbolTransform{
229		deltaFindState: 0,
230		deltaNbBits:    0,
231	}
232	if debug {
233		println("setRLE: val", val, "symbolTT", s.ct.symbolTT[val])
234	}
235	s.rleVal = val
236	s.useRLE = true
237}
238
239// setBits will set output bits for the transform.
240// if nil is provided, the number of bits is equal to the index.
241func (s *fseEncoder) setBits(transform []byte) {
242	if s.reUsed || s.preDefined {
243		return
244	}
245	if s.useRLE {
246		if transform == nil {
247			s.ct.symbolTT[s.rleVal].outBits = s.rleVal
248			s.maxBits = s.rleVal
249			return
250		}
251		s.maxBits = transform[s.rleVal]
252		s.ct.symbolTT[s.rleVal].outBits = s.maxBits
253		return
254	}
255	if transform == nil {
256		for i := range s.ct.symbolTT[:s.symbolLen] {
257			s.ct.symbolTT[i].outBits = uint8(i)
258		}
259		s.maxBits = uint8(s.symbolLen - 1)
260		return
261	}
262	s.maxBits = 0
263	for i, v := range transform[:s.symbolLen] {
264		s.ct.symbolTT[i].outBits = v
265		if v > s.maxBits {
266			// We could assume bits always going up, but we play safe.
267			s.maxBits = v
268		}
269	}
270}
271
272// normalizeCount will normalize the count of the symbols so
273// the total is equal to the table size.
274// If successful, compression tables will also be made ready.
275func (s *fseEncoder) normalizeCount(length int) error {
276	if s.reUsed {
277		return nil
278	}
279	s.optimalTableLog(length)
280	var (
281		tableLog          = s.actualTableLog
282		scale             = 62 - uint64(tableLog)
283		step              = (1 << 62) / uint64(length)
284		vStep             = uint64(1) << (scale - 20)
285		stillToDistribute = int16(1 << tableLog)
286		largest           int
287		largestP          int16
288		lowThreshold      = (uint32)(length >> tableLog)
289	)
290	if s.maxCount == length {
291		s.useRLE = true
292		return nil
293	}
294	s.useRLE = false
295	for i, cnt := range s.count[:s.symbolLen] {
296		// already handled
297		// if (count[s] == s.length) return 0;   /* rle special case */
298
299		if cnt == 0 {
300			s.norm[i] = 0
301			continue
302		}
303		if cnt <= lowThreshold {
304			s.norm[i] = -1
305			stillToDistribute--
306		} else {
307			proba := (int16)((uint64(cnt) * step) >> scale)
308			if proba < 8 {
309				restToBeat := vStep * uint64(rtbTable[proba])
310				v := uint64(cnt)*step - (uint64(proba) << scale)
311				if v > restToBeat {
312					proba++
313				}
314			}
315			if proba > largestP {
316				largestP = proba
317				largest = i
318			}
319			s.norm[i] = proba
320			stillToDistribute -= proba
321		}
322	}
323
324	if -stillToDistribute >= (s.norm[largest] >> 1) {
325		// corner case, need another normalization method
326		err := s.normalizeCount2(length)
327		if err != nil {
328			return err
329		}
330		if debug {
331			err = s.validateNorm()
332			if err != nil {
333				return err
334			}
335		}
336		return s.buildCTable()
337	}
338	s.norm[largest] += stillToDistribute
339	if debug {
340		err := s.validateNorm()
341		if err != nil {
342			return err
343		}
344	}
345	return s.buildCTable()
346}
347
348// Secondary normalization method.
349// To be used when primary method fails.
350func (s *fseEncoder) normalizeCount2(length int) error {
351	const notYetAssigned = -2
352	var (
353		distributed  uint32
354		total        = uint32(length)
355		tableLog     = s.actualTableLog
356		lowThreshold = uint32(total >> tableLog)
357		lowOne       = uint32((total * 3) >> (tableLog + 1))
358	)
359	for i, cnt := range s.count[:s.symbolLen] {
360		if cnt == 0 {
361			s.norm[i] = 0
362			continue
363		}
364		if cnt <= lowThreshold {
365			s.norm[i] = -1
366			distributed++
367			total -= cnt
368			continue
369		}
370		if cnt <= lowOne {
371			s.norm[i] = 1
372			distributed++
373			total -= cnt
374			continue
375		}
376		s.norm[i] = notYetAssigned
377	}
378	toDistribute := (1 << tableLog) - distributed
379
380	if (total / toDistribute) > lowOne {
381		// risk of rounding to zero
382		lowOne = uint32((total * 3) / (toDistribute * 2))
383		for i, cnt := range s.count[:s.symbolLen] {
384			if (s.norm[i] == notYetAssigned) && (cnt <= lowOne) {
385				s.norm[i] = 1
386				distributed++
387				total -= cnt
388				continue
389			}
390		}
391		toDistribute = (1 << tableLog) - distributed
392	}
393	if distributed == uint32(s.symbolLen)+1 {
394		// all values are pretty poor;
395		//   probably incompressible data (should have already been detected);
396		//   find max, then give all remaining points to max
397		var maxV int
398		var maxC uint32
399		for i, cnt := range s.count[:s.symbolLen] {
400			if cnt > maxC {
401				maxV = i
402				maxC = cnt
403			}
404		}
405		s.norm[maxV] += int16(toDistribute)
406		return nil
407	}
408
409	if total == 0 {
410		// all of the symbols were low enough for the lowOne or lowThreshold
411		for i := uint32(0); toDistribute > 0; i = (i + 1) % (uint32(s.symbolLen)) {
412			if s.norm[i] > 0 {
413				toDistribute--
414				s.norm[i]++
415			}
416		}
417		return nil
418	}
419
420	var (
421		vStepLog = 62 - uint64(tableLog)
422		mid      = uint64((1 << (vStepLog - 1)) - 1)
423		rStep    = (((1 << vStepLog) * uint64(toDistribute)) + mid) / uint64(total) // scale on remaining
424		tmpTotal = mid
425	)
426	for i, cnt := range s.count[:s.symbolLen] {
427		if s.norm[i] == notYetAssigned {
428			var (
429				end    = tmpTotal + uint64(cnt)*rStep
430				sStart = uint32(tmpTotal >> vStepLog)
431				sEnd   = uint32(end >> vStepLog)
432				weight = sEnd - sStart
433			)
434			if weight < 1 {
435				return errors.New("weight < 1")
436			}
437			s.norm[i] = int16(weight)
438			tmpTotal = end
439		}
440	}
441	return nil
442}
443
444// optimalTableLog calculates and sets the optimal tableLog in s.actualTableLog
445func (s *fseEncoder) optimalTableLog(length int) {
446	tableLog := uint8(maxEncTableLog)
447	minBitsSrc := highBit(uint32(length)) + 1
448	minBitsSymbols := highBit(uint32(s.symbolLen-1)) + 2
449	minBits := uint8(minBitsSymbols)
450	if minBitsSrc < minBitsSymbols {
451		minBits = uint8(minBitsSrc)
452	}
453
454	maxBitsSrc := uint8(highBit(uint32(length-1))) - 2
455	if maxBitsSrc < tableLog {
456		// Accuracy can be reduced
457		tableLog = maxBitsSrc
458	}
459	if minBits > tableLog {
460		tableLog = minBits
461	}
462	// Need a minimum to safely represent all symbol values
463	if tableLog < minEncTablelog {
464		tableLog = minEncTablelog
465	}
466	if tableLog > maxEncTableLog {
467		tableLog = maxEncTableLog
468	}
469	s.actualTableLog = tableLog
470}
471
472// validateNorm validates the normalized histogram table.
473func (s *fseEncoder) validateNorm() (err error) {
474	var total int
475	for _, v := range s.norm[:s.symbolLen] {
476		if v >= 0 {
477			total += int(v)
478		} else {
479			total -= int(v)
480		}
481	}
482	defer func() {
483		if err == nil {
484			return
485		}
486		fmt.Printf("selected TableLog: %d, Symbol length: %d\n", s.actualTableLog, s.symbolLen)
487		for i, v := range s.norm[:s.symbolLen] {
488			fmt.Printf("%3d: %5d -> %4d \n", i, s.count[i], v)
489		}
490	}()
491	if total != (1 << s.actualTableLog) {
492		return fmt.Errorf("warning: Total == %d != %d", total, 1<<s.actualTableLog)
493	}
494	for i, v := range s.count[s.symbolLen:] {
495		if v != 0 {
496			return fmt.Errorf("warning: Found symbol out of range, %d after cut", i)
497		}
498	}
499	return nil
500}
501
502// writeCount will write the normalized histogram count to header.
503// This is read back by readNCount.
504func (s *fseEncoder) writeCount(out []byte) ([]byte, error) {
505	var (
506		tableLog  = s.actualTableLog
507		tableSize = 1 << tableLog
508		previous0 bool
509		charnum   uint16
510
511		maxHeaderSize = ((int(s.symbolLen) * int(tableLog)) >> 3) + 3
512
513		// Write Table Size
514		bitStream = uint32(tableLog - minEncTablelog)
515		bitCount  = uint(4)
516		remaining = int16(tableSize + 1) /* +1 for extra accuracy */
517		threshold = int16(tableSize)
518		nbBits    = uint(tableLog + 1)
519	)
520	if s.useRLE {
521		return append(out, s.rleVal), nil
522	}
523	if s.preDefined || s.reUsed {
524		// Never write predefined.
525		return out, nil
526	}
527	outP := len(out)
528	out = out[:outP+maxHeaderSize]
529
530	// stops at 1
531	for remaining > 1 {
532		if previous0 {
533			start := charnum
534			for s.norm[charnum] == 0 {
535				charnum++
536			}
537			for charnum >= start+24 {
538				start += 24
539				bitStream += uint32(0xFFFF) << bitCount
540				out[outP] = byte(bitStream)
541				out[outP+1] = byte(bitStream >> 8)
542				outP += 2
543				bitStream >>= 16
544			}
545			for charnum >= start+3 {
546				start += 3
547				bitStream += 3 << bitCount
548				bitCount += 2
549			}
550			bitStream += uint32(charnum-start) << bitCount
551			bitCount += 2
552			if bitCount > 16 {
553				out[outP] = byte(bitStream)
554				out[outP+1] = byte(bitStream >> 8)
555				outP += 2
556				bitStream >>= 16
557				bitCount -= 16
558			}
559		}
560
561		count := s.norm[charnum]
562		charnum++
563		max := (2*threshold - 1) - remaining
564		if count < 0 {
565			remaining += count
566		} else {
567			remaining -= count
568		}
569		count++ // +1 for extra accuracy
570		if count >= threshold {
571			count += max // [0..max[ [max..threshold[ (...) [threshold+max 2*threshold[
572		}
573		bitStream += uint32(count) << bitCount
574		bitCount += nbBits
575		if count < max {
576			bitCount--
577		}
578
579		previous0 = count == 1
580		if remaining < 1 {
581			return nil, errors.New("internal error: remaining < 1")
582		}
583		for remaining < threshold {
584			nbBits--
585			threshold >>= 1
586		}
587
588		if bitCount > 16 {
589			out[outP] = byte(bitStream)
590			out[outP+1] = byte(bitStream >> 8)
591			outP += 2
592			bitStream >>= 16
593			bitCount -= 16
594		}
595	}
596
597	out[outP] = byte(bitStream)
598	out[outP+1] = byte(bitStream >> 8)
599	outP += int((bitCount + 7) / 8)
600
601	if uint16(charnum) > s.symbolLen {
602		return nil, errors.New("internal error: charnum > s.symbolLen")
603	}
604	return out[:outP], nil
605}
606
607// Approximate symbol cost, as fractional value, using fixed-point format (accuracyLog fractional bits)
608// note 1 : assume symbolValue is valid (<= maxSymbolValue)
609// note 2 : if freq[symbolValue]==0, @return a fake cost of tableLog+1 bits *
610func (s *fseEncoder) bitCost(symbolValue uint8, accuracyLog uint32) uint32 {
611	minNbBits := s.ct.symbolTT[symbolValue].deltaNbBits >> 16
612	threshold := (minNbBits + 1) << 16
613	if debug {
614		if !(s.actualTableLog < 16) {
615			panic("!s.actualTableLog < 16")
616		}
617		// ensure enough room for renormalization double shift
618		if !(uint8(accuracyLog) < 31-s.actualTableLog) {
619			panic("!uint8(accuracyLog) < 31-s.actualTableLog")
620		}
621	}
622	tableSize := uint32(1) << s.actualTableLog
623	deltaFromThreshold := threshold - (s.ct.symbolTT[symbolValue].deltaNbBits + tableSize)
624	// linear interpolation (very approximate)
625	normalizedDeltaFromThreshold := (deltaFromThreshold << accuracyLog) >> s.actualTableLog
626	bitMultiplier := uint32(1) << accuracyLog
627	if debug {
628		if s.ct.symbolTT[symbolValue].deltaNbBits+tableSize > threshold {
629			panic("s.ct.symbolTT[symbolValue].deltaNbBits+tableSize > threshold")
630		}
631		if normalizedDeltaFromThreshold > bitMultiplier {
632			panic("normalizedDeltaFromThreshold > bitMultiplier")
633		}
634	}
635	return (minNbBits+1)*bitMultiplier - normalizedDeltaFromThreshold
636}
637
638// Returns the cost in bits of encoding the distribution in count using ctable.
639// Histogram should only be up to the last non-zero symbol.
640// Returns an -1 if ctable cannot represent all the symbols in count.
641func (s *fseEncoder) approxSize(hist []uint32) uint32 {
642	if int(s.symbolLen) < len(hist) {
643		// More symbols than we have.
644		return math.MaxUint32
645	}
646	if s.useRLE {
647		// We will never reuse RLE encoders.
648		return math.MaxUint32
649	}
650	const kAccuracyLog = 8
651	badCost := (uint32(s.actualTableLog) + 1) << kAccuracyLog
652	var cost uint32
653	for i, v := range hist {
654		if v == 0 {
655			continue
656		}
657		if s.norm[i] == 0 {
658			return math.MaxUint32
659		}
660		bitCost := s.bitCost(uint8(i), kAccuracyLog)
661		if bitCost > badCost {
662			return math.MaxUint32
663		}
664		cost += v * bitCost
665	}
666	return cost >> kAccuracyLog
667}
668
669// maxHeaderSize returns the maximum header size in bits.
670// This is not exact size, but we want a penalty for new tables anyway.
671func (s *fseEncoder) maxHeaderSize() uint32 {
672	if s.preDefined {
673		return 0
674	}
675	if s.useRLE {
676		return 8
677	}
678	return (((uint32(s.symbolLen) * uint32(s.actualTableLog)) >> 3) + 3) * 8
679}
680
681// cState contains the compression state of a stream.
682type cState struct {
683	bw         *bitWriter
684	stateTable []uint16
685	state      uint16
686}
687
688// init will initialize the compression state to the first symbol of the stream.
689func (c *cState) init(bw *bitWriter, ct *cTable, first symbolTransform) {
690	c.bw = bw
691	c.stateTable = ct.stateTable
692	if len(c.stateTable) == 1 {
693		// RLE
694		c.stateTable[0] = uint16(0)
695		c.state = 0
696		return
697	}
698	nbBitsOut := (first.deltaNbBits + (1 << 15)) >> 16
699	im := int32((nbBitsOut << 16) - first.deltaNbBits)
700	lu := (im >> nbBitsOut) + int32(first.deltaFindState)
701	c.state = c.stateTable[lu]
702	return
703}
704
705// encode the output symbol provided and write it to the bitstream.
706func (c *cState) encode(symbolTT symbolTransform) {
707	nbBitsOut := (uint32(c.state) + symbolTT.deltaNbBits) >> 16
708	dstState := int32(c.state>>(nbBitsOut&15)) + int32(symbolTT.deltaFindState)
709	c.bw.addBits16NC(c.state, uint8(nbBitsOut))
710	c.state = c.stateTable[dstState]
711}
712
713// flush will write the tablelog to the output and flush the remaining full bytes.
714func (c *cState) flush(tableLog uint8) {
715	c.bw.flush32()
716	c.bw.addBits16NC(c.state, tableLog)
717}
718