1// Package huff0 provides fast huffman encoding as used in zstd.
2//
3// See README.md at https://github.com/klauspost/compress/tree/master/huff0 for details.
4package huff0
5
6import (
7	"errors"
8	"fmt"
9	"math"
10	"math/bits"
11
12	"github.com/klauspost/compress/fse"
13)
14
15const (
16	maxSymbolValue = 255
17
18	// zstandard limits tablelog to 11, see:
19	// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#huffman-tree-description
20	tableLogMax     = 11
21	tableLogDefault = 11
22	minTablelog     = 5
23	huffNodesLen    = 512
24
25	// BlockSizeMax is maximum input size for a single block uncompressed.
26	BlockSizeMax = 128 << 10
27)
28
29var (
30	// ErrIncompressible is returned when input is judged to be too hard to compress.
31	ErrIncompressible = errors.New("input is not compressible")
32
33	// ErrUseRLE is returned from the compressor when the input is a single byte value repeated.
34	ErrUseRLE = errors.New("input is single value repeated")
35
36	// ErrTooBig is return if input is too large for a single block.
37	ErrTooBig = errors.New("input too big")
38)
39
40type ReusePolicy uint8
41
42const (
43	// ReusePolicyAllow will allow reuse if it produces smaller output.
44	ReusePolicyAllow ReusePolicy = iota
45
46	// ReusePolicyPrefer will re-use aggressively if possible.
47	// This will not check if a new table will produce smaller output,
48	// except if the current table is impossible to use or
49	// compressed output is bigger than input.
50	ReusePolicyPrefer
51
52	// ReusePolicyNone will disable re-use of tables.
53	// This is slightly faster than ReusePolicyAllow but may produce larger output.
54	ReusePolicyNone
55)
56
57type Scratch struct {
58	count [maxSymbolValue + 1]uint32
59
60	// Per block parameters.
61	// These can be used to override compression parameters of the block.
62	// Do not touch, unless you know what you are doing.
63
64	// Out is output buffer.
65	// If the scratch is re-used before the caller is done processing the output,
66	// set this field to nil.
67	// Otherwise the output buffer will be re-used for next Compression/Decompression step
68	// and allocation will be avoided.
69	Out []byte
70
71	// OutTable will contain the table data only, if a new table has been generated.
72	// Slice of the returned data.
73	OutTable []byte
74
75	// OutData will contain the compressed data.
76	// Slice of the returned data.
77	OutData []byte
78
79	// MaxSymbolValue will override the maximum symbol value of the next block.
80	MaxSymbolValue uint8
81
82	// TableLog will attempt to override the tablelog for the next block.
83	// Must be <= 11.
84	TableLog uint8
85
86	// Reuse will specify the reuse policy
87	Reuse ReusePolicy
88
89	br             byteReader
90	symbolLen      uint16 // Length of active part of the symbol table.
91	maxCount       int    // count of the most probable symbol
92	clearCount     bool   // clear count
93	actualTableLog uint8  // Selected tablelog.
94	prevTable      cTable // Table used for previous compression.
95	cTable         cTable // compression table
96	dt             dTable // decompression table
97	nodes          []nodeElt
98	tmpOut         [4][]byte
99	fse            *fse.Scratch
100	huffWeight     [maxSymbolValue + 1]byte
101}
102
103func (s *Scratch) prepare(in []byte) (*Scratch, error) {
104	if len(in) > BlockSizeMax {
105		return nil, ErrTooBig
106	}
107	if s == nil {
108		s = &Scratch{}
109	}
110	if s.MaxSymbolValue == 0 {
111		s.MaxSymbolValue = maxSymbolValue
112	}
113	if s.TableLog == 0 {
114		s.TableLog = tableLogDefault
115	}
116	if s.TableLog > tableLogMax {
117		return nil, fmt.Errorf("tableLog (%d) > maxTableLog (%d)", s.TableLog, tableLogMax)
118	}
119	if s.clearCount && s.maxCount == 0 {
120		for i := range s.count {
121			s.count[i] = 0
122		}
123		s.clearCount = false
124	}
125	if cap(s.Out) == 0 {
126		s.Out = make([]byte, 0, len(in))
127	}
128	s.Out = s.Out[:0]
129
130	s.OutTable = nil
131	s.OutData = nil
132	if cap(s.nodes) < huffNodesLen+1 {
133		s.nodes = make([]nodeElt, 0, huffNodesLen+1)
134	}
135	s.nodes = s.nodes[:0]
136	if s.fse == nil {
137		s.fse = &fse.Scratch{}
138	}
139	s.br.init(in)
140
141	return s, nil
142}
143
144type cTable []cTableEntry
145
146func (c cTable) write(s *Scratch) error {
147	var (
148		// precomputed conversion table
149		bitsToWeight [tableLogMax + 1]byte
150		huffLog      = s.actualTableLog
151		// last weight is not saved.
152		maxSymbolValue = uint8(s.symbolLen - 1)
153		huffWeight     = s.huffWeight[:256]
154	)
155	const (
156		maxFSETableLog = 6
157	)
158	// convert to weight
159	bitsToWeight[0] = 0
160	for n := uint8(1); n < huffLog+1; n++ {
161		bitsToWeight[n] = huffLog + 1 - n
162	}
163
164	// Acquire histogram for FSE.
165	hist := s.fse.Histogram()
166	hist = hist[:256]
167	for i := range hist[:16] {
168		hist[i] = 0
169	}
170	for n := uint8(0); n < maxSymbolValue; n++ {
171		v := bitsToWeight[c[n].nBits] & 15
172		huffWeight[n] = v
173		hist[v]++
174	}
175
176	// FSE compress if feasible.
177	if maxSymbolValue >= 2 {
178		huffMaxCnt := uint32(0)
179		huffMax := uint8(0)
180		for i, v := range hist[:16] {
181			if v == 0 {
182				continue
183			}
184			huffMax = byte(i)
185			if v > huffMaxCnt {
186				huffMaxCnt = v
187			}
188		}
189		s.fse.HistogramFinished(huffMax, int(huffMaxCnt))
190		s.fse.TableLog = maxFSETableLog
191		b, err := fse.Compress(huffWeight[:maxSymbolValue], s.fse)
192		if err == nil && len(b) < int(s.symbolLen>>1) {
193			s.Out = append(s.Out, uint8(len(b)))
194			s.Out = append(s.Out, b...)
195			return nil
196		}
197	}
198	// write raw values as 4-bits (max : 15)
199	if maxSymbolValue > (256 - 128) {
200		// should not happen : likely means source cannot be compressed
201		return ErrIncompressible
202	}
203	op := s.Out
204	// special case, pack weights 4 bits/weight.
205	op = append(op, 128|(maxSymbolValue-1))
206	// be sure it doesn't cause msan issue in final combination
207	huffWeight[maxSymbolValue] = 0
208	for n := uint16(0); n < uint16(maxSymbolValue); n += 2 {
209		op = append(op, (huffWeight[n]<<4)|huffWeight[n+1])
210	}
211	s.Out = op
212	return nil
213}
214
215// estimateSize returns the estimated size in bytes of the input represented in the
216// histogram supplied.
217func (c cTable) estimateSize(hist []uint32) int {
218	nbBits := uint32(7)
219	for i, v := range c[:len(hist)] {
220		nbBits += uint32(v.nBits) * hist[i]
221	}
222	return int(nbBits >> 3)
223}
224
225// minSize returns the minimum possible size considering the shannon limit.
226func (s *Scratch) minSize(total int) int {
227	nbBits := float64(7)
228	fTotal := float64(total)
229	for _, v := range s.count[:s.symbolLen] {
230		n := float64(v)
231		if n > 0 {
232			nbBits += math.Log2(fTotal/n) * n
233		}
234	}
235	return int(nbBits) >> 3
236}
237
238func highBit32(val uint32) (n uint32) {
239	return uint32(bits.Len32(val) - 1)
240}
241