1// Copyright 2015, Joe Tsai. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE.md file. 4 5// Package prefix implements bit readers and writers that use prefix encoding. 6package prefix 7 8import ( 9 "fmt" 10 "sort" 11 12 "github.com/dsnet/compress/internal" 13 "github.com/dsnet/compress/internal/errors" 14) 15 16func errorf(c int, f string, a ...interface{}) error { 17 return errors.Error{Code: c, Pkg: "prefix", Msg: fmt.Sprintf(f, a...)} 18} 19 20func panicf(c int, f string, a ...interface{}) { 21 errors.Panic(errorf(c, f, a...)) 22} 23 24const ( 25 countBits = 5 // Number of bits to store the bit-length of the code 26 valueBits = 27 // Number of bits to store the code value 27 28 countMask = (1 << countBits) - 1 29) 30 31// PrefixCode is a representation of a prefix code, which is conceptually a 32// mapping from some arbitrary symbol to some bit-string. 33// 34// The Sym and Cnt fields are typically provided by the user, 35// while the Len and Val fields are generated by this package. 36type PrefixCode struct { 37 Sym uint32 // The symbol being mapped 38 Cnt uint32 // The number times this symbol is used 39 Len uint32 // Bit-length of the prefix code 40 Val uint32 // Value of the prefix code (must be in 0..(1<<Len)-1) 41} 42type PrefixCodes []PrefixCode 43 44type prefixCodesBySymbol []PrefixCode 45 46func (c prefixCodesBySymbol) Len() int { return len(c) } 47func (c prefixCodesBySymbol) Less(i, j int) bool { return c[i].Sym < c[j].Sym } 48func (c prefixCodesBySymbol) Swap(i, j int) { c[i], c[j] = c[j], c[i] } 49 50type prefixCodesByCount []PrefixCode 51 52func (c prefixCodesByCount) Len() int { return len(c) } 53func (c prefixCodesByCount) Less(i, j int) bool { 54 return c[i].Cnt < c[j].Cnt || (c[i].Cnt == c[j].Cnt && c[i].Sym < c[j].Sym) 55} 56func (c prefixCodesByCount) Swap(i, j int) { c[i], c[j] = c[j], c[i] } 57 58func (pc PrefixCodes) SortBySymbol() { sort.Sort(prefixCodesBySymbol(pc)) } 59func (pc PrefixCodes) SortByCount() { sort.Sort(prefixCodesByCount(pc)) } 60 61// Length computes the total bit-length using the Len and Cnt fields. 62func (pc PrefixCodes) Length() (nb uint) { 63 for _, c := range pc { 64 nb += uint(c.Len * c.Cnt) 65 } 66 return nb 67} 68 69// checkLengths reports whether the codes form a complete prefix tree. 70func (pc PrefixCodes) checkLengths() bool { 71 sum := 1 << valueBits 72 for _, c := range pc { 73 sum -= (1 << valueBits) >> uint(c.Len) 74 } 75 return sum == 0 || len(pc) == 0 76} 77 78// checkPrefixes reports whether all codes have non-overlapping prefixes. 79func (pc PrefixCodes) checkPrefixes() bool { 80 for i, c1 := range pc { 81 for j, c2 := range pc { 82 mask := uint32(1)<<c1.Len - 1 83 if i != j && c1.Len <= c2.Len && c1.Val&mask == c2.Val&mask { 84 return false 85 } 86 } 87 } 88 return true 89} 90 91// checkCanonical reports whether all codes are canonical. 92// That is, they have the following properties: 93// 94// 1. All codes of a given bit-length are consecutive values. 95// 2. Shorter codes lexicographically precede longer codes. 96// 97// The codes must have unique symbols and be sorted by the symbol 98// The Len and Val fields in each code must be populated. 99func (pc PrefixCodes) checkCanonical() bool { 100 // Rule 1. 101 var vals [valueBits + 1]PrefixCode 102 for _, c := range pc { 103 if c.Len > 0 { 104 c.Val = internal.ReverseUint32N(c.Val, uint(c.Len)) 105 if vals[c.Len].Cnt > 0 && vals[c.Len].Val+1 != c.Val { 106 return false 107 } 108 vals[c.Len].Val = c.Val 109 vals[c.Len].Cnt++ 110 } 111 } 112 113 // Rule 2. 114 var last PrefixCode 115 for _, v := range vals { 116 if v.Cnt > 0 { 117 curVal := v.Val - v.Cnt + 1 118 if last.Cnt != 0 && last.Val >= curVal { 119 return false 120 } 121 last = v 122 } 123 } 124 return true 125} 126 127// GenerateLengths assigns non-zero bit-lengths to all codes. Codes with high 128// frequency counts will be assigned shorter codes to reduce bit entropy. 129// This function is used primarily by compressors. 130// 131// The input codes must have the Cnt field populated, be sorted by count. 132// Even if a code has a count of 0, a non-zero bit-length will be assigned. 133// 134// The result will have the Len field populated. The algorithm used guarantees 135// that Len <= maxBits and that it is a complete prefix tree. The resulting 136// codes will remain sorted by count. 137func GenerateLengths(codes PrefixCodes, maxBits uint) error { 138 if len(codes) <= 1 { 139 if len(codes) == 1 { 140 codes[0].Len = 0 141 } 142 return nil 143 } 144 145 // Verify that the codes are in ascending order by count. 146 cntLast := codes[0].Cnt 147 for _, c := range codes[1:] { 148 if c.Cnt < cntLast { 149 return errorf(errors.Invalid, "non-monotonically increasing symbol counts") 150 } 151 cntLast = c.Cnt 152 } 153 154 // Construct a Huffman tree used to generate the bit-lengths. 155 // 156 // The Huffman tree is a binary tree where each symbol lies as a leaf node 157 // on this tree. The length of the prefix code to assign is the depth of 158 // that leaf from the root. The Huffman algorithm, which runs in O(n), 159 // is used to generate the tree. It assumes that codes are sorted in 160 // increasing order of frequency. 161 // 162 // The algorithm is as follows: 163 // 1. Start with two queues, F and Q, where F contains all of the starting 164 // symbols sorted such that symbols with lowest counts come first. 165 // 2. While len(F)+len(Q) > 1: 166 // 2a. Dequeue the node from F or Q that has the lowest weight as N0. 167 // 2b. Dequeue the node from F or Q that has the lowest weight as N1. 168 // 2c. Create a new node N that has N0 and N1 as its children. 169 // 2d. Enqueue N into the back of Q. 170 // 3. The tree's root node is Q[0]. 171 type node struct { 172 cnt uint32 173 174 // n0 or c0 represent the left child of this node. 175 // Since Go does not have unions, only one of these will be set. 176 // Similarly, n1 or c1 represent the right child of this node. 177 // 178 // If n0 or n1 is set, then it represents a "pointer" to another 179 // node in the Huffman tree. Since Go's pointer analysis cannot reason 180 // that these node pointers do not escape (golang.org/issue/13493), 181 // we use an index to a node in the nodes slice as a pseudo-pointer. 182 // 183 // If c0 or c1 is set, then it represents a leaf "node" in the 184 // Huffman tree. The leaves are the PrefixCode values themselves. 185 n0, n1 int // Index to child nodes 186 c0, c1 *PrefixCode 187 } 188 var nodeIdx int 189 var nodeArr [1024]node // Large enough to handle most cases on the stack 190 nodes := nodeArr[:] 191 if len(nodes) < len(codes) { 192 nodes = make([]node, len(codes)) // Number of internal nodes < number of leaves 193 } 194 freqs, queue := codes, nodes[:0] 195 for len(freqs)+len(queue) > 1 { 196 // These are the two smallest nodes at the front of freqs and queue. 197 var n node 198 if len(queue) == 0 || (len(freqs) > 0 && freqs[0].Cnt <= queue[0].cnt) { 199 n.c0, freqs = &freqs[0], freqs[1:] 200 n.cnt += n.c0.Cnt 201 } else { 202 n.cnt += queue[0].cnt 203 n.n0 = nodeIdx // nodeIdx is same as &queue[0] - &nodes[0] 204 nodeIdx++ 205 queue = queue[1:] 206 } 207 if len(queue) == 0 || (len(freqs) > 0 && freqs[0].Cnt <= queue[0].cnt) { 208 n.c1, freqs = &freqs[0], freqs[1:] 209 n.cnt += n.c1.Cnt 210 } else { 211 n.cnt += queue[0].cnt 212 n.n1 = nodeIdx // nodeIdx is same as &queue[0] - &nodes[0] 213 nodeIdx++ 214 queue = queue[1:] 215 } 216 queue = append(queue, n) 217 } 218 rootIdx := nodeIdx 219 220 // Search the whole binary tree, noting when we hit each leaf node. 221 // We do not care about the exact Huffman tree structure, but rather we only 222 // care about depth of each of the leaf nodes. That is, the depth determines 223 // how long each symbol is in bits. 224 // 225 // Since the number of leaves is n, there is at most n internal nodes. 226 // Thus, this algorithm runs in O(n). 227 var fixBits bool 228 var explore func(int, uint) 229 explore = func(rootIdx int, level uint) { 230 root := &nodes[rootIdx] 231 232 // Explore left branch. 233 if root.c0 == nil { 234 explore(root.n0, level+1) 235 } else { 236 fixBits = fixBits || (level > maxBits) 237 root.c0.Len = uint32(level) 238 } 239 240 // Explore right branch. 241 if root.c1 == nil { 242 explore(root.n1, level+1) 243 } else { 244 fixBits = fixBits || (level > maxBits) 245 root.c1.Len = uint32(level) 246 } 247 } 248 explore(rootIdx, 1) 249 250 // Fix the bit-lengths if we violate the maxBits requirement. 251 if fixBits { 252 // Create histogram for number of symbols with each bit-length. 253 var symBitsArr [valueBits + 1]uint32 254 symBits := symBitsArr[:] // symBits[nb] indicates number of symbols using nb bits 255 for _, c := range codes { 256 for int(c.Len) >= len(symBits) { 257 symBits = append(symBits, 0) 258 } 259 symBits[c.Len]++ 260 } 261 262 // Fudge the tree such that the largest bit-length is <= maxBits. 263 // This is accomplish by effectively doing a tree rotation. That is, we 264 // increase the bit-length of some higher frequency code, so that the 265 // bit-lengths of lower frequency codes can be decreased. 266 // 267 // Visually, this looks like the following transform: 268 // 269 // Level Before After 270 // __ ___ 271 // / \ / \ 272 // n-1 X / \ /\ /\ 273 // n X /\ X X X X 274 // n+1 X X 275 // 276 var treeRotate func(uint) 277 treeRotate = func(nb uint) { 278 if symBits[nb-1] == 0 { 279 treeRotate(nb - 1) 280 } 281 symBits[nb-1] -= 1 // Push this node to the level below 282 symBits[nb] += 3 // This level gets one node from above, two from below 283 symBits[nb+1] -= 2 // Push two nodes to the level above 284 } 285 for i := uint(len(symBits)) - 1; i > maxBits; i-- { 286 for symBits[i] > 0 { 287 treeRotate(i - 1) 288 } 289 } 290 291 // Assign bit-lengths to each code. Since codes is sorted in increasing 292 // order of frequency, that means that the most frequently used symbols 293 // should have the shortest bit-lengths. Thus, we copy symbols to codes 294 // from the back of codes first. 295 cs := codes 296 for nb, cnt := range symBits { 297 if cnt > 0 { 298 pos := len(cs) - int(cnt) 299 cs2 := cs[pos:] 300 for i := range cs2 { 301 cs2[i].Len = uint32(nb) 302 } 303 cs = cs[:pos] 304 } 305 } 306 if len(cs) != 0 { 307 panic("not all codes were used up") 308 } 309 } 310 311 if internal.Debug && !codes.checkLengths() { 312 panic("incomplete prefix tree detected") 313 } 314 return nil 315} 316 317// GeneratePrefixes assigns a prefix value to all codes according to the 318// bit-lengths. This function is used by both compressors and decompressors. 319// 320// The input codes must have the Sym and Len fields populated and be 321// sorted by symbol. The bit-lengths of each code must be properly allocated, 322// such that it forms a complete tree. 323// 324// The result will have the Val field populated and will produce a canonical 325// prefix tree. The resulting codes will remain sorted by symbol. 326func GeneratePrefixes(codes PrefixCodes) error { 327 if len(codes) <= 1 { 328 if len(codes) == 1 { 329 if codes[0].Len != 0 { 330 return errorf(errors.Invalid, "degenerate prefix tree with one node") 331 } 332 codes[0].Val = 0 333 } 334 return nil 335 } 336 337 // Compute basic statistics on the symbols. 338 var bitCnts [valueBits + 1]uint 339 c0 := codes[0] 340 bitCnts[c0.Len]++ 341 minBits, maxBits, symLast := c0.Len, c0.Len, c0.Sym 342 for _, c := range codes[1:] { 343 if c.Sym <= symLast { 344 return errorf(errors.Invalid, "non-unique or non-monotonically increasing symbols") 345 } 346 if minBits > c.Len { 347 minBits = c.Len 348 } 349 if maxBits < c.Len { 350 maxBits = c.Len 351 } 352 bitCnts[c.Len]++ // Histogram of bit counts 353 symLast = c.Sym // Keep track of last symbol 354 } 355 if minBits == 0 { 356 return errorf(errors.Invalid, "invalid prefix bit-length") 357 } 358 359 // Compute the next code for a symbol of a given bit length. 360 var nextCodes [valueBits + 1]uint 361 var code uint 362 for i := minBits; i <= maxBits; i++ { 363 code <<= 1 364 nextCodes[i] = code 365 code += bitCnts[i] 366 } 367 if code != 1<<maxBits { 368 return errorf(errors.Invalid, "degenerate prefix tree") 369 } 370 371 // Assign the code to each symbol. 372 for i, c := range codes { 373 codes[i].Val = internal.ReverseUint32N(uint32(nextCodes[c.Len]), uint(c.Len)) 374 nextCodes[c.Len]++ 375 } 376 377 if internal.Debug && !codes.checkPrefixes() { 378 panic("overlapping prefixes detected") 379 } 380 if internal.Debug && !codes.checkCanonical() { 381 panic("non-canonical prefixes detected") 382 } 383 return nil 384} 385 386func allocUint32s(s []uint32, n int) []uint32 { 387 if cap(s) >= n { 388 return s[:n] 389 } 390 return make([]uint32, n, n*3/2) 391} 392 393func extendSliceUint32s(s [][]uint32, n int) [][]uint32 { 394 if cap(s) >= n { 395 return s[:n] 396 } 397 ss := make([][]uint32, n, n*3/2) 398 copy(ss, s[:cap(s)]) 399 return ss 400} 401