1// Copyright 2019+ Klaus Post. All rights reserved. 2// License information can be found in the LICENSE file. 3// Based on work by Yann Collet, released under BSD License. 4 5package zstd 6 7import ( 8 "errors" 9 "fmt" 10) 11 12const ( 13 tablelogAbsoluteMax = 9 14) 15 16const ( 17 /*!MEMORY_USAGE : 18 * Memory usage formula : N->2^N Bytes (examples : 10 -> 1KB; 12 -> 4KB ; 16 -> 64KB; 20 -> 1MB; etc.) 19 * Increasing memory usage improves compression ratio 20 * Reduced memory usage can improve speed, due to cache effect 21 * Recommended max value is 14, for 16KB, which nicely fits into Intel x86 L1 cache */ 22 maxMemoryUsage = 11 23 24 maxTableLog = maxMemoryUsage - 2 25 maxTablesize = 1 << maxTableLog 26 maxTableMask = (1 << maxTableLog) - 1 27 minTablelog = 5 28 maxSymbolValue = 255 29) 30 31// fseDecoder provides temporary storage for compression and decompression. 32type fseDecoder struct { 33 dt [maxTablesize]decSymbol // Decompression table. 34 symbolLen uint16 // Length of active part of the symbol table. 35 actualTableLog uint8 // Selected tablelog. 36 maxBits uint8 // Maximum number of additional bits 37 38 // used for table creation to avoid allocations. 39 stateTable [256]uint16 40 norm [maxSymbolValue + 1]int16 41 preDefined bool 42} 43 44// tableStep returns the next table index. 45func tableStep(tableSize uint32) uint32 { 46 return (tableSize >> 1) + (tableSize >> 3) + 3 47} 48 49// readNCount will read the symbol distribution so decoding tables can be constructed. 50func (s *fseDecoder) readNCount(b *byteReader, maxSymbol uint16) error { 51 var ( 52 charnum uint16 53 previous0 bool 54 ) 55 if b.remain() < 4 { 56 return errors.New("input too small") 57 } 58 bitStream := b.Uint32() 59 nbBits := uint((bitStream & 0xF) + minTablelog) // extract tableLog 60 if nbBits > tablelogAbsoluteMax { 61 println("Invalid tablelog:", nbBits) 62 return errors.New("tableLog too large") 63 } 64 bitStream >>= 4 65 bitCount := uint(4) 66 67 s.actualTableLog = uint8(nbBits) 68 remaining := int32((1 << nbBits) + 1) 69 threshold := int32(1 << nbBits) 70 gotTotal := int32(0) 71 nbBits++ 72 73 for remaining > 1 && charnum <= maxSymbol { 74 if previous0 { 75 //println("prev0") 76 n0 := charnum 77 for (bitStream & 0xFFFF) == 0xFFFF { 78 //println("24 x 0") 79 n0 += 24 80 if r := b.remain(); r > 5 { 81 b.advance(2) 82 bitStream = b.Uint32() >> bitCount 83 } else { 84 // end of bit stream 85 bitStream >>= 16 86 bitCount += 16 87 } 88 } 89 //printf("bitstream: %d, 0b%b", bitStream&3, bitStream) 90 for (bitStream & 3) == 3 { 91 n0 += 3 92 bitStream >>= 2 93 bitCount += 2 94 } 95 n0 += uint16(bitStream & 3) 96 bitCount += 2 97 98 if n0 > maxSymbolValue { 99 return errors.New("maxSymbolValue too small") 100 } 101 //println("inserting ", n0-charnum, "zeroes from idx", charnum, "ending before", n0) 102 for charnum < n0 { 103 s.norm[uint8(charnum)] = 0 104 charnum++ 105 } 106 107 if r := b.remain(); r >= 7 || r+int(bitCount>>3) >= 4 { 108 b.advance(bitCount >> 3) 109 bitCount &= 7 110 bitStream = b.Uint32() >> bitCount 111 } else { 112 bitStream >>= 2 113 } 114 } 115 116 max := (2*threshold - 1) - remaining 117 var count int32 118 119 if int32(bitStream)&(threshold-1) < max { 120 count = int32(bitStream) & (threshold - 1) 121 if debug && nbBits < 1 { 122 panic("nbBits underflow") 123 } 124 bitCount += nbBits - 1 125 } else { 126 count = int32(bitStream) & (2*threshold - 1) 127 if count >= threshold { 128 count -= max 129 } 130 bitCount += nbBits 131 } 132 133 // extra accuracy 134 count-- 135 if count < 0 { 136 // -1 means +1 137 remaining += count 138 gotTotal -= count 139 } else { 140 remaining -= count 141 gotTotal += count 142 } 143 s.norm[charnum&0xff] = int16(count) 144 charnum++ 145 previous0 = count == 0 146 for remaining < threshold { 147 nbBits-- 148 threshold >>= 1 149 } 150 151 //println("b.off:", b.off, "len:", len(b.b), "bc:", bitCount, "remain:", b.remain()) 152 if r := b.remain(); r >= 7 || r+int(bitCount>>3) >= 4 { 153 b.advance(bitCount >> 3) 154 bitCount &= 7 155 } else { 156 bitCount -= (uint)(8 * (len(b.b) - 4 - b.off)) 157 b.off = len(b.b) - 4 158 //println("b.off:", b.off, "len:", len(b.b), "bc:", bitCount, "iend", iend) 159 } 160 bitStream = b.Uint32() >> (bitCount & 31) 161 //printf("bitstream is now: 0b%b", bitStream) 162 } 163 s.symbolLen = charnum 164 if s.symbolLen <= 1 { 165 return fmt.Errorf("symbolLen (%d) too small", s.symbolLen) 166 } 167 if s.symbolLen > maxSymbolValue+1 { 168 return fmt.Errorf("symbolLen (%d) too big", s.symbolLen) 169 } 170 if remaining != 1 { 171 return fmt.Errorf("corruption detected (remaining %d != 1)", remaining) 172 } 173 if bitCount > 32 { 174 return fmt.Errorf("corruption detected (bitCount %d > 32)", bitCount) 175 } 176 if gotTotal != 1<<s.actualTableLog { 177 return fmt.Errorf("corruption detected (total %d != %d)", gotTotal, 1<<s.actualTableLog) 178 } 179 b.advance((bitCount + 7) >> 3) 180 // println(s.norm[:s.symbolLen], s.symbolLen) 181 return s.buildDtable() 182} 183 184// decSymbol contains information about a state entry, 185// Including the state offset base, the output symbol and 186// the number of bits to read for the low part of the destination state. 187type decSymbol struct { 188 newState uint16 189 addBits uint8 // Used for symbols until transformed. 190 nbBits uint8 191 baseline uint32 192} 193 194// decSymbolValue returns the transformed decSymbol for the given symbol. 195func decSymbolValue(symb uint8, t []baseOffset) (decSymbol, error) { 196 if int(symb) >= len(t) { 197 return decSymbol{}, fmt.Errorf("rle symbol %d >= max %d", symb, len(t)) 198 } 199 lu := t[symb] 200 return decSymbol{ 201 addBits: lu.addBits, 202 baseline: lu.baseLine, 203 }, nil 204} 205 206// setRLE will set the decoder til RLE mode. 207func (s *fseDecoder) setRLE(symbol decSymbol) { 208 s.actualTableLog = 0 209 s.maxBits = symbol.addBits 210 s.dt[0] = symbol 211} 212 213// buildDtable will build the decoding table. 214func (s *fseDecoder) buildDtable() error { 215 tableSize := uint32(1 << s.actualTableLog) 216 highThreshold := tableSize - 1 217 symbolNext := s.stateTable[:256] 218 219 // Init, lay down lowprob symbols 220 { 221 for i, v := range s.norm[:s.symbolLen] { 222 if v == -1 { 223 s.dt[highThreshold].addBits = uint8(i) 224 highThreshold-- 225 symbolNext[i] = 1 226 } else { 227 symbolNext[i] = uint16(v) 228 } 229 } 230 } 231 // Spread symbols 232 { 233 tableMask := tableSize - 1 234 step := tableStep(tableSize) 235 position := uint32(0) 236 for ss, v := range s.norm[:s.symbolLen] { 237 for i := 0; i < int(v); i++ { 238 s.dt[position].addBits = uint8(ss) 239 position = (position + step) & tableMask 240 for position > highThreshold { 241 // lowprob area 242 position = (position + step) & tableMask 243 } 244 } 245 } 246 if position != 0 { 247 // position must reach all cells once, otherwise normalizedCounter is incorrect 248 return errors.New("corrupted input (position != 0)") 249 } 250 } 251 252 // Build Decoding table 253 { 254 tableSize := uint16(1 << s.actualTableLog) 255 for u, v := range s.dt[:tableSize] { 256 symbol := v.addBits 257 nextState := symbolNext[symbol] 258 symbolNext[symbol] = nextState + 1 259 nBits := s.actualTableLog - byte(highBits(uint32(nextState))) 260 s.dt[u&maxTableMask].nbBits = nBits 261 newState := (nextState << nBits) - tableSize 262 if newState > tableSize { 263 return fmt.Errorf("newState (%d) outside table size (%d)", newState, tableSize) 264 } 265 if newState == uint16(u) && nBits == 0 { 266 // Seems weird that this is possible with nbits > 0. 267 return fmt.Errorf("newState (%d) == oldState (%d) and no bits", newState, u) 268 } 269 s.dt[u&maxTableMask].newState = newState 270 } 271 } 272 return nil 273} 274 275// transform will transform the decoder table into a table usable for 276// decoding without having to apply the transformation while decoding. 277// The state will contain the base value and the number of bits to read. 278func (s *fseDecoder) transform(t []baseOffset) error { 279 tableSize := uint16(1 << s.actualTableLog) 280 s.maxBits = 0 281 for i, v := range s.dt[:tableSize] { 282 if int(v.addBits) >= len(t) { 283 return fmt.Errorf("invalid decoding table entry %d, symbol %d >= max (%d)", i, v.addBits, len(t)) 284 } 285 lu := t[v.addBits] 286 if lu.addBits > s.maxBits { 287 s.maxBits = lu.addBits 288 } 289 s.dt[i&maxTableMask] = decSymbol{ 290 newState: v.newState, 291 nbBits: v.nbBits, 292 addBits: lu.addBits, 293 baseline: lu.baseLine, 294 } 295 } 296 return nil 297} 298 299type fseState struct { 300 // TODO: Check if *[1 << maxTablelog]decSymbol is faster. 301 dt []decSymbol 302 state decSymbol 303} 304 305// Initialize and decodeAsync first state and symbol. 306func (s *fseState) init(br *bitReader, tableLog uint8, dt []decSymbol) { 307 s.dt = dt 308 br.fill() 309 s.state = dt[br.getBits(tableLog)] 310} 311 312// next returns the current symbol and sets the next state. 313// At least tablelog bits must be available in the bit reader. 314func (s *fseState) next(br *bitReader) { 315 lowBits := uint16(br.getBits(s.state.nbBits)) 316 s.state = s.dt[s.state.newState+lowBits] 317} 318 319// finished returns true if all bits have been read from the bitstream 320// and the next state would require reading bits from the input. 321func (s *fseState) finished(br *bitReader) bool { 322 return br.finished() && s.state.nbBits > 0 323} 324 325// final returns the current state symbol without decoding the next. 326func (s *fseState) final() (int, uint8) { 327 return int(s.state.baseline), s.state.addBits 328} 329 330// nextFast returns the next symbol and sets the next state. 331// This can only be used if no symbols are 0 bits. 332// At least tablelog bits must be available in the bit reader. 333func (s *fseState) nextFast(br *bitReader) (uint32, uint8) { 334 lowBits := uint16(br.getBitsFast(s.state.nbBits)) 335 s.state = s.dt[s.state.newState+lowBits] 336 return s.state.baseline, s.state.addBits 337} 338