1package huff0 2 3import ( 4 "errors" 5 "fmt" 6 "io" 7 8 "github.com/klauspost/compress/fse" 9) 10 11type dTable struct { 12 single []dEntrySingle 13 double []dEntryDouble 14} 15 16// single-symbols decoding 17type dEntrySingle struct { 18 entry uint16 19} 20 21// double-symbols decoding 22type dEntryDouble struct { 23 seq uint16 24 nBits uint8 25 len uint8 26} 27 28// ReadTable will read a table from the input. 29// The size of the input may be larger than the table definition. 30// Any content remaining after the table definition will be returned. 31// If no Scratch is provided a new one is allocated. 32// The returned Scratch can be used for decoding input using this table. 33func ReadTable(in []byte, s *Scratch) (s2 *Scratch, remain []byte, err error) { 34 s, err = s.prepare(in) 35 if err != nil { 36 return s, nil, err 37 } 38 if len(in) <= 1 { 39 return s, nil, errors.New("input too small for table") 40 } 41 iSize := in[0] 42 in = in[1:] 43 if iSize >= 128 { 44 // Uncompressed 45 oSize := iSize - 127 46 iSize = (oSize + 1) / 2 47 if int(iSize) > len(in) { 48 return s, nil, errors.New("input too small for table") 49 } 50 for n := uint8(0); n < oSize; n += 2 { 51 v := in[n/2] 52 s.huffWeight[n] = v >> 4 53 s.huffWeight[n+1] = v & 15 54 } 55 s.symbolLen = uint16(oSize) 56 in = in[iSize:] 57 } else { 58 if len(in) <= int(iSize) { 59 return s, nil, errors.New("input too small for table") 60 } 61 // FSE compressed weights 62 s.fse.DecompressLimit = 255 63 hw := s.huffWeight[:] 64 s.fse.Out = hw 65 b, err := fse.Decompress(in[:iSize], s.fse) 66 s.fse.Out = nil 67 if err != nil { 68 return s, nil, err 69 } 70 if len(b) > 255 { 71 return s, nil, errors.New("corrupt input: output table too large") 72 } 73 s.symbolLen = uint16(len(b)) 74 in = in[iSize:] 75 } 76 77 // collect weight stats 78 var rankStats [16]uint32 79 weightTotal := uint32(0) 80 for _, v := range s.huffWeight[:s.symbolLen] { 81 if v > tableLogMax { 82 return s, nil, errors.New("corrupt input: weight too large") 83 } 84 v2 := v & 15 85 rankStats[v2]++ 86 weightTotal += (1 << v2) >> 1 87 } 88 if weightTotal == 0 { 89 return s, nil, errors.New("corrupt input: weights zero") 90 } 91 92 // get last non-null symbol weight (implied, total must be 2^n) 93 { 94 tableLog := highBit32(weightTotal) + 1 95 if tableLog > tableLogMax { 96 return s, nil, errors.New("corrupt input: tableLog too big") 97 } 98 s.actualTableLog = uint8(tableLog) 99 // determine last weight 100 { 101 total := uint32(1) << tableLog 102 rest := total - weightTotal 103 verif := uint32(1) << highBit32(rest) 104 lastWeight := highBit32(rest) + 1 105 if verif != rest { 106 // last value must be a clean power of 2 107 return s, nil, errors.New("corrupt input: last value not power of two") 108 } 109 s.huffWeight[s.symbolLen] = uint8(lastWeight) 110 s.symbolLen++ 111 rankStats[lastWeight]++ 112 } 113 } 114 115 if (rankStats[1] < 2) || (rankStats[1]&1 != 0) { 116 // by construction : at least 2 elts of rank 1, must be even 117 return s, nil, errors.New("corrupt input: min elt size, even check failed ") 118 } 119 120 // TODO: Choose between single/double symbol decoding 121 122 // Calculate starting value for each rank 123 { 124 var nextRankStart uint32 125 for n := uint8(1); n < s.actualTableLog+1; n++ { 126 current := nextRankStart 127 nextRankStart += rankStats[n] << (n - 1) 128 rankStats[n] = current 129 } 130 } 131 132 // fill DTable (always full size) 133 tSize := 1 << tableLogMax 134 if len(s.dt.single) != tSize { 135 s.dt.single = make([]dEntrySingle, tSize) 136 } 137 for n, w := range s.huffWeight[:s.symbolLen] { 138 if w == 0 { 139 continue 140 } 141 length := (uint32(1) << w) >> 1 142 d := dEntrySingle{ 143 entry: uint16(s.actualTableLog+1-w) | (uint16(n) << 8), 144 } 145 single := s.dt.single[rankStats[w] : rankStats[w]+length] 146 for i := range single { 147 single[i] = d 148 } 149 rankStats[w] += length 150 } 151 return s, in, nil 152} 153 154// Decompress1X will decompress a 1X encoded stream. 155// The length of the supplied input must match the end of a block exactly. 156// Before this is called, the table must be initialized with ReadTable unless 157// the encoder re-used the table. 158func (s *Scratch) Decompress1X(in []byte) (out []byte, err error) { 159 if len(s.dt.single) == 0 { 160 return nil, errors.New("no table loaded") 161 } 162 var br bitReader 163 err = br.init(in) 164 if err != nil { 165 return nil, err 166 } 167 s.Out = s.Out[:0] 168 169 decode := func() byte { 170 val := br.peekBitsFast(s.actualTableLog) /* note : actualTableLog >= 1 */ 171 v := s.dt.single[val] 172 br.bitsRead += uint8(v.entry) 173 return uint8(v.entry >> 8) 174 } 175 hasDec := func(v dEntrySingle) byte { 176 br.bitsRead += uint8(v.entry) 177 return uint8(v.entry >> 8) 178 } 179 180 // Avoid bounds check by always having full sized table. 181 const tlSize = 1 << tableLogMax 182 const tlMask = tlSize - 1 183 dt := s.dt.single[:tlSize] 184 185 // Use temp table to avoid bound checks/append penalty. 186 var tmp = s.huffWeight[:256] 187 var off uint8 188 189 for br.off >= 8 { 190 br.fillFast() 191 tmp[off+0] = hasDec(dt[br.peekBitsFast(s.actualTableLog)&tlMask]) 192 tmp[off+1] = hasDec(dt[br.peekBitsFast(s.actualTableLog)&tlMask]) 193 br.fillFast() 194 tmp[off+2] = hasDec(dt[br.peekBitsFast(s.actualTableLog)&tlMask]) 195 tmp[off+3] = hasDec(dt[br.peekBitsFast(s.actualTableLog)&tlMask]) 196 off += 4 197 if off == 0 { 198 if len(s.Out)+256 > s.MaxDecodedSize { 199 br.close() 200 return nil, ErrMaxDecodedSizeExceeded 201 } 202 s.Out = append(s.Out, tmp...) 203 } 204 } 205 206 if len(s.Out)+int(off) > s.MaxDecodedSize { 207 br.close() 208 return nil, ErrMaxDecodedSizeExceeded 209 } 210 s.Out = append(s.Out, tmp[:off]...) 211 212 for !br.finished() { 213 br.fill() 214 if len(s.Out) >= s.MaxDecodedSize { 215 br.close() 216 return nil, ErrMaxDecodedSizeExceeded 217 } 218 s.Out = append(s.Out, decode()) 219 } 220 return s.Out, br.close() 221} 222 223// Decompress4X will decompress a 4X encoded stream. 224// Before this is called, the table must be initialized with ReadTable unless 225// the encoder re-used the table. 226// The length of the supplied input must match the end of a block exactly. 227// The destination size of the uncompressed data must be known and provided. 228func (s *Scratch) Decompress4X(in []byte, dstSize int) (out []byte, err error) { 229 if len(s.dt.single) == 0 { 230 return nil, errors.New("no table loaded") 231 } 232 if len(in) < 6+(4*1) { 233 return nil, errors.New("input too small") 234 } 235 if dstSize > s.MaxDecodedSize { 236 return nil, ErrMaxDecodedSizeExceeded 237 } 238 // TODO: We do not detect when we overrun a buffer, except if the last one does. 239 240 var br [4]bitReader 241 start := 6 242 for i := 0; i < 3; i++ { 243 length := int(in[i*2]) | (int(in[i*2+1]) << 8) 244 if start+length >= len(in) { 245 return nil, errors.New("truncated input (or invalid offset)") 246 } 247 err = br[i].init(in[start : start+length]) 248 if err != nil { 249 return nil, err 250 } 251 start += length 252 } 253 err = br[3].init(in[start:]) 254 if err != nil { 255 return nil, err 256 } 257 258 // Prepare output 259 if cap(s.Out) < dstSize { 260 s.Out = make([]byte, 0, dstSize) 261 } 262 s.Out = s.Out[:dstSize] 263 // destination, offset to match first output 264 dstOut := s.Out 265 dstEvery := (dstSize + 3) / 4 266 267 const tlSize = 1 << tableLogMax 268 const tlMask = tlSize - 1 269 single := s.dt.single[:tlSize] 270 271 decode := func(br *bitReader) byte { 272 val := br.peekBitsFast(s.actualTableLog) /* note : actualTableLog >= 1 */ 273 v := single[val&tlMask] 274 br.bitsRead += uint8(v.entry) 275 return uint8(v.entry >> 8) 276 } 277 278 // Use temp table to avoid bound checks/append penalty. 279 var tmp = s.huffWeight[:256] 280 var off uint8 281 var decoded int 282 283 // Decode 2 values from each decoder/loop. 284 const bufoff = 256 / 4 285bigloop: 286 for { 287 for i := range br { 288 br := &br[i] 289 if br.off < 4 { 290 break bigloop 291 } 292 br.fillFast() 293 } 294 295 { 296 const stream = 0 297 val := br[stream].peekBitsFast(s.actualTableLog) 298 v := single[val&tlMask] 299 br[stream].bitsRead += uint8(v.entry) 300 301 val2 := br[stream].peekBitsFast(s.actualTableLog) 302 v2 := single[val2&tlMask] 303 tmp[off+bufoff*stream+1] = uint8(v2.entry >> 8) 304 tmp[off+bufoff*stream] = uint8(v.entry >> 8) 305 br[stream].bitsRead += uint8(v2.entry) 306 } 307 308 { 309 const stream = 1 310 val := br[stream].peekBitsFast(s.actualTableLog) 311 v := single[val&tlMask] 312 br[stream].bitsRead += uint8(v.entry) 313 314 val2 := br[stream].peekBitsFast(s.actualTableLog) 315 v2 := single[val2&tlMask] 316 tmp[off+bufoff*stream+1] = uint8(v2.entry >> 8) 317 tmp[off+bufoff*stream] = uint8(v.entry >> 8) 318 br[stream].bitsRead += uint8(v2.entry) 319 } 320 321 { 322 const stream = 2 323 val := br[stream].peekBitsFast(s.actualTableLog) 324 v := single[val&tlMask] 325 br[stream].bitsRead += uint8(v.entry) 326 327 val2 := br[stream].peekBitsFast(s.actualTableLog) 328 v2 := single[val2&tlMask] 329 tmp[off+bufoff*stream+1] = uint8(v2.entry >> 8) 330 tmp[off+bufoff*stream] = uint8(v.entry >> 8) 331 br[stream].bitsRead += uint8(v2.entry) 332 } 333 334 { 335 const stream = 3 336 val := br[stream].peekBitsFast(s.actualTableLog) 337 v := single[val&tlMask] 338 br[stream].bitsRead += uint8(v.entry) 339 340 val2 := br[stream].peekBitsFast(s.actualTableLog) 341 v2 := single[val2&tlMask] 342 tmp[off+bufoff*stream+1] = uint8(v2.entry >> 8) 343 tmp[off+bufoff*stream] = uint8(v.entry >> 8) 344 br[stream].bitsRead += uint8(v2.entry) 345 } 346 347 off += 2 348 349 if off == bufoff { 350 if bufoff > dstEvery { 351 return nil, errors.New("corruption detected: stream overrun 1") 352 } 353 copy(dstOut, tmp[:bufoff]) 354 copy(dstOut[dstEvery:], tmp[bufoff:bufoff*2]) 355 copy(dstOut[dstEvery*2:], tmp[bufoff*2:bufoff*3]) 356 copy(dstOut[dstEvery*3:], tmp[bufoff*3:bufoff*4]) 357 off = 0 358 dstOut = dstOut[bufoff:] 359 decoded += 256 360 // There must at least be 3 buffers left. 361 if len(dstOut) < dstEvery*3 { 362 return nil, errors.New("corruption detected: stream overrun 2") 363 } 364 } 365 } 366 if off > 0 { 367 ioff := int(off) 368 if len(dstOut) < dstEvery*3+ioff { 369 return nil, errors.New("corruption detected: stream overrun 3") 370 } 371 copy(dstOut, tmp[:off]) 372 copy(dstOut[dstEvery:dstEvery+ioff], tmp[bufoff:bufoff*2]) 373 copy(dstOut[dstEvery*2:dstEvery*2+ioff], tmp[bufoff*2:bufoff*3]) 374 copy(dstOut[dstEvery*3:dstEvery*3+ioff], tmp[bufoff*3:bufoff*4]) 375 decoded += int(off) * 4 376 dstOut = dstOut[off:] 377 } 378 379 // Decode remaining. 380 for i := range br { 381 offset := dstEvery * i 382 br := &br[i] 383 for !br.finished() { 384 br.fill() 385 if offset >= len(dstOut) { 386 return nil, errors.New("corruption detected: stream overrun 4") 387 } 388 dstOut[offset] = decode(br) 389 offset++ 390 } 391 decoded += offset - dstEvery*i 392 err = br.close() 393 if err != nil { 394 return nil, err 395 } 396 } 397 if dstSize != decoded { 398 return nil, errors.New("corruption detected: short output block") 399 } 400 return s.Out, nil 401} 402 403// matches will compare a decoding table to a coding table. 404// Errors are written to the writer. 405// Nothing will be written if table is ok. 406func (s *Scratch) matches(ct cTable, w io.Writer) { 407 if s == nil || len(s.dt.single) == 0 { 408 return 409 } 410 dt := s.dt.single[:1<<s.actualTableLog] 411 tablelog := s.actualTableLog 412 ok := 0 413 broken := 0 414 for sym, enc := range ct { 415 errs := 0 416 broken++ 417 if enc.nBits == 0 { 418 for _, dec := range dt { 419 if uint8(dec.entry>>8) == byte(sym) { 420 fmt.Fprintf(w, "symbol %x has decoder, but no encoder\n", sym) 421 errs++ 422 break 423 } 424 } 425 if errs == 0 { 426 broken-- 427 } 428 continue 429 } 430 // Unused bits in input 431 ub := tablelog - enc.nBits 432 top := enc.val << ub 433 // decoder looks at top bits. 434 dec := dt[top] 435 if uint8(dec.entry) != enc.nBits { 436 fmt.Fprintf(w, "symbol 0x%x bit size mismatch (enc: %d, dec:%d).\n", sym, enc.nBits, uint8(dec.entry)) 437 errs++ 438 } 439 if uint8(dec.entry>>8) != uint8(sym) { 440 fmt.Fprintf(w, "symbol 0x%x decoder output mismatch (enc: %d, dec:%d).\n", sym, sym, uint8(dec.entry>>8)) 441 errs++ 442 } 443 if errs > 0 { 444 fmt.Fprintf(w, "%d errros in base, stopping\n", errs) 445 continue 446 } 447 // Ensure that all combinations are covered. 448 for i := uint16(0); i < (1 << ub); i++ { 449 vval := top | i 450 dec := dt[vval] 451 if uint8(dec.entry) != enc.nBits { 452 fmt.Fprintf(w, "symbol 0x%x bit size mismatch (enc: %d, dec:%d).\n", vval, enc.nBits, uint8(dec.entry)) 453 errs++ 454 } 455 if uint8(dec.entry>>8) != uint8(sym) { 456 fmt.Fprintf(w, "symbol 0x%x decoder output mismatch (enc: %d, dec:%d).\n", vval, sym, uint8(dec.entry>>8)) 457 errs++ 458 } 459 if errs > 20 { 460 fmt.Fprintf(w, "%d errros, stopping\n", errs) 461 break 462 } 463 } 464 if errs == 0 { 465 ok++ 466 broken-- 467 } 468 } 469 if broken > 0 { 470 fmt.Fprintf(w, "%d broken, %d ok\n", broken, ok) 471 } 472} 473