1package fse 2 3import ( 4 "errors" 5 "fmt" 6) 7 8const ( 9 tablelogAbsoluteMax = 15 10) 11 12// Decompress a block of data. 13// You can provide a scratch buffer to avoid allocations. 14// If nil is provided a temporary one will be allocated. 15// It is possible, but by no way guaranteed that corrupt data will 16// return an error. 17// It is up to the caller to verify integrity of the returned data. 18// Use a predefined Scrach to set maximum acceptable output size. 19func Decompress(b []byte, s *Scratch) ([]byte, error) { 20 s, err := s.prepare(b) 21 if err != nil { 22 return nil, err 23 } 24 s.Out = s.Out[:0] 25 err = s.readNCount() 26 if err != nil { 27 return nil, err 28 } 29 err = s.buildDtable() 30 if err != nil { 31 return nil, err 32 } 33 err = s.decompress() 34 if err != nil { 35 return nil, err 36 } 37 38 return s.Out, nil 39} 40 41// readNCount will read the symbol distribution so decoding tables can be constructed. 42func (s *Scratch) readNCount() error { 43 var ( 44 charnum uint16 45 previous0 bool 46 b = &s.br 47 ) 48 iend := b.remain() 49 if iend < 4 { 50 return errors.New("input too small") 51 } 52 bitStream := b.Uint32() 53 nbBits := uint((bitStream & 0xF) + minTablelog) // extract tableLog 54 if nbBits > tablelogAbsoluteMax { 55 return errors.New("tableLog too large") 56 } 57 bitStream >>= 4 58 bitCount := uint(4) 59 60 s.actualTableLog = uint8(nbBits) 61 remaining := int32((1 << nbBits) + 1) 62 threshold := int32(1 << nbBits) 63 gotTotal := int32(0) 64 nbBits++ 65 66 for remaining > 1 { 67 if previous0 { 68 n0 := charnum 69 for (bitStream & 0xFFFF) == 0xFFFF { 70 n0 += 24 71 if b.off < iend-5 { 72 b.advance(2) 73 bitStream = b.Uint32() >> bitCount 74 } else { 75 bitStream >>= 16 76 bitCount += 16 77 } 78 } 79 for (bitStream & 3) == 3 { 80 n0 += 3 81 bitStream >>= 2 82 bitCount += 2 83 } 84 n0 += uint16(bitStream & 3) 85 bitCount += 2 86 if n0 > maxSymbolValue { 87 return errors.New("maxSymbolValue too small") 88 } 89 for charnum < n0 { 90 s.norm[charnum&0xff] = 0 91 charnum++ 92 } 93 94 if b.off <= iend-7 || b.off+int(bitCount>>3) <= iend-4 { 95 b.advance(bitCount >> 3) 96 bitCount &= 7 97 bitStream = b.Uint32() >> bitCount 98 } else { 99 bitStream >>= 2 100 } 101 } 102 103 max := (2*(threshold) - 1) - (remaining) 104 var count int32 105 106 if (int32(bitStream) & (threshold - 1)) < max { 107 count = int32(bitStream) & (threshold - 1) 108 bitCount += nbBits - 1 109 } else { 110 count = int32(bitStream) & (2*threshold - 1) 111 if count >= threshold { 112 count -= max 113 } 114 bitCount += nbBits 115 } 116 117 count-- // extra accuracy 118 if count < 0 { 119 // -1 means +1 120 remaining += count 121 gotTotal -= count 122 } else { 123 remaining -= count 124 gotTotal += count 125 } 126 s.norm[charnum&0xff] = int16(count) 127 charnum++ 128 previous0 = count == 0 129 for remaining < threshold { 130 nbBits-- 131 threshold >>= 1 132 } 133 if b.off <= iend-7 || b.off+int(bitCount>>3) <= iend-4 { 134 b.advance(bitCount >> 3) 135 bitCount &= 7 136 } else { 137 bitCount -= (uint)(8 * (len(b.b) - 4 - b.off)) 138 b.off = len(b.b) - 4 139 } 140 bitStream = b.Uint32() >> (bitCount & 31) 141 } 142 s.symbolLen = charnum 143 144 if s.symbolLen <= 1 { 145 return fmt.Errorf("symbolLen (%d) too small", s.symbolLen) 146 } 147 if s.symbolLen > maxSymbolValue+1 { 148 return fmt.Errorf("symbolLen (%d) too big", s.symbolLen) 149 } 150 if remaining != 1 { 151 return fmt.Errorf("corruption detected (remaining %d != 1)", remaining) 152 } 153 if bitCount > 32 { 154 return fmt.Errorf("corruption detected (bitCount %d > 32)", bitCount) 155 } 156 if gotTotal != 1<<s.actualTableLog { 157 return fmt.Errorf("corruption detected (total %d != %d)", gotTotal, 1<<s.actualTableLog) 158 } 159 b.advance((bitCount + 7) >> 3) 160 return nil 161} 162 163// decSymbol contains information about a state entry, 164// Including the state offset base, the output symbol and 165// the number of bits to read for the low part of the destination state. 166type decSymbol struct { 167 newState uint16 168 symbol uint8 169 nbBits uint8 170} 171 172// allocDtable will allocate decoding tables if they are not big enough. 173func (s *Scratch) allocDtable() { 174 tableSize := 1 << s.actualTableLog 175 if cap(s.decTable) < int(tableSize) { 176 s.decTable = make([]decSymbol, tableSize) 177 } 178 s.decTable = s.decTable[:tableSize] 179 180 if cap(s.ct.tableSymbol) < 256 { 181 s.ct.tableSymbol = make([]byte, 256) 182 } 183 s.ct.tableSymbol = s.ct.tableSymbol[:256] 184 185 if cap(s.ct.stateTable) < 256 { 186 s.ct.stateTable = make([]uint16, 256) 187 } 188 s.ct.stateTable = s.ct.stateTable[:256] 189} 190 191// buildDtable will build the decoding table. 192func (s *Scratch) buildDtable() error { 193 tableSize := uint32(1 << s.actualTableLog) 194 highThreshold := tableSize - 1 195 s.allocDtable() 196 symbolNext := s.ct.stateTable[:256] 197 198 // Init, lay down lowprob symbols 199 s.zeroBits = false 200 { 201 largeLimit := int16(1 << (s.actualTableLog - 1)) 202 for i, v := range s.norm[:s.symbolLen] { 203 if v == -1 { 204 s.decTable[highThreshold].symbol = uint8(i) 205 highThreshold-- 206 symbolNext[i] = 1 207 } else { 208 if v >= largeLimit { 209 s.zeroBits = true 210 } 211 symbolNext[i] = uint16(v) 212 } 213 } 214 } 215 // Spread symbols 216 { 217 tableMask := tableSize - 1 218 step := tableStep(tableSize) 219 position := uint32(0) 220 for ss, v := range s.norm[:s.symbolLen] { 221 for i := 0; i < int(v); i++ { 222 s.decTable[position].symbol = uint8(ss) 223 position = (position + step) & tableMask 224 for position > highThreshold { 225 // lowprob area 226 position = (position + step) & tableMask 227 } 228 } 229 } 230 if position != 0 { 231 // position must reach all cells once, otherwise normalizedCounter is incorrect 232 return errors.New("corrupted input (position != 0)") 233 } 234 } 235 236 // Build Decoding table 237 { 238 tableSize := uint16(1 << s.actualTableLog) 239 for u, v := range s.decTable { 240 symbol := v.symbol 241 nextState := symbolNext[symbol] 242 symbolNext[symbol] = nextState + 1 243 nBits := s.actualTableLog - byte(highBits(uint32(nextState))) 244 s.decTable[u].nbBits = nBits 245 newState := (nextState << nBits) - tableSize 246 if newState >= tableSize { 247 return fmt.Errorf("newState (%d) outside table size (%d)", newState, tableSize) 248 } 249 if newState == uint16(u) && nBits == 0 { 250 // Seems weird that this is possible with nbits > 0. 251 return fmt.Errorf("newState (%d) == oldState (%d) and no bits", newState, u) 252 } 253 s.decTable[u].newState = newState 254 } 255 } 256 return nil 257} 258 259// decompress will decompress the bitstream. 260// If the buffer is over-read an error is returned. 261func (s *Scratch) decompress() error { 262 br := &s.bits 263 br.init(s.br.unread()) 264 265 var s1, s2 decoder 266 // Initialize and decode first state and symbol. 267 s1.init(br, s.decTable, s.actualTableLog) 268 s2.init(br, s.decTable, s.actualTableLog) 269 270 // Use temp table to avoid bound checks/append penalty. 271 var tmp = s.ct.tableSymbol[:256] 272 var off uint8 273 274 // Main part 275 if !s.zeroBits { 276 for br.off >= 8 { 277 br.fillFast() 278 tmp[off+0] = s1.nextFast() 279 tmp[off+1] = s2.nextFast() 280 br.fillFast() 281 tmp[off+2] = s1.nextFast() 282 tmp[off+3] = s2.nextFast() 283 off += 4 284 // When off is 0, we have overflowed and should write. 285 if off == 0 { 286 s.Out = append(s.Out, tmp...) 287 if len(s.Out) >= s.DecompressLimit { 288 return fmt.Errorf("output size (%d) > DecompressLimit (%d)", len(s.Out), s.DecompressLimit) 289 } 290 } 291 } 292 } else { 293 for br.off >= 8 { 294 br.fillFast() 295 tmp[off+0] = s1.next() 296 tmp[off+1] = s2.next() 297 br.fillFast() 298 tmp[off+2] = s1.next() 299 tmp[off+3] = s2.next() 300 off += 4 301 if off == 0 { 302 s.Out = append(s.Out, tmp...) 303 // When off is 0, we have overflowed and should write. 304 if len(s.Out) >= s.DecompressLimit { 305 return fmt.Errorf("output size (%d) > DecompressLimit (%d)", len(s.Out), s.DecompressLimit) 306 } 307 } 308 } 309 } 310 s.Out = append(s.Out, tmp[:off]...) 311 312 // Final bits, a bit more expensive check 313 for { 314 if s1.finished() { 315 s.Out = append(s.Out, s1.final(), s2.final()) 316 break 317 } 318 br.fill() 319 s.Out = append(s.Out, s1.next()) 320 if s2.finished() { 321 s.Out = append(s.Out, s2.final(), s1.final()) 322 break 323 } 324 s.Out = append(s.Out, s2.next()) 325 if len(s.Out) >= s.DecompressLimit { 326 return fmt.Errorf("output size (%d) > DecompressLimit (%d)", len(s.Out), s.DecompressLimit) 327 } 328 } 329 return br.close() 330} 331 332// decoder keeps track of the current state and updates it from the bitstream. 333type decoder struct { 334 state uint16 335 br *bitReader 336 dt []decSymbol 337} 338 339// init will initialize the decoder and read the first state from the stream. 340func (d *decoder) init(in *bitReader, dt []decSymbol, tableLog uint8) { 341 d.dt = dt 342 d.br = in 343 d.state = uint16(in.getBits(tableLog)) 344} 345 346// next returns the next symbol and sets the next state. 347// At least tablelog bits must be available in the bit reader. 348func (d *decoder) next() uint8 { 349 n := &d.dt[d.state] 350 lowBits := d.br.getBits(n.nbBits) 351 d.state = n.newState + lowBits 352 return n.symbol 353} 354 355// finished returns true if all bits have been read from the bitstream 356// and the next state would require reading bits from the input. 357func (d *decoder) finished() bool { 358 return d.br.finished() && d.dt[d.state].nbBits > 0 359} 360 361// final returns the current state symbol without decoding the next. 362func (d *decoder) final() uint8 { 363 return d.dt[d.state].symbol 364} 365 366// nextFast returns the next symbol and sets the next state. 367// This can only be used if no symbols are 0 bits. 368// At least tablelog bits must be available in the bit reader. 369func (d *decoder) nextFast() uint8 { 370 n := d.dt[d.state] 371 lowBits := d.br.getBitsFast(n.nbBits) 372 d.state = n.newState + lowBits 373 return n.symbol 374} 375