1package huff0 2 3import ( 4 "fmt" 5 "runtime" 6 "sync" 7) 8 9// Compress1X will compress the input. 10// The output can be decoded using Decompress1X. 11// Supply a Scratch object. The scratch object contains state about re-use, 12// So when sharing across independent encodes, be sure to set the re-use policy. 13func Compress1X(in []byte, s *Scratch) (out []byte, reUsed bool, err error) { 14 s, err = s.prepare(in) 15 if err != nil { 16 return nil, false, err 17 } 18 return compress(in, s, s.compress1X) 19} 20 21// Compress4X will compress the input. The input is split into 4 independent blocks 22// and compressed similar to Compress1X. 23// The output can be decoded using Decompress4X. 24// Supply a Scratch object. The scratch object contains state about re-use, 25// So when sharing across independent encodes, be sure to set the re-use policy. 26func Compress4X(in []byte, s *Scratch) (out []byte, reUsed bool, err error) { 27 s, err = s.prepare(in) 28 if err != nil { 29 return nil, false, err 30 } 31 if false { 32 // TODO: compress4Xp only slightly faster. 33 const parallelThreshold = 8 << 10 34 if len(in) < parallelThreshold || runtime.GOMAXPROCS(0) == 1 { 35 return compress(in, s, s.compress4X) 36 } 37 return compress(in, s, s.compress4Xp) 38 } 39 return compress(in, s, s.compress4X) 40} 41 42func compress(in []byte, s *Scratch, compressor func(src []byte) ([]byte, error)) (out []byte, reUsed bool, err error) { 43 // Nuke previous table if we cannot reuse anyway. 44 if s.Reuse == ReusePolicyNone { 45 s.prevTable = s.prevTable[:0] 46 } 47 48 // Create histogram, if none was provided. 49 maxCount := s.maxCount 50 var canReuse = false 51 if maxCount == 0 { 52 maxCount, canReuse = s.countSimple(in) 53 } else { 54 canReuse = s.canUseTable(s.prevTable) 55 } 56 57 // We want the output size to be less than this: 58 wantSize := len(in) 59 if s.WantLogLess > 0 { 60 wantSize -= wantSize >> s.WantLogLess 61 } 62 63 // Reset for next run. 64 s.clearCount = true 65 s.maxCount = 0 66 if maxCount >= len(in) { 67 if maxCount > len(in) { 68 return nil, false, fmt.Errorf("maxCount (%d) > length (%d)", maxCount, len(in)) 69 } 70 if len(in) == 1 { 71 return nil, false, ErrIncompressible 72 } 73 // One symbol, use RLE 74 return nil, false, ErrUseRLE 75 } 76 if maxCount == 1 || maxCount < (len(in)>>7) { 77 // Each symbol present maximum once or too well distributed. 78 return nil, false, ErrIncompressible 79 } 80 if s.Reuse == ReusePolicyMust && !canReuse { 81 // We must reuse, but we can't. 82 return nil, false, ErrIncompressible 83 } 84 if (s.Reuse == ReusePolicyPrefer || s.Reuse == ReusePolicyMust) && canReuse { 85 keepTable := s.cTable 86 keepTL := s.actualTableLog 87 s.cTable = s.prevTable 88 s.actualTableLog = s.prevTableLog 89 s.Out, err = compressor(in) 90 s.cTable = keepTable 91 s.actualTableLog = keepTL 92 if err == nil && len(s.Out) < wantSize { 93 s.OutData = s.Out 94 return s.Out, true, nil 95 } 96 if s.Reuse == ReusePolicyMust { 97 return nil, false, ErrIncompressible 98 } 99 // Do not attempt to re-use later. 100 s.prevTable = s.prevTable[:0] 101 } 102 103 // Calculate new table. 104 err = s.buildCTable() 105 if err != nil { 106 return nil, false, err 107 } 108 109 if false && !s.canUseTable(s.cTable) { 110 panic("invalid table generated") 111 } 112 113 if s.Reuse == ReusePolicyAllow && canReuse { 114 hSize := len(s.Out) 115 oldSize := s.prevTable.estimateSize(s.count[:s.symbolLen]) 116 newSize := s.cTable.estimateSize(s.count[:s.symbolLen]) 117 if oldSize <= hSize+newSize || hSize+12 >= wantSize { 118 // Retain cTable even if we re-use. 119 keepTable := s.cTable 120 keepTL := s.actualTableLog 121 122 s.cTable = s.prevTable 123 s.actualTableLog = s.prevTableLog 124 s.Out, err = compressor(in) 125 126 // Restore ctable. 127 s.cTable = keepTable 128 s.actualTableLog = keepTL 129 if err != nil { 130 return nil, false, err 131 } 132 if len(s.Out) >= wantSize { 133 return nil, false, ErrIncompressible 134 } 135 s.OutData = s.Out 136 return s.Out, true, nil 137 } 138 } 139 140 // Use new table 141 err = s.cTable.write(s) 142 if err != nil { 143 s.OutTable = nil 144 return nil, false, err 145 } 146 s.OutTable = s.Out 147 148 // Compress using new table 149 s.Out, err = compressor(in) 150 if err != nil { 151 s.OutTable = nil 152 return nil, false, err 153 } 154 if len(s.Out) >= wantSize { 155 s.OutTable = nil 156 return nil, false, ErrIncompressible 157 } 158 // Move current table into previous. 159 s.prevTable, s.prevTableLog, s.cTable = s.cTable, s.actualTableLog, s.prevTable[:0] 160 s.OutData = s.Out[len(s.OutTable):] 161 return s.Out, false, nil 162} 163 164func (s *Scratch) compress1X(src []byte) ([]byte, error) { 165 return s.compress1xDo(s.Out, src) 166} 167 168func (s *Scratch) compress1xDo(dst, src []byte) ([]byte, error) { 169 var bw = bitWriter{out: dst} 170 171 // N is length divisible by 4. 172 n := len(src) 173 n -= n & 3 174 cTable := s.cTable[:256] 175 176 // Encode last bytes. 177 for i := len(src) & 3; i > 0; i-- { 178 bw.encSymbol(cTable, src[n+i-1]) 179 } 180 n -= 4 181 if s.actualTableLog <= 8 { 182 for ; n >= 0; n -= 4 { 183 tmp := src[n : n+4] 184 // tmp should be len 4 185 bw.flush32() 186 bw.encTwoSymbols(cTable, tmp[3], tmp[2]) 187 bw.encTwoSymbols(cTable, tmp[1], tmp[0]) 188 } 189 } else { 190 for ; n >= 0; n -= 4 { 191 tmp := src[n : n+4] 192 // tmp should be len 4 193 bw.flush32() 194 bw.encTwoSymbols(cTable, tmp[3], tmp[2]) 195 bw.flush32() 196 bw.encTwoSymbols(cTable, tmp[1], tmp[0]) 197 } 198 } 199 err := bw.close() 200 return bw.out, err 201} 202 203var sixZeros [6]byte 204 205func (s *Scratch) compress4X(src []byte) ([]byte, error) { 206 if len(src) < 12 { 207 return nil, ErrIncompressible 208 } 209 segmentSize := (len(src) + 3) / 4 210 211 // Add placeholder for output length 212 offsetIdx := len(s.Out) 213 s.Out = append(s.Out, sixZeros[:]...) 214 215 for i := 0; i < 4; i++ { 216 toDo := src 217 if len(toDo) > segmentSize { 218 toDo = toDo[:segmentSize] 219 } 220 src = src[len(toDo):] 221 222 var err error 223 idx := len(s.Out) 224 s.Out, err = s.compress1xDo(s.Out, toDo) 225 if err != nil { 226 return nil, err 227 } 228 // Write compressed length as little endian before block. 229 if i < 3 { 230 // Last length is not written. 231 length := len(s.Out) - idx 232 s.Out[i*2+offsetIdx] = byte(length) 233 s.Out[i*2+offsetIdx+1] = byte(length >> 8) 234 } 235 } 236 237 return s.Out, nil 238} 239 240// compress4Xp will compress 4 streams using separate goroutines. 241func (s *Scratch) compress4Xp(src []byte) ([]byte, error) { 242 if len(src) < 12 { 243 return nil, ErrIncompressible 244 } 245 // Add placeholder for output length 246 s.Out = s.Out[:6] 247 248 segmentSize := (len(src) + 3) / 4 249 var wg sync.WaitGroup 250 var errs [4]error 251 wg.Add(4) 252 for i := 0; i < 4; i++ { 253 toDo := src 254 if len(toDo) > segmentSize { 255 toDo = toDo[:segmentSize] 256 } 257 src = src[len(toDo):] 258 259 // Separate goroutine for each block. 260 go func(i int) { 261 s.tmpOut[i], errs[i] = s.compress1xDo(s.tmpOut[i][:0], toDo) 262 wg.Done() 263 }(i) 264 } 265 wg.Wait() 266 for i := 0; i < 4; i++ { 267 if errs[i] != nil { 268 return nil, errs[i] 269 } 270 o := s.tmpOut[i] 271 // Write compressed length as little endian before block. 272 if i < 3 { 273 // Last length is not written. 274 s.Out[i*2] = byte(len(o)) 275 s.Out[i*2+1] = byte(len(o) >> 8) 276 } 277 278 // Write output. 279 s.Out = append(s.Out, o...) 280 } 281 return s.Out, nil 282} 283 284// countSimple will create a simple histogram in s.count. 285// Returns the biggest count. 286// Does not update s.clearCount. 287func (s *Scratch) countSimple(in []byte) (max int, reuse bool) { 288 reuse = true 289 for _, v := range in { 290 s.count[v]++ 291 } 292 m := uint32(0) 293 if len(s.prevTable) > 0 { 294 for i, v := range s.count[:] { 295 if v > m { 296 m = v 297 } 298 if v > 0 { 299 s.symbolLen = uint16(i) + 1 300 if i >= len(s.prevTable) { 301 reuse = false 302 } else { 303 if s.prevTable[i].nBits == 0 { 304 reuse = false 305 } 306 } 307 } 308 } 309 return int(m), reuse 310 } 311 for i, v := range s.count[:] { 312 if v > m { 313 m = v 314 } 315 if v > 0 { 316 s.symbolLen = uint16(i) + 1 317 } 318 } 319 return int(m), false 320} 321 322func (s *Scratch) canUseTable(c cTable) bool { 323 if len(c) < int(s.symbolLen) { 324 return false 325 } 326 for i, v := range s.count[:s.symbolLen] { 327 if v != 0 && c[i].nBits == 0 { 328 return false 329 } 330 } 331 return true 332} 333 334func (s *Scratch) validateTable(c cTable) bool { 335 if len(c) < int(s.symbolLen) { 336 return false 337 } 338 for i, v := range s.count[:s.symbolLen] { 339 if v != 0 { 340 if c[i].nBits == 0 { 341 return false 342 } 343 if c[i].nBits > s.actualTableLog { 344 return false 345 } 346 } 347 } 348 return true 349} 350 351// minTableLog provides the minimum logSize to safely represent a distribution. 352func (s *Scratch) minTableLog() uint8 { 353 minBitsSrc := highBit32(uint32(s.br.remain())) + 1 354 minBitsSymbols := highBit32(uint32(s.symbolLen-1)) + 2 355 if minBitsSrc < minBitsSymbols { 356 return uint8(minBitsSrc) 357 } 358 return uint8(minBitsSymbols) 359} 360 361// optimalTableLog calculates and sets the optimal tableLog in s.actualTableLog 362func (s *Scratch) optimalTableLog() { 363 tableLog := s.TableLog 364 minBits := s.minTableLog() 365 maxBitsSrc := uint8(highBit32(uint32(s.br.remain()-1))) - 1 366 if maxBitsSrc < tableLog { 367 // Accuracy can be reduced 368 tableLog = maxBitsSrc 369 } 370 if minBits > tableLog { 371 tableLog = minBits 372 } 373 // Need a minimum to safely represent all symbol values 374 if tableLog < minTablelog { 375 tableLog = minTablelog 376 } 377 if tableLog > tableLogMax { 378 tableLog = tableLogMax 379 } 380 s.actualTableLog = tableLog 381} 382 383type cTableEntry struct { 384 val uint16 385 nBits uint8 386 // We have 8 bits extra 387} 388 389const huffNodesMask = huffNodesLen - 1 390 391func (s *Scratch) buildCTable() error { 392 s.optimalTableLog() 393 s.huffSort() 394 if cap(s.cTable) < maxSymbolValue+1 { 395 s.cTable = make([]cTableEntry, s.symbolLen, maxSymbolValue+1) 396 } else { 397 s.cTable = s.cTable[:s.symbolLen] 398 for i := range s.cTable { 399 s.cTable[i] = cTableEntry{} 400 } 401 } 402 403 var startNode = int16(s.symbolLen) 404 nonNullRank := s.symbolLen - 1 405 406 nodeNb := startNode 407 huffNode := s.nodes[1 : huffNodesLen+1] 408 409 // This overlays the slice above, but allows "-1" index lookups. 410 // Different from reference implementation. 411 huffNode0 := s.nodes[0 : huffNodesLen+1] 412 413 for huffNode[nonNullRank].count == 0 { 414 nonNullRank-- 415 } 416 417 lowS := int16(nonNullRank) 418 nodeRoot := nodeNb + lowS - 1 419 lowN := nodeNb 420 huffNode[nodeNb].count = huffNode[lowS].count + huffNode[lowS-1].count 421 huffNode[lowS].parent, huffNode[lowS-1].parent = uint16(nodeNb), uint16(nodeNb) 422 nodeNb++ 423 lowS -= 2 424 for n := nodeNb; n <= nodeRoot; n++ { 425 huffNode[n].count = 1 << 30 426 } 427 // fake entry, strong barrier 428 huffNode0[0].count = 1 << 31 429 430 // create parents 431 for nodeNb <= nodeRoot { 432 var n1, n2 int16 433 if huffNode0[lowS+1].count < huffNode0[lowN+1].count { 434 n1 = lowS 435 lowS-- 436 } else { 437 n1 = lowN 438 lowN++ 439 } 440 if huffNode0[lowS+1].count < huffNode0[lowN+1].count { 441 n2 = lowS 442 lowS-- 443 } else { 444 n2 = lowN 445 lowN++ 446 } 447 448 huffNode[nodeNb].count = huffNode0[n1+1].count + huffNode0[n2+1].count 449 huffNode0[n1+1].parent, huffNode0[n2+1].parent = uint16(nodeNb), uint16(nodeNb) 450 nodeNb++ 451 } 452 453 // distribute weights (unlimited tree height) 454 huffNode[nodeRoot].nbBits = 0 455 for n := nodeRoot - 1; n >= startNode; n-- { 456 huffNode[n].nbBits = huffNode[huffNode[n].parent].nbBits + 1 457 } 458 for n := uint16(0); n <= nonNullRank; n++ { 459 huffNode[n].nbBits = huffNode[huffNode[n].parent].nbBits + 1 460 } 461 s.actualTableLog = s.setMaxHeight(int(nonNullRank)) 462 maxNbBits := s.actualTableLog 463 464 // fill result into tree (val, nbBits) 465 if maxNbBits > tableLogMax { 466 return fmt.Errorf("internal error: maxNbBits (%d) > tableLogMax (%d)", maxNbBits, tableLogMax) 467 } 468 var nbPerRank [tableLogMax + 1]uint16 469 var valPerRank [16]uint16 470 for _, v := range huffNode[:nonNullRank+1] { 471 nbPerRank[v.nbBits]++ 472 } 473 // determine stating value per rank 474 { 475 min := uint16(0) 476 for n := maxNbBits; n > 0; n-- { 477 // get starting value within each rank 478 valPerRank[n] = min 479 min += nbPerRank[n] 480 min >>= 1 481 } 482 } 483 484 // push nbBits per symbol, symbol order 485 for _, v := range huffNode[:nonNullRank+1] { 486 s.cTable[v.symbol].nBits = v.nbBits 487 } 488 489 // assign value within rank, symbol order 490 t := s.cTable[:s.symbolLen] 491 for n, val := range t { 492 nbits := val.nBits & 15 493 v := valPerRank[nbits] 494 t[n].val = v 495 valPerRank[nbits] = v + 1 496 } 497 498 return nil 499} 500 501// huffSort will sort symbols, decreasing order. 502func (s *Scratch) huffSort() { 503 type rankPos struct { 504 base uint32 505 current uint32 506 } 507 508 // Clear nodes 509 nodes := s.nodes[:huffNodesLen+1] 510 s.nodes = nodes 511 nodes = nodes[1 : huffNodesLen+1] 512 513 // Sort into buckets based on length of symbol count. 514 var rank [32]rankPos 515 for _, v := range s.count[:s.symbolLen] { 516 r := highBit32(v+1) & 31 517 rank[r].base++ 518 } 519 // maxBitLength is log2(BlockSizeMax) + 1 520 const maxBitLength = 18 + 1 521 for n := maxBitLength; n > 0; n-- { 522 rank[n-1].base += rank[n].base 523 } 524 for n := range rank[:maxBitLength] { 525 rank[n].current = rank[n].base 526 } 527 for n, c := range s.count[:s.symbolLen] { 528 r := (highBit32(c+1) + 1) & 31 529 pos := rank[r].current 530 rank[r].current++ 531 prev := nodes[(pos-1)&huffNodesMask] 532 for pos > rank[r].base && c > prev.count { 533 nodes[pos&huffNodesMask] = prev 534 pos-- 535 prev = nodes[(pos-1)&huffNodesMask] 536 } 537 nodes[pos&huffNodesMask] = nodeElt{count: c, symbol: byte(n)} 538 } 539} 540 541func (s *Scratch) setMaxHeight(lastNonNull int) uint8 { 542 maxNbBits := s.actualTableLog 543 huffNode := s.nodes[1 : huffNodesLen+1] 544 //huffNode = huffNode[: huffNodesLen] 545 546 largestBits := huffNode[lastNonNull].nbBits 547 548 // early exit : no elt > maxNbBits 549 if largestBits <= maxNbBits { 550 return largestBits 551 } 552 totalCost := int(0) 553 baseCost := int(1) << (largestBits - maxNbBits) 554 n := uint32(lastNonNull) 555 556 for huffNode[n].nbBits > maxNbBits { 557 totalCost += baseCost - (1 << (largestBits - huffNode[n].nbBits)) 558 huffNode[n].nbBits = maxNbBits 559 n-- 560 } 561 // n stops at huffNode[n].nbBits <= maxNbBits 562 563 for huffNode[n].nbBits == maxNbBits { 564 n-- 565 } 566 // n end at index of smallest symbol using < maxNbBits 567 568 // renorm totalCost 569 totalCost >>= largestBits - maxNbBits /* note : totalCost is necessarily a multiple of baseCost */ 570 571 // repay normalized cost 572 { 573 const noSymbol = 0xF0F0F0F0 574 var rankLast [tableLogMax + 2]uint32 575 576 for i := range rankLast[:] { 577 rankLast[i] = noSymbol 578 } 579 580 // Get pos of last (smallest) symbol per rank 581 { 582 currentNbBits := maxNbBits 583 for pos := int(n); pos >= 0; pos-- { 584 if huffNode[pos].nbBits >= currentNbBits { 585 continue 586 } 587 currentNbBits = huffNode[pos].nbBits // < maxNbBits 588 rankLast[maxNbBits-currentNbBits] = uint32(pos) 589 } 590 } 591 592 for totalCost > 0 { 593 nBitsToDecrease := uint8(highBit32(uint32(totalCost))) + 1 594 595 for ; nBitsToDecrease > 1; nBitsToDecrease-- { 596 highPos := rankLast[nBitsToDecrease] 597 lowPos := rankLast[nBitsToDecrease-1] 598 if highPos == noSymbol { 599 continue 600 } 601 if lowPos == noSymbol { 602 break 603 } 604 highTotal := huffNode[highPos].count 605 lowTotal := 2 * huffNode[lowPos].count 606 if highTotal <= lowTotal { 607 break 608 } 609 } 610 // only triggered when no more rank 1 symbol left => find closest one (note : there is necessarily at least one !) 611 // HUF_MAX_TABLELOG test just to please gcc 5+; but it should not be necessary 612 // FIXME: try to remove 613 for (nBitsToDecrease <= tableLogMax) && (rankLast[nBitsToDecrease] == noSymbol) { 614 nBitsToDecrease++ 615 } 616 totalCost -= 1 << (nBitsToDecrease - 1) 617 if rankLast[nBitsToDecrease-1] == noSymbol { 618 // this rank is no longer empty 619 rankLast[nBitsToDecrease-1] = rankLast[nBitsToDecrease] 620 } 621 huffNode[rankLast[nBitsToDecrease]].nbBits++ 622 if rankLast[nBitsToDecrease] == 0 { 623 /* special case, reached largest symbol */ 624 rankLast[nBitsToDecrease] = noSymbol 625 } else { 626 rankLast[nBitsToDecrease]-- 627 if huffNode[rankLast[nBitsToDecrease]].nbBits != maxNbBits-nBitsToDecrease { 628 rankLast[nBitsToDecrease] = noSymbol /* this rank is now empty */ 629 } 630 } 631 } 632 633 for totalCost < 0 { /* Sometimes, cost correction overshoot */ 634 if rankLast[1] == noSymbol { /* special case : no rank 1 symbol (using maxNbBits-1); let's create one from largest rank 0 (using maxNbBits) */ 635 for huffNode[n].nbBits == maxNbBits { 636 n-- 637 } 638 huffNode[n+1].nbBits-- 639 rankLast[1] = n + 1 640 totalCost++ 641 continue 642 } 643 huffNode[rankLast[1]+1].nbBits-- 644 rankLast[1]++ 645 totalCost++ 646 } 647 } 648 return maxNbBits 649} 650 651type nodeElt struct { 652 count uint32 653 parent uint16 654 symbol byte 655 nbBits uint8 656} 657