1package huff0
2
3import (
4	"fmt"
5	"runtime"
6	"sync"
7)
8
9// Compress1X will compress the input.
10// The output can be decoded using Decompress1X.
11// Supply a Scratch object. The scratch object contains state about re-use,
12// So when sharing across independent encodes, be sure to set the re-use policy.
13func Compress1X(in []byte, s *Scratch) (out []byte, reUsed bool, err error) {
14	s, err = s.prepare(in)
15	if err != nil {
16		return nil, false, err
17	}
18	return compress(in, s, s.compress1X)
19}
20
21// Compress4X will compress the input. The input is split into 4 independent blocks
22// and compressed similar to Compress1X.
23// The output can be decoded using Decompress4X.
24// Supply a Scratch object. The scratch object contains state about re-use,
25// So when sharing across independent encodes, be sure to set the re-use policy.
26func Compress4X(in []byte, s *Scratch) (out []byte, reUsed bool, err error) {
27	s, err = s.prepare(in)
28	if err != nil {
29		return nil, false, err
30	}
31	if false {
32		// TODO: compress4Xp only slightly faster.
33		const parallelThreshold = 8 << 10
34		if len(in) < parallelThreshold || runtime.GOMAXPROCS(0) == 1 {
35			return compress(in, s, s.compress4X)
36		}
37		return compress(in, s, s.compress4Xp)
38	}
39	return compress(in, s, s.compress4X)
40}
41
42func compress(in []byte, s *Scratch, compressor func(src []byte) ([]byte, error)) (out []byte, reUsed bool, err error) {
43	// Nuke previous table if we cannot reuse anyway.
44	if s.Reuse == ReusePolicyNone {
45		s.prevTable = s.prevTable[:0]
46	}
47
48	// Create histogram, if none was provided.
49	maxCount := s.maxCount
50	var canReuse = false
51	if maxCount == 0 {
52		maxCount, canReuse = s.countSimple(in)
53	} else {
54		canReuse = s.canUseTable(s.prevTable)
55	}
56
57	// We want the output size to be less than this:
58	wantSize := len(in)
59	if s.WantLogLess > 0 {
60		wantSize -= wantSize >> s.WantLogLess
61	}
62
63	// Reset for next run.
64	s.clearCount = true
65	s.maxCount = 0
66	if maxCount >= len(in) {
67		if maxCount > len(in) {
68			return nil, false, fmt.Errorf("maxCount (%d) > length (%d)", maxCount, len(in))
69		}
70		if len(in) == 1 {
71			return nil, false, ErrIncompressible
72		}
73		// One symbol, use RLE
74		return nil, false, ErrUseRLE
75	}
76	if maxCount == 1 || maxCount < (len(in)>>7) {
77		// Each symbol present maximum once or too well distributed.
78		return nil, false, ErrIncompressible
79	}
80	if s.Reuse == ReusePolicyMust && !canReuse {
81		// We must reuse, but we can't.
82		return nil, false, ErrIncompressible
83	}
84	if (s.Reuse == ReusePolicyPrefer || s.Reuse == ReusePolicyMust) && canReuse {
85		keepTable := s.cTable
86		keepTL := s.actualTableLog
87		s.cTable = s.prevTable
88		s.actualTableLog = s.prevTableLog
89		s.Out, err = compressor(in)
90		s.cTable = keepTable
91		s.actualTableLog = keepTL
92		if err == nil && len(s.Out) < wantSize {
93			s.OutData = s.Out
94			return s.Out, true, nil
95		}
96		if s.Reuse == ReusePolicyMust {
97			return nil, false, ErrIncompressible
98		}
99		// Do not attempt to re-use later.
100		s.prevTable = s.prevTable[:0]
101	}
102
103	// Calculate new table.
104	err = s.buildCTable()
105	if err != nil {
106		return nil, false, err
107	}
108
109	if false && !s.canUseTable(s.cTable) {
110		panic("invalid table generated")
111	}
112
113	if s.Reuse == ReusePolicyAllow && canReuse {
114		hSize := len(s.Out)
115		oldSize := s.prevTable.estimateSize(s.count[:s.symbolLen])
116		newSize := s.cTable.estimateSize(s.count[:s.symbolLen])
117		if oldSize <= hSize+newSize || hSize+12 >= wantSize {
118			// Retain cTable even if we re-use.
119			keepTable := s.cTable
120			keepTL := s.actualTableLog
121
122			s.cTable = s.prevTable
123			s.actualTableLog = s.prevTableLog
124			s.Out, err = compressor(in)
125
126			// Restore ctable.
127			s.cTable = keepTable
128			s.actualTableLog = keepTL
129			if err != nil {
130				return nil, false, err
131			}
132			if len(s.Out) >= wantSize {
133				return nil, false, ErrIncompressible
134			}
135			s.OutData = s.Out
136			return s.Out, true, nil
137		}
138	}
139
140	// Use new table
141	err = s.cTable.write(s)
142	if err != nil {
143		s.OutTable = nil
144		return nil, false, err
145	}
146	s.OutTable = s.Out
147
148	// Compress using new table
149	s.Out, err = compressor(in)
150	if err != nil {
151		s.OutTable = nil
152		return nil, false, err
153	}
154	if len(s.Out) >= wantSize {
155		s.OutTable = nil
156		return nil, false, ErrIncompressible
157	}
158	// Move current table into previous.
159	s.prevTable, s.prevTableLog, s.cTable = s.cTable, s.actualTableLog, s.prevTable[:0]
160	s.OutData = s.Out[len(s.OutTable):]
161	return s.Out, false, nil
162}
163
164func (s *Scratch) compress1X(src []byte) ([]byte, error) {
165	return s.compress1xDo(s.Out, src)
166}
167
168func (s *Scratch) compress1xDo(dst, src []byte) ([]byte, error) {
169	var bw = bitWriter{out: dst}
170
171	// N is length divisible by 4.
172	n := len(src)
173	n -= n & 3
174	cTable := s.cTable[:256]
175
176	// Encode last bytes.
177	for i := len(src) & 3; i > 0; i-- {
178		bw.encSymbol(cTable, src[n+i-1])
179	}
180	n -= 4
181	if s.actualTableLog <= 8 {
182		for ; n >= 0; n -= 4 {
183			tmp := src[n : n+4]
184			// tmp should be len 4
185			bw.flush32()
186			bw.encTwoSymbols(cTable, tmp[3], tmp[2])
187			bw.encTwoSymbols(cTable, tmp[1], tmp[0])
188		}
189	} else {
190		for ; n >= 0; n -= 4 {
191			tmp := src[n : n+4]
192			// tmp should be len 4
193			bw.flush32()
194			bw.encTwoSymbols(cTable, tmp[3], tmp[2])
195			bw.flush32()
196			bw.encTwoSymbols(cTable, tmp[1], tmp[0])
197		}
198	}
199	err := bw.close()
200	return bw.out, err
201}
202
203var sixZeros [6]byte
204
205func (s *Scratch) compress4X(src []byte) ([]byte, error) {
206	if len(src) < 12 {
207		return nil, ErrIncompressible
208	}
209	segmentSize := (len(src) + 3) / 4
210
211	// Add placeholder for output length
212	offsetIdx := len(s.Out)
213	s.Out = append(s.Out, sixZeros[:]...)
214
215	for i := 0; i < 4; i++ {
216		toDo := src
217		if len(toDo) > segmentSize {
218			toDo = toDo[:segmentSize]
219		}
220		src = src[len(toDo):]
221
222		var err error
223		idx := len(s.Out)
224		s.Out, err = s.compress1xDo(s.Out, toDo)
225		if err != nil {
226			return nil, err
227		}
228		// Write compressed length as little endian before block.
229		if i < 3 {
230			// Last length is not written.
231			length := len(s.Out) - idx
232			s.Out[i*2+offsetIdx] = byte(length)
233			s.Out[i*2+offsetIdx+1] = byte(length >> 8)
234		}
235	}
236
237	return s.Out, nil
238}
239
240// compress4Xp will compress 4 streams using separate goroutines.
241func (s *Scratch) compress4Xp(src []byte) ([]byte, error) {
242	if len(src) < 12 {
243		return nil, ErrIncompressible
244	}
245	// Add placeholder for output length
246	s.Out = s.Out[:6]
247
248	segmentSize := (len(src) + 3) / 4
249	var wg sync.WaitGroup
250	var errs [4]error
251	wg.Add(4)
252	for i := 0; i < 4; i++ {
253		toDo := src
254		if len(toDo) > segmentSize {
255			toDo = toDo[:segmentSize]
256		}
257		src = src[len(toDo):]
258
259		// Separate goroutine for each block.
260		go func(i int) {
261			s.tmpOut[i], errs[i] = s.compress1xDo(s.tmpOut[i][:0], toDo)
262			wg.Done()
263		}(i)
264	}
265	wg.Wait()
266	for i := 0; i < 4; i++ {
267		if errs[i] != nil {
268			return nil, errs[i]
269		}
270		o := s.tmpOut[i]
271		// Write compressed length as little endian before block.
272		if i < 3 {
273			// Last length is not written.
274			s.Out[i*2] = byte(len(o))
275			s.Out[i*2+1] = byte(len(o) >> 8)
276		}
277
278		// Write output.
279		s.Out = append(s.Out, o...)
280	}
281	return s.Out, nil
282}
283
284// countSimple will create a simple histogram in s.count.
285// Returns the biggest count.
286// Does not update s.clearCount.
287func (s *Scratch) countSimple(in []byte) (max int, reuse bool) {
288	reuse = true
289	for _, v := range in {
290		s.count[v]++
291	}
292	m := uint32(0)
293	if len(s.prevTable) > 0 {
294		for i, v := range s.count[:] {
295			if v > m {
296				m = v
297			}
298			if v > 0 {
299				s.symbolLen = uint16(i) + 1
300				if i >= len(s.prevTable) {
301					reuse = false
302				} else {
303					if s.prevTable[i].nBits == 0 {
304						reuse = false
305					}
306				}
307			}
308		}
309		return int(m), reuse
310	}
311	for i, v := range s.count[:] {
312		if v > m {
313			m = v
314		}
315		if v > 0 {
316			s.symbolLen = uint16(i) + 1
317		}
318	}
319	return int(m), false
320}
321
322func (s *Scratch) canUseTable(c cTable) bool {
323	if len(c) < int(s.symbolLen) {
324		return false
325	}
326	for i, v := range s.count[:s.symbolLen] {
327		if v != 0 && c[i].nBits == 0 {
328			return false
329		}
330	}
331	return true
332}
333
334func (s *Scratch) validateTable(c cTable) bool {
335	if len(c) < int(s.symbolLen) {
336		return false
337	}
338	for i, v := range s.count[:s.symbolLen] {
339		if v != 0 {
340			if c[i].nBits == 0 {
341				return false
342			}
343			if c[i].nBits > s.actualTableLog {
344				return false
345			}
346		}
347	}
348	return true
349}
350
351// minTableLog provides the minimum logSize to safely represent a distribution.
352func (s *Scratch) minTableLog() uint8 {
353	minBitsSrc := highBit32(uint32(s.br.remain())) + 1
354	minBitsSymbols := highBit32(uint32(s.symbolLen-1)) + 2
355	if minBitsSrc < minBitsSymbols {
356		return uint8(minBitsSrc)
357	}
358	return uint8(minBitsSymbols)
359}
360
361// optimalTableLog calculates and sets the optimal tableLog in s.actualTableLog
362func (s *Scratch) optimalTableLog() {
363	tableLog := s.TableLog
364	minBits := s.minTableLog()
365	maxBitsSrc := uint8(highBit32(uint32(s.br.remain()-1))) - 1
366	if maxBitsSrc < tableLog {
367		// Accuracy can be reduced
368		tableLog = maxBitsSrc
369	}
370	if minBits > tableLog {
371		tableLog = minBits
372	}
373	// Need a minimum to safely represent all symbol values
374	if tableLog < minTablelog {
375		tableLog = minTablelog
376	}
377	if tableLog > tableLogMax {
378		tableLog = tableLogMax
379	}
380	s.actualTableLog = tableLog
381}
382
383type cTableEntry struct {
384	val   uint16
385	nBits uint8
386	// We have 8 bits extra
387}
388
389const huffNodesMask = huffNodesLen - 1
390
391func (s *Scratch) buildCTable() error {
392	s.optimalTableLog()
393	s.huffSort()
394	if cap(s.cTable) < maxSymbolValue+1 {
395		s.cTable = make([]cTableEntry, s.symbolLen, maxSymbolValue+1)
396	} else {
397		s.cTable = s.cTable[:s.symbolLen]
398		for i := range s.cTable {
399			s.cTable[i] = cTableEntry{}
400		}
401	}
402
403	var startNode = int16(s.symbolLen)
404	nonNullRank := s.symbolLen - 1
405
406	nodeNb := startNode
407	huffNode := s.nodes[1 : huffNodesLen+1]
408
409	// This overlays the slice above, but allows "-1" index lookups.
410	// Different from reference implementation.
411	huffNode0 := s.nodes[0 : huffNodesLen+1]
412
413	for huffNode[nonNullRank].count == 0 {
414		nonNullRank--
415	}
416
417	lowS := int16(nonNullRank)
418	nodeRoot := nodeNb + lowS - 1
419	lowN := nodeNb
420	huffNode[nodeNb].count = huffNode[lowS].count + huffNode[lowS-1].count
421	huffNode[lowS].parent, huffNode[lowS-1].parent = uint16(nodeNb), uint16(nodeNb)
422	nodeNb++
423	lowS -= 2
424	for n := nodeNb; n <= nodeRoot; n++ {
425		huffNode[n].count = 1 << 30
426	}
427	// fake entry, strong barrier
428	huffNode0[0].count = 1 << 31
429
430	// create parents
431	for nodeNb <= nodeRoot {
432		var n1, n2 int16
433		if huffNode0[lowS+1].count < huffNode0[lowN+1].count {
434			n1 = lowS
435			lowS--
436		} else {
437			n1 = lowN
438			lowN++
439		}
440		if huffNode0[lowS+1].count < huffNode0[lowN+1].count {
441			n2 = lowS
442			lowS--
443		} else {
444			n2 = lowN
445			lowN++
446		}
447
448		huffNode[nodeNb].count = huffNode0[n1+1].count + huffNode0[n2+1].count
449		huffNode0[n1+1].parent, huffNode0[n2+1].parent = uint16(nodeNb), uint16(nodeNb)
450		nodeNb++
451	}
452
453	// distribute weights (unlimited tree height)
454	huffNode[nodeRoot].nbBits = 0
455	for n := nodeRoot - 1; n >= startNode; n-- {
456		huffNode[n].nbBits = huffNode[huffNode[n].parent].nbBits + 1
457	}
458	for n := uint16(0); n <= nonNullRank; n++ {
459		huffNode[n].nbBits = huffNode[huffNode[n].parent].nbBits + 1
460	}
461	s.actualTableLog = s.setMaxHeight(int(nonNullRank))
462	maxNbBits := s.actualTableLog
463
464	// fill result into tree (val, nbBits)
465	if maxNbBits > tableLogMax {
466		return fmt.Errorf("internal error: maxNbBits (%d) > tableLogMax (%d)", maxNbBits, tableLogMax)
467	}
468	var nbPerRank [tableLogMax + 1]uint16
469	var valPerRank [16]uint16
470	for _, v := range huffNode[:nonNullRank+1] {
471		nbPerRank[v.nbBits]++
472	}
473	// determine stating value per rank
474	{
475		min := uint16(0)
476		for n := maxNbBits; n > 0; n-- {
477			// get starting value within each rank
478			valPerRank[n] = min
479			min += nbPerRank[n]
480			min >>= 1
481		}
482	}
483
484	// push nbBits per symbol, symbol order
485	for _, v := range huffNode[:nonNullRank+1] {
486		s.cTable[v.symbol].nBits = v.nbBits
487	}
488
489	// assign value within rank, symbol order
490	t := s.cTable[:s.symbolLen]
491	for n, val := range t {
492		nbits := val.nBits & 15
493		v := valPerRank[nbits]
494		t[n].val = v
495		valPerRank[nbits] = v + 1
496	}
497
498	return nil
499}
500
501// huffSort will sort symbols, decreasing order.
502func (s *Scratch) huffSort() {
503	type rankPos struct {
504		base    uint32
505		current uint32
506	}
507
508	// Clear nodes
509	nodes := s.nodes[:huffNodesLen+1]
510	s.nodes = nodes
511	nodes = nodes[1 : huffNodesLen+1]
512
513	// Sort into buckets based on length of symbol count.
514	var rank [32]rankPos
515	for _, v := range s.count[:s.symbolLen] {
516		r := highBit32(v+1) & 31
517		rank[r].base++
518	}
519	// maxBitLength is log2(BlockSizeMax) + 1
520	const maxBitLength = 18 + 1
521	for n := maxBitLength; n > 0; n-- {
522		rank[n-1].base += rank[n].base
523	}
524	for n := range rank[:maxBitLength] {
525		rank[n].current = rank[n].base
526	}
527	for n, c := range s.count[:s.symbolLen] {
528		r := (highBit32(c+1) + 1) & 31
529		pos := rank[r].current
530		rank[r].current++
531		prev := nodes[(pos-1)&huffNodesMask]
532		for pos > rank[r].base && c > prev.count {
533			nodes[pos&huffNodesMask] = prev
534			pos--
535			prev = nodes[(pos-1)&huffNodesMask]
536		}
537		nodes[pos&huffNodesMask] = nodeElt{count: c, symbol: byte(n)}
538	}
539}
540
541func (s *Scratch) setMaxHeight(lastNonNull int) uint8 {
542	maxNbBits := s.actualTableLog
543	huffNode := s.nodes[1 : huffNodesLen+1]
544	//huffNode = huffNode[: huffNodesLen]
545
546	largestBits := huffNode[lastNonNull].nbBits
547
548	// early exit : no elt > maxNbBits
549	if largestBits <= maxNbBits {
550		return largestBits
551	}
552	totalCost := int(0)
553	baseCost := int(1) << (largestBits - maxNbBits)
554	n := uint32(lastNonNull)
555
556	for huffNode[n].nbBits > maxNbBits {
557		totalCost += baseCost - (1 << (largestBits - huffNode[n].nbBits))
558		huffNode[n].nbBits = maxNbBits
559		n--
560	}
561	// n stops at huffNode[n].nbBits <= maxNbBits
562
563	for huffNode[n].nbBits == maxNbBits {
564		n--
565	}
566	// n end at index of smallest symbol using < maxNbBits
567
568	// renorm totalCost
569	totalCost >>= largestBits - maxNbBits /* note : totalCost is necessarily a multiple of baseCost */
570
571	// repay normalized cost
572	{
573		const noSymbol = 0xF0F0F0F0
574		var rankLast [tableLogMax + 2]uint32
575
576		for i := range rankLast[:] {
577			rankLast[i] = noSymbol
578		}
579
580		// Get pos of last (smallest) symbol per rank
581		{
582			currentNbBits := maxNbBits
583			for pos := int(n); pos >= 0; pos-- {
584				if huffNode[pos].nbBits >= currentNbBits {
585					continue
586				}
587				currentNbBits = huffNode[pos].nbBits // < maxNbBits
588				rankLast[maxNbBits-currentNbBits] = uint32(pos)
589			}
590		}
591
592		for totalCost > 0 {
593			nBitsToDecrease := uint8(highBit32(uint32(totalCost))) + 1
594
595			for ; nBitsToDecrease > 1; nBitsToDecrease-- {
596				highPos := rankLast[nBitsToDecrease]
597				lowPos := rankLast[nBitsToDecrease-1]
598				if highPos == noSymbol {
599					continue
600				}
601				if lowPos == noSymbol {
602					break
603				}
604				highTotal := huffNode[highPos].count
605				lowTotal := 2 * huffNode[lowPos].count
606				if highTotal <= lowTotal {
607					break
608				}
609			}
610			// only triggered when no more rank 1 symbol left => find closest one (note : there is necessarily at least one !)
611			// HUF_MAX_TABLELOG test just to please gcc 5+; but it should not be necessary
612			// FIXME: try to remove
613			for (nBitsToDecrease <= tableLogMax) && (rankLast[nBitsToDecrease] == noSymbol) {
614				nBitsToDecrease++
615			}
616			totalCost -= 1 << (nBitsToDecrease - 1)
617			if rankLast[nBitsToDecrease-1] == noSymbol {
618				// this rank is no longer empty
619				rankLast[nBitsToDecrease-1] = rankLast[nBitsToDecrease]
620			}
621			huffNode[rankLast[nBitsToDecrease]].nbBits++
622			if rankLast[nBitsToDecrease] == 0 {
623				/* special case, reached largest symbol */
624				rankLast[nBitsToDecrease] = noSymbol
625			} else {
626				rankLast[nBitsToDecrease]--
627				if huffNode[rankLast[nBitsToDecrease]].nbBits != maxNbBits-nBitsToDecrease {
628					rankLast[nBitsToDecrease] = noSymbol /* this rank is now empty */
629				}
630			}
631		}
632
633		for totalCost < 0 { /* Sometimes, cost correction overshoot */
634			if rankLast[1] == noSymbol { /* special case : no rank 1 symbol (using maxNbBits-1); let's create one from largest rank 0 (using maxNbBits) */
635				for huffNode[n].nbBits == maxNbBits {
636					n--
637				}
638				huffNode[n+1].nbBits--
639				rankLast[1] = n + 1
640				totalCost++
641				continue
642			}
643			huffNode[rankLast[1]+1].nbBits--
644			rankLast[1]++
645			totalCost++
646		}
647	}
648	return maxNbBits
649}
650
651type nodeElt struct {
652	count  uint32
653	parent uint16
654	symbol byte
655	nbBits uint8
656}
657