1// Copyright 2019+ Klaus Post. All rights reserved. 2// License information can be found in the LICENSE file. 3// Based on work by Yann Collet, released under BSD License. 4 5package zstd 6 7import ( 8 "errors" 9 "fmt" 10 "math" 11) 12 13const ( 14 // For encoding we only support up to 15 maxEncTableLog = 8 16 maxEncTablesize = 1 << maxTableLog 17 maxEncTableMask = (1 << maxTableLog) - 1 18 minEncTablelog = 5 19 maxEncSymbolValue = maxMatchLengthSymbol 20) 21 22// Scratch provides temporary storage for compression and decompression. 23type fseEncoder struct { 24 symbolLen uint16 // Length of active part of the symbol table. 25 actualTableLog uint8 // Selected tablelog. 26 ct cTable // Compression tables. 27 maxCount int // count of the most probable symbol 28 zeroBits bool // no bits has prob > 50%. 29 clearCount bool // clear count 30 useRLE bool // This encoder is for RLE 31 preDefined bool // This encoder is predefined. 32 reUsed bool // Set to know when the encoder has been reused. 33 rleVal uint8 // RLE Symbol 34 maxBits uint8 // Maximum output bits after transform. 35 36 // TODO: Technically zstd should be fine with 64 bytes. 37 count [256]uint32 38 norm [256]int16 39} 40 41// cTable contains tables used for compression. 42type cTable struct { 43 tableSymbol []byte 44 stateTable []uint16 45 symbolTT []symbolTransform 46} 47 48// symbolTransform contains the state transform for a symbol. 49type symbolTransform struct { 50 deltaNbBits uint32 51 deltaFindState int16 52 outBits uint8 53} 54 55// String prints values as a human readable string. 56func (s symbolTransform) String() string { 57 return fmt.Sprintf("{deltabits: %08x, findstate:%d outbits:%d}", s.deltaNbBits, s.deltaFindState, s.outBits) 58} 59 60// Histogram allows to populate the histogram and skip that step in the compression, 61// It otherwise allows to inspect the histogram when compression is done. 62// To indicate that you have populated the histogram call HistogramFinished 63// with the value of the highest populated symbol, as well as the number of entries 64// in the most populated entry. These are accepted at face value. 65// The returned slice will always be length 256. 66func (s *fseEncoder) Histogram() []uint32 { 67 return s.count[:] 68} 69 70// HistogramFinished can be called to indicate that the histogram has been populated. 71// maxSymbol is the index of the highest set symbol of the next data segment. 72// maxCount is the number of entries in the most populated entry. 73// These are accepted at face value. 74func (s *fseEncoder) HistogramFinished(maxSymbol uint8, maxCount int) { 75 s.maxCount = maxCount 76 s.symbolLen = uint16(maxSymbol) + 1 77 s.clearCount = maxCount != 0 78} 79 80// prepare will prepare and allocate scratch tables used for both compression and decompression. 81func (s *fseEncoder) prepare() (*fseEncoder, error) { 82 if s == nil { 83 s = &fseEncoder{} 84 } 85 s.useRLE = false 86 if s.clearCount && s.maxCount == 0 { 87 for i := range s.count { 88 s.count[i] = 0 89 } 90 s.clearCount = false 91 } 92 return s, nil 93} 94 95// allocCtable will allocate tables needed for compression. 96// If existing tables a re big enough, they are simply re-used. 97func (s *fseEncoder) allocCtable() { 98 tableSize := 1 << s.actualTableLog 99 // get tableSymbol that is big enough. 100 if cap(s.ct.tableSymbol) < int(tableSize) { 101 s.ct.tableSymbol = make([]byte, tableSize) 102 } 103 s.ct.tableSymbol = s.ct.tableSymbol[:tableSize] 104 105 ctSize := tableSize 106 if cap(s.ct.stateTable) < ctSize { 107 s.ct.stateTable = make([]uint16, ctSize) 108 } 109 s.ct.stateTable = s.ct.stateTable[:ctSize] 110 111 if cap(s.ct.symbolTT) < 256 { 112 s.ct.symbolTT = make([]symbolTransform, 256) 113 } 114 s.ct.symbolTT = s.ct.symbolTT[:256] 115} 116 117// buildCTable will populate the compression table so it is ready to be used. 118func (s *fseEncoder) buildCTable() error { 119 tableSize := uint32(1 << s.actualTableLog) 120 highThreshold := tableSize - 1 121 var cumul [256]int16 122 123 s.allocCtable() 124 tableSymbol := s.ct.tableSymbol[:tableSize] 125 // symbol start positions 126 { 127 cumul[0] = 0 128 for ui, v := range s.norm[:s.symbolLen-1] { 129 u := byte(ui) // one less than reference 130 if v == -1 { 131 // Low proba symbol 132 cumul[u+1] = cumul[u] + 1 133 tableSymbol[highThreshold] = u 134 highThreshold-- 135 } else { 136 cumul[u+1] = cumul[u] + v 137 } 138 } 139 // Encode last symbol separately to avoid overflowing u 140 u := int(s.symbolLen - 1) 141 v := s.norm[s.symbolLen-1] 142 if v == -1 { 143 // Low proba symbol 144 cumul[u+1] = cumul[u] + 1 145 tableSymbol[highThreshold] = byte(u) 146 highThreshold-- 147 } else { 148 cumul[u+1] = cumul[u] + v 149 } 150 if uint32(cumul[s.symbolLen]) != tableSize { 151 return fmt.Errorf("internal error: expected cumul[s.symbolLen] (%d) == tableSize (%d)", cumul[s.symbolLen], tableSize) 152 } 153 cumul[s.symbolLen] = int16(tableSize) + 1 154 } 155 // Spread symbols 156 s.zeroBits = false 157 { 158 step := tableStep(tableSize) 159 tableMask := tableSize - 1 160 var position uint32 161 // if any symbol > largeLimit, we may have 0 bits output. 162 largeLimit := int16(1 << (s.actualTableLog - 1)) 163 for ui, v := range s.norm[:s.symbolLen] { 164 symbol := byte(ui) 165 if v > largeLimit { 166 s.zeroBits = true 167 } 168 for nbOccurrences := int16(0); nbOccurrences < v; nbOccurrences++ { 169 tableSymbol[position] = symbol 170 position = (position + step) & tableMask 171 for position > highThreshold { 172 position = (position + step) & tableMask 173 } /* Low proba area */ 174 } 175 } 176 177 // Check if we have gone through all positions 178 if position != 0 { 179 return errors.New("position!=0") 180 } 181 } 182 183 // Build table 184 table := s.ct.stateTable 185 { 186 tsi := int(tableSize) 187 for u, v := range tableSymbol { 188 // TableU16 : sorted by symbol order; gives next state value 189 table[cumul[v]] = uint16(tsi + u) 190 cumul[v]++ 191 } 192 } 193 194 // Build Symbol Transformation Table 195 { 196 total := int16(0) 197 symbolTT := s.ct.symbolTT[:s.symbolLen] 198 tableLog := s.actualTableLog 199 tl := (uint32(tableLog) << 16) - (1 << tableLog) 200 for i, v := range s.norm[:s.symbolLen] { 201 switch v { 202 case 0: 203 case -1, 1: 204 symbolTT[i].deltaNbBits = tl 205 symbolTT[i].deltaFindState = int16(total - 1) 206 total++ 207 default: 208 maxBitsOut := uint32(tableLog) - highBit(uint32(v-1)) 209 minStatePlus := uint32(v) << maxBitsOut 210 symbolTT[i].deltaNbBits = (maxBitsOut << 16) - minStatePlus 211 symbolTT[i].deltaFindState = int16(total - v) 212 total += v 213 } 214 } 215 if total != int16(tableSize) { 216 return fmt.Errorf("total mismatch %d (got) != %d (want)", total, tableSize) 217 } 218 } 219 return nil 220} 221 222var rtbTable = [...]uint32{0, 473195, 504333, 520860, 550000, 700000, 750000, 830000} 223 224func (s *fseEncoder) setRLE(val byte) { 225 s.allocCtable() 226 s.actualTableLog = 0 227 s.ct.stateTable = s.ct.stateTable[:1] 228 s.ct.symbolTT[val] = symbolTransform{ 229 deltaFindState: 0, 230 deltaNbBits: 0, 231 } 232 if debug { 233 println("setRLE: val", val, "symbolTT", s.ct.symbolTT[val]) 234 } 235 s.rleVal = val 236 s.useRLE = true 237} 238 239// setBits will set output bits for the transform. 240// if nil is provided, the number of bits is equal to the index. 241func (s *fseEncoder) setBits(transform []byte) { 242 if s.reUsed || s.preDefined { 243 return 244 } 245 if s.useRLE { 246 if transform == nil { 247 s.ct.symbolTT[s.rleVal].outBits = s.rleVal 248 s.maxBits = s.rleVal 249 return 250 } 251 s.maxBits = transform[s.rleVal] 252 s.ct.symbolTT[s.rleVal].outBits = s.maxBits 253 return 254 } 255 if transform == nil { 256 for i := range s.ct.symbolTT[:s.symbolLen] { 257 s.ct.symbolTT[i].outBits = uint8(i) 258 } 259 s.maxBits = uint8(s.symbolLen - 1) 260 return 261 } 262 s.maxBits = 0 263 for i, v := range transform[:s.symbolLen] { 264 s.ct.symbolTT[i].outBits = v 265 if v > s.maxBits { 266 // We could assume bits always going up, but we play safe. 267 s.maxBits = v 268 } 269 } 270} 271 272// normalizeCount will normalize the count of the symbols so 273// the total is equal to the table size. 274// If successful, compression tables will also be made ready. 275func (s *fseEncoder) normalizeCount(length int) error { 276 if s.reUsed { 277 return nil 278 } 279 s.optimalTableLog(length) 280 var ( 281 tableLog = s.actualTableLog 282 scale = 62 - uint64(tableLog) 283 step = (1 << 62) / uint64(length) 284 vStep = uint64(1) << (scale - 20) 285 stillToDistribute = int16(1 << tableLog) 286 largest int 287 largestP int16 288 lowThreshold = (uint32)(length >> tableLog) 289 ) 290 if s.maxCount == length { 291 s.useRLE = true 292 return nil 293 } 294 s.useRLE = false 295 for i, cnt := range s.count[:s.symbolLen] { 296 // already handled 297 // if (count[s] == s.length) return 0; /* rle special case */ 298 299 if cnt == 0 { 300 s.norm[i] = 0 301 continue 302 } 303 if cnt <= lowThreshold { 304 s.norm[i] = -1 305 stillToDistribute-- 306 } else { 307 proba := (int16)((uint64(cnt) * step) >> scale) 308 if proba < 8 { 309 restToBeat := vStep * uint64(rtbTable[proba]) 310 v := uint64(cnt)*step - (uint64(proba) << scale) 311 if v > restToBeat { 312 proba++ 313 } 314 } 315 if proba > largestP { 316 largestP = proba 317 largest = i 318 } 319 s.norm[i] = proba 320 stillToDistribute -= proba 321 } 322 } 323 324 if -stillToDistribute >= (s.norm[largest] >> 1) { 325 // corner case, need another normalization method 326 err := s.normalizeCount2(length) 327 if err != nil { 328 return err 329 } 330 if debugAsserts { 331 err = s.validateNorm() 332 if err != nil { 333 return err 334 } 335 } 336 return s.buildCTable() 337 } 338 s.norm[largest] += stillToDistribute 339 if debugAsserts { 340 err := s.validateNorm() 341 if err != nil { 342 return err 343 } 344 } 345 return s.buildCTable() 346} 347 348// Secondary normalization method. 349// To be used when primary method fails. 350func (s *fseEncoder) normalizeCount2(length int) error { 351 const notYetAssigned = -2 352 var ( 353 distributed uint32 354 total = uint32(length) 355 tableLog = s.actualTableLog 356 lowThreshold = uint32(total >> tableLog) 357 lowOne = uint32((total * 3) >> (tableLog + 1)) 358 ) 359 for i, cnt := range s.count[:s.symbolLen] { 360 if cnt == 0 { 361 s.norm[i] = 0 362 continue 363 } 364 if cnt <= lowThreshold { 365 s.norm[i] = -1 366 distributed++ 367 total -= cnt 368 continue 369 } 370 if cnt <= lowOne { 371 s.norm[i] = 1 372 distributed++ 373 total -= cnt 374 continue 375 } 376 s.norm[i] = notYetAssigned 377 } 378 toDistribute := (1 << tableLog) - distributed 379 380 if (total / toDistribute) > lowOne { 381 // risk of rounding to zero 382 lowOne = uint32((total * 3) / (toDistribute * 2)) 383 for i, cnt := range s.count[:s.symbolLen] { 384 if (s.norm[i] == notYetAssigned) && (cnt <= lowOne) { 385 s.norm[i] = 1 386 distributed++ 387 total -= cnt 388 continue 389 } 390 } 391 toDistribute = (1 << tableLog) - distributed 392 } 393 if distributed == uint32(s.symbolLen)+1 { 394 // all values are pretty poor; 395 // probably incompressible data (should have already been detected); 396 // find max, then give all remaining points to max 397 var maxV int 398 var maxC uint32 399 for i, cnt := range s.count[:s.symbolLen] { 400 if cnt > maxC { 401 maxV = i 402 maxC = cnt 403 } 404 } 405 s.norm[maxV] += int16(toDistribute) 406 return nil 407 } 408 409 if total == 0 { 410 // all of the symbols were low enough for the lowOne or lowThreshold 411 for i := uint32(0); toDistribute > 0; i = (i + 1) % (uint32(s.symbolLen)) { 412 if s.norm[i] > 0 { 413 toDistribute-- 414 s.norm[i]++ 415 } 416 } 417 return nil 418 } 419 420 var ( 421 vStepLog = 62 - uint64(tableLog) 422 mid = uint64((1 << (vStepLog - 1)) - 1) 423 rStep = (((1 << vStepLog) * uint64(toDistribute)) + mid) / uint64(total) // scale on remaining 424 tmpTotal = mid 425 ) 426 for i, cnt := range s.count[:s.symbolLen] { 427 if s.norm[i] == notYetAssigned { 428 var ( 429 end = tmpTotal + uint64(cnt)*rStep 430 sStart = uint32(tmpTotal >> vStepLog) 431 sEnd = uint32(end >> vStepLog) 432 weight = sEnd - sStart 433 ) 434 if weight < 1 { 435 return errors.New("weight < 1") 436 } 437 s.norm[i] = int16(weight) 438 tmpTotal = end 439 } 440 } 441 return nil 442} 443 444// optimalTableLog calculates and sets the optimal tableLog in s.actualTableLog 445func (s *fseEncoder) optimalTableLog(length int) { 446 tableLog := uint8(maxEncTableLog) 447 minBitsSrc := highBit(uint32(length)) + 1 448 minBitsSymbols := highBit(uint32(s.symbolLen-1)) + 2 449 minBits := uint8(minBitsSymbols) 450 if minBitsSrc < minBitsSymbols { 451 minBits = uint8(minBitsSrc) 452 } 453 454 maxBitsSrc := uint8(highBit(uint32(length-1))) - 2 455 if maxBitsSrc < tableLog { 456 // Accuracy can be reduced 457 tableLog = maxBitsSrc 458 } 459 if minBits > tableLog { 460 tableLog = minBits 461 } 462 // Need a minimum to safely represent all symbol values 463 if tableLog < minEncTablelog { 464 tableLog = minEncTablelog 465 } 466 if tableLog > maxEncTableLog { 467 tableLog = maxEncTableLog 468 } 469 s.actualTableLog = tableLog 470} 471 472// validateNorm validates the normalized histogram table. 473func (s *fseEncoder) validateNorm() (err error) { 474 var total int 475 for _, v := range s.norm[:s.symbolLen] { 476 if v >= 0 { 477 total += int(v) 478 } else { 479 total -= int(v) 480 } 481 } 482 defer func() { 483 if err == nil { 484 return 485 } 486 fmt.Printf("selected TableLog: %d, Symbol length: %d\n", s.actualTableLog, s.symbolLen) 487 for i, v := range s.norm[:s.symbolLen] { 488 fmt.Printf("%3d: %5d -> %4d \n", i, s.count[i], v) 489 } 490 }() 491 if total != (1 << s.actualTableLog) { 492 return fmt.Errorf("warning: Total == %d != %d", total, 1<<s.actualTableLog) 493 } 494 for i, v := range s.count[s.symbolLen:] { 495 if v != 0 { 496 return fmt.Errorf("warning: Found symbol out of range, %d after cut", i) 497 } 498 } 499 return nil 500} 501 502// writeCount will write the normalized histogram count to header. 503// This is read back by readNCount. 504func (s *fseEncoder) writeCount(out []byte) ([]byte, error) { 505 if s.useRLE { 506 return append(out, s.rleVal), nil 507 } 508 if s.preDefined || s.reUsed { 509 // Never write predefined. 510 return out, nil 511 } 512 513 var ( 514 tableLog = s.actualTableLog 515 tableSize = 1 << tableLog 516 previous0 bool 517 charnum uint16 518 519 // maximum header size plus 2 extra bytes for final output if bitCount == 0. 520 maxHeaderSize = ((int(s.symbolLen) * int(tableLog)) >> 3) + 3 + 2 521 522 // Write Table Size 523 bitStream = uint32(tableLog - minEncTablelog) 524 bitCount = uint(4) 525 remaining = int16(tableSize + 1) /* +1 for extra accuracy */ 526 threshold = int16(tableSize) 527 nbBits = uint(tableLog + 1) 528 outP = len(out) 529 ) 530 if cap(out) < outP+maxHeaderSize { 531 out = append(out, make([]byte, maxHeaderSize*3)...) 532 out = out[:len(out)-maxHeaderSize*3] 533 } 534 out = out[:outP+maxHeaderSize] 535 536 // stops at 1 537 for remaining > 1 { 538 if previous0 { 539 start := charnum 540 for s.norm[charnum] == 0 { 541 charnum++ 542 } 543 for charnum >= start+24 { 544 start += 24 545 bitStream += uint32(0xFFFF) << bitCount 546 out[outP] = byte(bitStream) 547 out[outP+1] = byte(bitStream >> 8) 548 outP += 2 549 bitStream >>= 16 550 } 551 for charnum >= start+3 { 552 start += 3 553 bitStream += 3 << bitCount 554 bitCount += 2 555 } 556 bitStream += uint32(charnum-start) << bitCount 557 bitCount += 2 558 if bitCount > 16 { 559 out[outP] = byte(bitStream) 560 out[outP+1] = byte(bitStream >> 8) 561 outP += 2 562 bitStream >>= 16 563 bitCount -= 16 564 } 565 } 566 567 count := s.norm[charnum] 568 charnum++ 569 max := (2*threshold - 1) - remaining 570 if count < 0 { 571 remaining += count 572 } else { 573 remaining -= count 574 } 575 count++ // +1 for extra accuracy 576 if count >= threshold { 577 count += max // [0..max[ [max..threshold[ (...) [threshold+max 2*threshold[ 578 } 579 bitStream += uint32(count) << bitCount 580 bitCount += nbBits 581 if count < max { 582 bitCount-- 583 } 584 585 previous0 = count == 1 586 if remaining < 1 { 587 return nil, errors.New("internal error: remaining < 1") 588 } 589 for remaining < threshold { 590 nbBits-- 591 threshold >>= 1 592 } 593 594 if bitCount > 16 { 595 out[outP] = byte(bitStream) 596 out[outP+1] = byte(bitStream >> 8) 597 outP += 2 598 bitStream >>= 16 599 bitCount -= 16 600 } 601 } 602 603 if outP+2 > len(out) { 604 return nil, fmt.Errorf("internal error: %d > %d, maxheader: %d, sl: %d, tl: %d, normcount: %v", outP+2, len(out), maxHeaderSize, s.symbolLen, int(tableLog), s.norm[:s.symbolLen]) 605 } 606 out[outP] = byte(bitStream) 607 out[outP+1] = byte(bitStream >> 8) 608 outP += int((bitCount + 7) / 8) 609 610 if charnum > s.symbolLen { 611 return nil, errors.New("internal error: charnum > s.symbolLen") 612 } 613 return out[:outP], nil 614} 615 616// Approximate symbol cost, as fractional value, using fixed-point format (accuracyLog fractional bits) 617// note 1 : assume symbolValue is valid (<= maxSymbolValue) 618// note 2 : if freq[symbolValue]==0, @return a fake cost of tableLog+1 bits * 619func (s *fseEncoder) bitCost(symbolValue uint8, accuracyLog uint32) uint32 { 620 minNbBits := s.ct.symbolTT[symbolValue].deltaNbBits >> 16 621 threshold := (minNbBits + 1) << 16 622 if debugAsserts { 623 if !(s.actualTableLog < 16) { 624 panic("!s.actualTableLog < 16") 625 } 626 // ensure enough room for renormalization double shift 627 if !(uint8(accuracyLog) < 31-s.actualTableLog) { 628 panic("!uint8(accuracyLog) < 31-s.actualTableLog") 629 } 630 } 631 tableSize := uint32(1) << s.actualTableLog 632 deltaFromThreshold := threshold - (s.ct.symbolTT[symbolValue].deltaNbBits + tableSize) 633 // linear interpolation (very approximate) 634 normalizedDeltaFromThreshold := (deltaFromThreshold << accuracyLog) >> s.actualTableLog 635 bitMultiplier := uint32(1) << accuracyLog 636 if debugAsserts { 637 if s.ct.symbolTT[symbolValue].deltaNbBits+tableSize > threshold { 638 panic("s.ct.symbolTT[symbolValue].deltaNbBits+tableSize > threshold") 639 } 640 if normalizedDeltaFromThreshold > bitMultiplier { 641 panic("normalizedDeltaFromThreshold > bitMultiplier") 642 } 643 } 644 return (minNbBits+1)*bitMultiplier - normalizedDeltaFromThreshold 645} 646 647// Returns the cost in bits of encoding the distribution in count using ctable. 648// Histogram should only be up to the last non-zero symbol. 649// Returns an -1 if ctable cannot represent all the symbols in count. 650func (s *fseEncoder) approxSize(hist []uint32) uint32 { 651 if int(s.symbolLen) < len(hist) { 652 // More symbols than we have. 653 return math.MaxUint32 654 } 655 if s.useRLE { 656 // We will never reuse RLE encoders. 657 return math.MaxUint32 658 } 659 const kAccuracyLog = 8 660 badCost := (uint32(s.actualTableLog) + 1) << kAccuracyLog 661 var cost uint32 662 for i, v := range hist { 663 if v == 0 { 664 continue 665 } 666 if s.norm[i] == 0 { 667 return math.MaxUint32 668 } 669 bitCost := s.bitCost(uint8(i), kAccuracyLog) 670 if bitCost > badCost { 671 return math.MaxUint32 672 } 673 cost += v * bitCost 674 } 675 return cost >> kAccuracyLog 676} 677 678// maxHeaderSize returns the maximum header size in bits. 679// This is not exact size, but we want a penalty for new tables anyway. 680func (s *fseEncoder) maxHeaderSize() uint32 { 681 if s.preDefined { 682 return 0 683 } 684 if s.useRLE { 685 return 8 686 } 687 return (((uint32(s.symbolLen) * uint32(s.actualTableLog)) >> 3) + 3) * 8 688} 689 690// cState contains the compression state of a stream. 691type cState struct { 692 bw *bitWriter 693 stateTable []uint16 694 state uint16 695} 696 697// init will initialize the compression state to the first symbol of the stream. 698func (c *cState) init(bw *bitWriter, ct *cTable, first symbolTransform) { 699 c.bw = bw 700 c.stateTable = ct.stateTable 701 if len(c.stateTable) == 1 { 702 // RLE 703 c.stateTable[0] = uint16(0) 704 c.state = 0 705 return 706 } 707 nbBitsOut := (first.deltaNbBits + (1 << 15)) >> 16 708 im := int32((nbBitsOut << 16) - first.deltaNbBits) 709 lu := (im >> nbBitsOut) + int32(first.deltaFindState) 710 c.state = c.stateTable[lu] 711 return 712} 713 714// encode the output symbol provided and write it to the bitstream. 715func (c *cState) encode(symbolTT symbolTransform) { 716 nbBitsOut := (uint32(c.state) + symbolTT.deltaNbBits) >> 16 717 dstState := int32(c.state>>(nbBitsOut&15)) + int32(symbolTT.deltaFindState) 718 c.bw.addBits16NC(c.state, uint8(nbBitsOut)) 719 c.state = c.stateTable[dstState] 720} 721 722// flush will write the tablelog to the output and flush the remaining full bytes. 723func (c *cState) flush(tableLog uint8) { 724 c.bw.flush32() 725 c.bw.addBits16NC(c.state, tableLog) 726} 727