1// Copyright 2019+ Klaus Post. All rights reserved. 2// License information can be found in the LICENSE file. 3// Based on work by Yann Collet, released under BSD License. 4 5package zstd 6 7import ( 8 "errors" 9 "fmt" 10) 11 12const ( 13 tablelogAbsoluteMax = 9 14) 15 16const ( 17 /*!MEMORY_USAGE : 18 * Memory usage formula : N->2^N Bytes (examples : 10 -> 1KB; 12 -> 4KB ; 16 -> 64KB; 20 -> 1MB; etc.) 19 * Increasing memory usage improves compression ratio 20 * Reduced memory usage can improve speed, due to cache effect 21 * Recommended max value is 14, for 16KB, which nicely fits into Intel x86 L1 cache */ 22 maxMemoryUsage = tablelogAbsoluteMax + 2 23 24 maxTableLog = maxMemoryUsage - 2 25 maxTablesize = 1 << maxTableLog 26 maxTableMask = (1 << maxTableLog) - 1 27 minTablelog = 5 28 maxSymbolValue = 255 29) 30 31// fseDecoder provides temporary storage for compression and decompression. 32type fseDecoder struct { 33 dt [maxTablesize]decSymbol // Decompression table. 34 symbolLen uint16 // Length of active part of the symbol table. 35 actualTableLog uint8 // Selected tablelog. 36 maxBits uint8 // Maximum number of additional bits 37 38 // used for table creation to avoid allocations. 39 stateTable [256]uint16 40 norm [maxSymbolValue + 1]int16 41 preDefined bool 42} 43 44// tableStep returns the next table index. 45func tableStep(tableSize uint32) uint32 { 46 return (tableSize >> 1) + (tableSize >> 3) + 3 47} 48 49// readNCount will read the symbol distribution so decoding tables can be constructed. 50func (s *fseDecoder) readNCount(b *byteReader, maxSymbol uint16) error { 51 var ( 52 charnum uint16 53 previous0 bool 54 ) 55 if b.remain() < 4 { 56 return errors.New("input too small") 57 } 58 bitStream := b.Uint32NC() 59 nbBits := uint((bitStream & 0xF) + minTablelog) // extract tableLog 60 if nbBits > tablelogAbsoluteMax { 61 println("Invalid tablelog:", nbBits) 62 return errors.New("tableLog too large") 63 } 64 bitStream >>= 4 65 bitCount := uint(4) 66 67 s.actualTableLog = uint8(nbBits) 68 remaining := int32((1 << nbBits) + 1) 69 threshold := int32(1 << nbBits) 70 gotTotal := int32(0) 71 nbBits++ 72 73 for remaining > 1 && charnum <= maxSymbol { 74 if previous0 { 75 //println("prev0") 76 n0 := charnum 77 for (bitStream & 0xFFFF) == 0xFFFF { 78 //println("24 x 0") 79 n0 += 24 80 if r := b.remain(); r > 5 { 81 b.advance(2) 82 // The check above should make sure we can read 32 bits 83 bitStream = b.Uint32NC() >> bitCount 84 } else { 85 // end of bit stream 86 bitStream >>= 16 87 bitCount += 16 88 } 89 } 90 //printf("bitstream: %d, 0b%b", bitStream&3, bitStream) 91 for (bitStream & 3) == 3 { 92 n0 += 3 93 bitStream >>= 2 94 bitCount += 2 95 } 96 n0 += uint16(bitStream & 3) 97 bitCount += 2 98 99 if n0 > maxSymbolValue { 100 return errors.New("maxSymbolValue too small") 101 } 102 //println("inserting ", n0-charnum, "zeroes from idx", charnum, "ending before", n0) 103 for charnum < n0 { 104 s.norm[uint8(charnum)] = 0 105 charnum++ 106 } 107 108 if r := b.remain(); r >= 7 || r-int(bitCount>>3) >= 4 { 109 b.advance(bitCount >> 3) 110 bitCount &= 7 111 // The check above should make sure we can read 32 bits 112 bitStream = b.Uint32NC() >> bitCount 113 } else { 114 bitStream >>= 2 115 } 116 } 117 118 max := (2*threshold - 1) - remaining 119 var count int32 120 121 if int32(bitStream)&(threshold-1) < max { 122 count = int32(bitStream) & (threshold - 1) 123 if debugAsserts && nbBits < 1 { 124 panic("nbBits underflow") 125 } 126 bitCount += nbBits - 1 127 } else { 128 count = int32(bitStream) & (2*threshold - 1) 129 if count >= threshold { 130 count -= max 131 } 132 bitCount += nbBits 133 } 134 135 // extra accuracy 136 count-- 137 if count < 0 { 138 // -1 means +1 139 remaining += count 140 gotTotal -= count 141 } else { 142 remaining -= count 143 gotTotal += count 144 } 145 s.norm[charnum&0xff] = int16(count) 146 charnum++ 147 previous0 = count == 0 148 for remaining < threshold { 149 nbBits-- 150 threshold >>= 1 151 } 152 153 if r := b.remain(); r >= 7 || r-int(bitCount>>3) >= 4 { 154 b.advance(bitCount >> 3) 155 bitCount &= 7 156 // The check above should make sure we can read 32 bits 157 bitStream = b.Uint32NC() >> (bitCount & 31) 158 } else { 159 bitCount -= (uint)(8 * (len(b.b) - 4 - b.off)) 160 b.off = len(b.b) - 4 161 bitStream = b.Uint32() >> (bitCount & 31) 162 } 163 } 164 s.symbolLen = charnum 165 if s.symbolLen <= 1 { 166 return fmt.Errorf("symbolLen (%d) too small", s.symbolLen) 167 } 168 if s.symbolLen > maxSymbolValue+1 { 169 return fmt.Errorf("symbolLen (%d) too big", s.symbolLen) 170 } 171 if remaining != 1 { 172 return fmt.Errorf("corruption detected (remaining %d != 1)", remaining) 173 } 174 if bitCount > 32 { 175 return fmt.Errorf("corruption detected (bitCount %d > 32)", bitCount) 176 } 177 if gotTotal != 1<<s.actualTableLog { 178 return fmt.Errorf("corruption detected (total %d != %d)", gotTotal, 1<<s.actualTableLog) 179 } 180 b.advance((bitCount + 7) >> 3) 181 // println(s.norm[:s.symbolLen], s.symbolLen) 182 return s.buildDtable() 183} 184 185// decSymbol contains information about a state entry, 186// Including the state offset base, the output symbol and 187// the number of bits to read for the low part of the destination state. 188// Using a composite uint64 is faster than a struct with separate members. 189type decSymbol uint64 190 191func newDecSymbol(nbits, addBits uint8, newState uint16, baseline uint32) decSymbol { 192 return decSymbol(nbits) | (decSymbol(addBits) << 8) | (decSymbol(newState) << 16) | (decSymbol(baseline) << 32) 193} 194 195func (d decSymbol) nbBits() uint8 { 196 return uint8(d) 197} 198 199func (d decSymbol) addBits() uint8 { 200 return uint8(d >> 8) 201} 202 203func (d decSymbol) newState() uint16 { 204 return uint16(d >> 16) 205} 206 207func (d decSymbol) baseline() uint32 { 208 return uint32(d >> 32) 209} 210 211func (d decSymbol) baselineInt() int { 212 return int(d >> 32) 213} 214 215func (d *decSymbol) set(nbits, addBits uint8, newState uint16, baseline uint32) { 216 *d = decSymbol(nbits) | (decSymbol(addBits) << 8) | (decSymbol(newState) << 16) | (decSymbol(baseline) << 32) 217} 218 219func (d *decSymbol) setNBits(nBits uint8) { 220 const mask = 0xffffffffffffff00 221 *d = (*d & mask) | decSymbol(nBits) 222} 223 224func (d *decSymbol) setAddBits(addBits uint8) { 225 const mask = 0xffffffffffff00ff 226 *d = (*d & mask) | (decSymbol(addBits) << 8) 227} 228 229func (d *decSymbol) setNewState(state uint16) { 230 const mask = 0xffffffff0000ffff 231 *d = (*d & mask) | decSymbol(state)<<16 232} 233 234func (d *decSymbol) setBaseline(baseline uint32) { 235 const mask = 0xffffffff 236 *d = (*d & mask) | decSymbol(baseline)<<32 237} 238 239func (d *decSymbol) setExt(addBits uint8, baseline uint32) { 240 const mask = 0xffff00ff 241 *d = (*d & mask) | (decSymbol(addBits) << 8) | (decSymbol(baseline) << 32) 242} 243 244// decSymbolValue returns the transformed decSymbol for the given symbol. 245func decSymbolValue(symb uint8, t []baseOffset) (decSymbol, error) { 246 if int(symb) >= len(t) { 247 return 0, fmt.Errorf("rle symbol %d >= max %d", symb, len(t)) 248 } 249 lu := t[symb] 250 return newDecSymbol(0, lu.addBits, 0, lu.baseLine), nil 251} 252 253// setRLE will set the decoder til RLE mode. 254func (s *fseDecoder) setRLE(symbol decSymbol) { 255 s.actualTableLog = 0 256 s.maxBits = symbol.addBits() 257 s.dt[0] = symbol 258} 259 260// buildDtable will build the decoding table. 261func (s *fseDecoder) buildDtable() error { 262 tableSize := uint32(1 << s.actualTableLog) 263 highThreshold := tableSize - 1 264 symbolNext := s.stateTable[:256] 265 266 // Init, lay down lowprob symbols 267 { 268 for i, v := range s.norm[:s.symbolLen] { 269 if v == -1 { 270 s.dt[highThreshold].setAddBits(uint8(i)) 271 highThreshold-- 272 symbolNext[i] = 1 273 } else { 274 symbolNext[i] = uint16(v) 275 } 276 } 277 } 278 // Spread symbols 279 { 280 tableMask := tableSize - 1 281 step := tableStep(tableSize) 282 position := uint32(0) 283 for ss, v := range s.norm[:s.symbolLen] { 284 for i := 0; i < int(v); i++ { 285 s.dt[position].setAddBits(uint8(ss)) 286 position = (position + step) & tableMask 287 for position > highThreshold { 288 // lowprob area 289 position = (position + step) & tableMask 290 } 291 } 292 } 293 if position != 0 { 294 // position must reach all cells once, otherwise normalizedCounter is incorrect 295 return errors.New("corrupted input (position != 0)") 296 } 297 } 298 299 // Build Decoding table 300 { 301 tableSize := uint16(1 << s.actualTableLog) 302 for u, v := range s.dt[:tableSize] { 303 symbol := v.addBits() 304 nextState := symbolNext[symbol] 305 symbolNext[symbol] = nextState + 1 306 nBits := s.actualTableLog - byte(highBits(uint32(nextState))) 307 s.dt[u&maxTableMask].setNBits(nBits) 308 newState := (nextState << nBits) - tableSize 309 if newState > tableSize { 310 return fmt.Errorf("newState (%d) outside table size (%d)", newState, tableSize) 311 } 312 if newState == uint16(u) && nBits == 0 { 313 // Seems weird that this is possible with nbits > 0. 314 return fmt.Errorf("newState (%d) == oldState (%d) and no bits", newState, u) 315 } 316 s.dt[u&maxTableMask].setNewState(newState) 317 } 318 } 319 return nil 320} 321 322// transform will transform the decoder table into a table usable for 323// decoding without having to apply the transformation while decoding. 324// The state will contain the base value and the number of bits to read. 325func (s *fseDecoder) transform(t []baseOffset) error { 326 tableSize := uint16(1 << s.actualTableLog) 327 s.maxBits = 0 328 for i, v := range s.dt[:tableSize] { 329 add := v.addBits() 330 if int(add) >= len(t) { 331 return fmt.Errorf("invalid decoding table entry %d, symbol %d >= max (%d)", i, v.addBits(), len(t)) 332 } 333 lu := t[add] 334 if lu.addBits > s.maxBits { 335 s.maxBits = lu.addBits 336 } 337 v.setExt(lu.addBits, lu.baseLine) 338 s.dt[i] = v 339 } 340 return nil 341} 342 343type fseState struct { 344 dt []decSymbol 345 state decSymbol 346} 347 348// Initialize and decodeAsync first state and symbol. 349func (s *fseState) init(br *bitReader, tableLog uint8, dt []decSymbol) { 350 s.dt = dt 351 br.fill() 352 s.state = dt[br.getBits(tableLog)] 353} 354 355// next returns the current symbol and sets the next state. 356// At least tablelog bits must be available in the bit reader. 357func (s *fseState) next(br *bitReader) { 358 lowBits := uint16(br.getBits(s.state.nbBits())) 359 s.state = s.dt[s.state.newState()+lowBits] 360} 361 362// finished returns true if all bits have been read from the bitstream 363// and the next state would require reading bits from the input. 364func (s *fseState) finished(br *bitReader) bool { 365 return br.finished() && s.state.nbBits() > 0 366} 367 368// final returns the current state symbol without decoding the next. 369func (s *fseState) final() (int, uint8) { 370 return s.state.baselineInt(), s.state.addBits() 371} 372 373// final returns the current state symbol without decoding the next. 374func (s decSymbol) final() (int, uint8) { 375 return s.baselineInt(), s.addBits() 376} 377 378// nextFast returns the next symbol and sets the next state. 379// This can only be used if no symbols are 0 bits. 380// At least tablelog bits must be available in the bit reader. 381func (s *fseState) nextFast(br *bitReader) (uint32, uint8) { 382 lowBits := uint16(br.getBitsFast(s.state.nbBits())) 383 s.state = s.dt[s.state.newState()+lowBits] 384 return s.state.baseline(), s.state.addBits() 385} 386