1// Copyright 2015, Joe Tsai. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE.md file.
4
5// Package prefix implements bit readers and writers that use prefix encoding.
6package prefix
7
8import (
9	"fmt"
10	"sort"
11
12	"github.com/dsnet/compress/internal"
13	"github.com/dsnet/compress/internal/errors"
14)
15
16func errorf(c int, f string, a ...interface{}) error {
17	return errors.Error{Code: c, Pkg: "prefix", Msg: fmt.Sprintf(f, a...)}
18}
19
20func panicf(c int, f string, a ...interface{}) {
21	errors.Panic(errorf(c, f, a...))
22}
23
24const (
25	countBits = 5  // Number of bits to store the bit-length of the code
26	valueBits = 27 // Number of bits to store the code value
27
28	countMask = (1 << countBits) - 1
29)
30
31// PrefixCode is a representation of a prefix code, which is conceptually a
32// mapping from some arbitrary symbol to some bit-string.
33//
34// The Sym and Cnt fields are typically provided by the user,
35// while the Len and Val fields are generated by this package.
36type PrefixCode struct {
37	Sym uint32 // The symbol being mapped
38	Cnt uint32 // The number times this symbol is used
39	Len uint32 // Bit-length of the prefix code
40	Val uint32 // Value of the prefix code (must be in 0..(1<<Len)-1)
41}
42type PrefixCodes []PrefixCode
43
44type prefixCodesBySymbol []PrefixCode
45
46func (c prefixCodesBySymbol) Len() int           { return len(c) }
47func (c prefixCodesBySymbol) Less(i, j int) bool { return c[i].Sym < c[j].Sym }
48func (c prefixCodesBySymbol) Swap(i, j int)      { c[i], c[j] = c[j], c[i] }
49
50type prefixCodesByCount []PrefixCode
51
52func (c prefixCodesByCount) Len() int { return len(c) }
53func (c prefixCodesByCount) Less(i, j int) bool {
54	return c[i].Cnt < c[j].Cnt || (c[i].Cnt == c[j].Cnt && c[i].Sym < c[j].Sym)
55}
56func (c prefixCodesByCount) Swap(i, j int) { c[i], c[j] = c[j], c[i] }
57
58func (pc PrefixCodes) SortBySymbol() { sort.Sort(prefixCodesBySymbol(pc)) }
59func (pc PrefixCodes) SortByCount()  { sort.Sort(prefixCodesByCount(pc)) }
60
61// Length computes the total bit-length using the Len and Cnt fields.
62func (pc PrefixCodes) Length() (nb uint) {
63	for _, c := range pc {
64		nb += uint(c.Len * c.Cnt)
65	}
66	return nb
67}
68
69// checkLengths reports whether the codes form a complete prefix tree.
70func (pc PrefixCodes) checkLengths() bool {
71	sum := 1 << valueBits
72	for _, c := range pc {
73		sum -= (1 << valueBits) >> uint(c.Len)
74	}
75	return sum == 0 || len(pc) == 0
76}
77
78// checkPrefixes reports whether all codes have non-overlapping prefixes.
79func (pc PrefixCodes) checkPrefixes() bool {
80	for i, c1 := range pc {
81		for j, c2 := range pc {
82			mask := uint32(1)<<c1.Len - 1
83			if i != j && c1.Len <= c2.Len && c1.Val&mask == c2.Val&mask {
84				return false
85			}
86		}
87	}
88	return true
89}
90
91// checkCanonical reports whether all codes are canonical.
92// That is, they have the following properties:
93//
94//	1. All codes of a given bit-length are consecutive values.
95//	2. Shorter codes lexicographically precede longer codes.
96//
97// The codes must have unique symbols and be sorted by the symbol
98// The Len and Val fields in each code must be populated.
99func (pc PrefixCodes) checkCanonical() bool {
100	// Rule 1.
101	var vals [valueBits + 1]PrefixCode
102	for _, c := range pc {
103		if c.Len > 0 {
104			c.Val = internal.ReverseUint32N(c.Val, uint(c.Len))
105			if vals[c.Len].Cnt > 0 && vals[c.Len].Val+1 != c.Val {
106				return false
107			}
108			vals[c.Len].Val = c.Val
109			vals[c.Len].Cnt++
110		}
111	}
112
113	// Rule 2.
114	var last PrefixCode
115	for _, v := range vals {
116		if v.Cnt > 0 {
117			curVal := v.Val - v.Cnt + 1
118			if last.Cnt != 0 && last.Val >= curVal {
119				return false
120			}
121			last = v
122		}
123	}
124	return true
125}
126
127// GenerateLengths assigns non-zero bit-lengths to all codes. Codes with high
128// frequency counts will be assigned shorter codes to reduce bit entropy.
129// This function is used primarily by compressors.
130//
131// The input codes must have the Cnt field populated, be sorted by count.
132// Even if a code has a count of 0, a non-zero bit-length will be assigned.
133//
134// The result will have the Len field populated. The algorithm used guarantees
135// that Len <= maxBits and that it is a complete prefix tree. The resulting
136// codes will remain sorted by count.
137func GenerateLengths(codes PrefixCodes, maxBits uint) error {
138	if len(codes) <= 1 {
139		if len(codes) == 1 {
140			codes[0].Len = 0
141		}
142		return nil
143	}
144
145	// Verify that the codes are in ascending order by count.
146	cntLast := codes[0].Cnt
147	for _, c := range codes[1:] {
148		if c.Cnt < cntLast {
149			return errorf(errors.Invalid, "non-monotonically increasing symbol counts")
150		}
151		cntLast = c.Cnt
152	}
153
154	// Construct a Huffman tree used to generate the bit-lengths.
155	//
156	// The Huffman tree is a binary tree where each symbol lies as a leaf node
157	// on this tree. The length of the prefix code to assign is the depth of
158	// that leaf from the root. The Huffman algorithm, which runs in O(n),
159	// is used to generate the tree. It assumes that codes are sorted in
160	// increasing order of frequency.
161	//
162	// The algorithm is as follows:
163	//	1. Start with two queues, F and Q, where F contains all of the starting
164	//	symbols sorted such that symbols with lowest counts come first.
165	//	2. While len(F)+len(Q) > 1:
166	//		2a. Dequeue the node from F or Q that has the lowest weight as N0.
167	//		2b. Dequeue the node from F or Q that has the lowest weight as N1.
168	//		2c. Create a new node N that has N0 and N1 as its children.
169	//		2d. Enqueue N into the back of Q.
170	//	3. The tree's root node is Q[0].
171	type node struct {
172		cnt uint32
173
174		// n0 or c0 represent the left child of this node.
175		// Since Go does not have unions, only one of these will be set.
176		// Similarly, n1 or c1 represent the right child of this node.
177		//
178		// If n0 or n1 is set, then it represents a "pointer" to another
179		// node in the Huffman tree. Since Go's pointer analysis cannot reason
180		// that these node pointers do not escape (golang.org/issue/13493),
181		// we use an index to a node in the nodes slice as a pseudo-pointer.
182		//
183		// If c0 or c1 is set, then it represents a leaf "node" in the
184		// Huffman tree. The leaves are the PrefixCode values themselves.
185		n0, n1 int // Index to child nodes
186		c0, c1 *PrefixCode
187	}
188	var nodeIdx int
189	var nodeArr [1024]node // Large enough to handle most cases on the stack
190	nodes := nodeArr[:]
191	if len(nodes) < len(codes) {
192		nodes = make([]node, len(codes)) // Number of internal nodes < number of leaves
193	}
194	freqs, queue := codes, nodes[:0]
195	for len(freqs)+len(queue) > 1 {
196		// These are the two smallest nodes at the front of freqs and queue.
197		var n node
198		if len(queue) == 0 || (len(freqs) > 0 && freqs[0].Cnt <= queue[0].cnt) {
199			n.c0, freqs = &freqs[0], freqs[1:]
200			n.cnt += n.c0.Cnt
201		} else {
202			n.cnt += queue[0].cnt
203			n.n0 = nodeIdx // nodeIdx is same as &queue[0] - &nodes[0]
204			nodeIdx++
205			queue = queue[1:]
206		}
207		if len(queue) == 0 || (len(freqs) > 0 && freqs[0].Cnt <= queue[0].cnt) {
208			n.c1, freqs = &freqs[0], freqs[1:]
209			n.cnt += n.c1.Cnt
210		} else {
211			n.cnt += queue[0].cnt
212			n.n1 = nodeIdx // nodeIdx is same as &queue[0] - &nodes[0]
213			nodeIdx++
214			queue = queue[1:]
215		}
216		queue = append(queue, n)
217	}
218	rootIdx := nodeIdx
219
220	// Search the whole binary tree, noting when we hit each leaf node.
221	// We do not care about the exact Huffman tree structure, but rather we only
222	// care about depth of each of the leaf nodes. That is, the depth determines
223	// how long each symbol is in bits.
224	//
225	// Since the number of leaves is n, there is at most n internal nodes.
226	// Thus, this algorithm runs in O(n).
227	var fixBits bool
228	var explore func(int, uint)
229	explore = func(rootIdx int, level uint) {
230		root := &nodes[rootIdx]
231
232		// Explore left branch.
233		if root.c0 == nil {
234			explore(root.n0, level+1)
235		} else {
236			fixBits = fixBits || (level > maxBits)
237			root.c0.Len = uint32(level)
238		}
239
240		// Explore right branch.
241		if root.c1 == nil {
242			explore(root.n1, level+1)
243		} else {
244			fixBits = fixBits || (level > maxBits)
245			root.c1.Len = uint32(level)
246		}
247	}
248	explore(rootIdx, 1)
249
250	// Fix the bit-lengths if we violate the maxBits requirement.
251	if fixBits {
252		// Create histogram for number of symbols with each bit-length.
253		var symBitsArr [valueBits + 1]uint32
254		symBits := symBitsArr[:] // symBits[nb] indicates number of symbols using nb bits
255		for _, c := range codes {
256			for int(c.Len) >= len(symBits) {
257				symBits = append(symBits, 0)
258			}
259			symBits[c.Len]++
260		}
261
262		// Fudge the tree such that the largest bit-length is <= maxBits.
263		// This is accomplish by effectively doing a tree rotation. That is, we
264		// increase the bit-length of some higher frequency code, so that the
265		// bit-lengths of lower frequency codes can be decreased.
266		//
267		// Visually, this looks like the following transform:
268		//
269		//	Level   Before       After
270		//	          __          ___
271		//	         /  \        /   \
272		//	 n-1    X  / \      /\   /\
273		//	 n        X  /\    X  X X  X
274		//	 n+1        X  X
275		//
276		var treeRotate func(uint)
277		treeRotate = func(nb uint) {
278			if symBits[nb-1] == 0 {
279				treeRotate(nb - 1)
280			}
281			symBits[nb-1] -= 1 // Push this node to the level below
282			symBits[nb] += 3   // This level gets one node from above, two from below
283			symBits[nb+1] -= 2 // Push two nodes to the level above
284		}
285		for i := uint(len(symBits)) - 1; i > maxBits; i-- {
286			for symBits[i] > 0 {
287				treeRotate(i - 1)
288			}
289		}
290
291		// Assign bit-lengths to each code. Since codes is sorted in increasing
292		// order of frequency, that means that the most frequently used symbols
293		// should have the shortest bit-lengths. Thus, we copy symbols to codes
294		// from the back of codes first.
295		cs := codes
296		for nb, cnt := range symBits {
297			if cnt > 0 {
298				pos := len(cs) - int(cnt)
299				cs2 := cs[pos:]
300				for i := range cs2 {
301					cs2[i].Len = uint32(nb)
302				}
303				cs = cs[:pos]
304			}
305		}
306		if len(cs) != 0 {
307			panic("not all codes were used up")
308		}
309	}
310
311	if internal.Debug && !codes.checkLengths() {
312		panic("incomplete prefix tree detected")
313	}
314	return nil
315}
316
317// GeneratePrefixes assigns a prefix value to all codes according to the
318// bit-lengths. This function is used by both compressors and decompressors.
319//
320// The input codes must have the Sym and Len fields populated and be
321// sorted by symbol. The bit-lengths of each code must be properly allocated,
322// such that it forms a complete tree.
323//
324// The result will have the Val field populated and will produce a canonical
325// prefix tree. The resulting codes will remain sorted by symbol.
326func GeneratePrefixes(codes PrefixCodes) error {
327	if len(codes) <= 1 {
328		if len(codes) == 1 {
329			if codes[0].Len != 0 {
330				return errorf(errors.Invalid, "degenerate prefix tree with one node")
331			}
332			codes[0].Val = 0
333		}
334		return nil
335	}
336
337	// Compute basic statistics on the symbols.
338	var bitCnts [valueBits + 1]uint
339	c0 := codes[0]
340	bitCnts[c0.Len]++
341	minBits, maxBits, symLast := c0.Len, c0.Len, c0.Sym
342	for _, c := range codes[1:] {
343		if c.Sym <= symLast {
344			return errorf(errors.Invalid, "non-unique or non-monotonically increasing symbols")
345		}
346		if minBits > c.Len {
347			minBits = c.Len
348		}
349		if maxBits < c.Len {
350			maxBits = c.Len
351		}
352		bitCnts[c.Len]++ // Histogram of bit counts
353		symLast = c.Sym  // Keep track of last symbol
354	}
355	if minBits == 0 {
356		return errorf(errors.Invalid, "invalid prefix bit-length")
357	}
358
359	// Compute the next code for a symbol of a given bit length.
360	var nextCodes [valueBits + 1]uint
361	var code uint
362	for i := minBits; i <= maxBits; i++ {
363		code <<= 1
364		nextCodes[i] = code
365		code += bitCnts[i]
366	}
367	if code != 1<<maxBits {
368		return errorf(errors.Invalid, "degenerate prefix tree")
369	}
370
371	// Assign the code to each symbol.
372	for i, c := range codes {
373		codes[i].Val = internal.ReverseUint32N(uint32(nextCodes[c.Len]), uint(c.Len))
374		nextCodes[c.Len]++
375	}
376
377	if internal.Debug && !codes.checkPrefixes() {
378		panic("overlapping prefixes detected")
379	}
380	if internal.Debug && !codes.checkCanonical() {
381		panic("non-canonical prefixes detected")
382	}
383	return nil
384}
385
386func allocUint32s(s []uint32, n int) []uint32 {
387	if cap(s) >= n {
388		return s[:n]
389	}
390	return make([]uint32, n, n*3/2)
391}
392
393func extendSliceUint32s(s [][]uint32, n int) [][]uint32 {
394	if cap(s) >= n {
395		return s[:n]
396	}
397	ss := make([][]uint32, n, n*3/2)
398	copy(ss, s[:cap(s)])
399	return ss
400}
401