1// Copyright 2019+ Klaus Post. All rights reserved. 2// License information can be found in the LICENSE file. 3// Based on work by Yann Collet, released under BSD License. 4 5package zstd 6 7import ( 8 "errors" 9 "fmt" 10 "io" 11 "sync" 12 13 "github.com/klauspost/compress/huff0" 14 "github.com/klauspost/compress/zstd/internal/xxhash" 15) 16 17type blockType uint8 18 19//go:generate stringer -type=blockType,literalsBlockType,seqCompMode,tableIndex 20 21const ( 22 blockTypeRaw blockType = iota 23 blockTypeRLE 24 blockTypeCompressed 25 blockTypeReserved 26) 27 28type literalsBlockType uint8 29 30const ( 31 literalsBlockRaw literalsBlockType = iota 32 literalsBlockRLE 33 literalsBlockCompressed 34 literalsBlockTreeless 35) 36 37const ( 38 // maxCompressedBlockSize is the biggest allowed compressed block size (128KB) 39 maxCompressedBlockSize = 128 << 10 40 41 // Maximum possible block size (all Raw+Uncompressed). 42 maxBlockSize = (1 << 21) - 1 43 44 // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#literals_section_header 45 maxCompressedLiteralSize = 1 << 18 46 maxRLELiteralSize = 1 << 20 47 maxMatchLen = 131074 48 maxSequences = 0x7f00 + 0xffff 49 50 // We support slightly less than the reference decoder to be able to 51 // use ints on 32 bit archs. 52 maxOffsetBits = 30 53) 54 55var ( 56 huffDecoderPool = sync.Pool{New: func() interface{} { 57 return &huff0.Scratch{} 58 }} 59 60 fseDecoderPool = sync.Pool{New: func() interface{} { 61 return &fseDecoder{} 62 }} 63) 64 65type blockDec struct { 66 // Raw source data of the block. 67 data []byte 68 dataStorage []byte 69 70 // Destination of the decoded data. 71 dst []byte 72 73 // Buffer for literals data. 74 literalBuf []byte 75 76 // Window size of the block. 77 WindowSize uint64 78 79 history chan *history 80 input chan struct{} 81 result chan decodeOutput 82 sequenceBuf []seq 83 err error 84 decWG sync.WaitGroup 85 86 // Frame to use for singlethreaded decoding. 87 // Should not be used by the decoder itself since parent may be another frame. 88 localFrame *frameDec 89 90 // Block is RLE, this is the size. 91 RLESize uint32 92 tmp [4]byte 93 94 Type blockType 95 96 // Is this the last block of a frame? 97 Last bool 98 99 // Use less memory 100 lowMem bool 101} 102 103func (b *blockDec) String() string { 104 if b == nil { 105 return "<nil>" 106 } 107 return fmt.Sprintf("Steam Size: %d, Type: %v, Last: %t, Window: %d", len(b.data), b.Type, b.Last, b.WindowSize) 108} 109 110func newBlockDec(lowMem bool) *blockDec { 111 b := blockDec{ 112 lowMem: lowMem, 113 result: make(chan decodeOutput, 1), 114 input: make(chan struct{}, 1), 115 history: make(chan *history, 1), 116 } 117 b.decWG.Add(1) 118 go b.startDecoder() 119 return &b 120} 121 122// reset will reset the block. 123// Input must be a start of a block and will be at the end of the block when returned. 124func (b *blockDec) reset(br byteBuffer, windowSize uint64) error { 125 b.WindowSize = windowSize 126 tmp := br.readSmall(3) 127 if tmp == nil { 128 if debug { 129 println("Reading block header:", io.ErrUnexpectedEOF) 130 } 131 return io.ErrUnexpectedEOF 132 } 133 bh := uint32(tmp[0]) | (uint32(tmp[1]) << 8) | (uint32(tmp[2]) << 16) 134 b.Last = bh&1 != 0 135 b.Type = blockType((bh >> 1) & 3) 136 // find size. 137 cSize := int(bh >> 3) 138 maxSize := maxBlockSize 139 switch b.Type { 140 case blockTypeReserved: 141 return ErrReservedBlockType 142 case blockTypeRLE: 143 b.RLESize = uint32(cSize) 144 if b.lowMem { 145 maxSize = cSize 146 } 147 cSize = 1 148 case blockTypeCompressed: 149 if debug { 150 println("Data size on stream:", cSize) 151 } 152 b.RLESize = 0 153 maxSize = maxCompressedBlockSize 154 if windowSize < maxCompressedBlockSize && b.lowMem { 155 maxSize = int(windowSize) 156 } 157 if cSize > maxCompressedBlockSize || uint64(cSize) > b.WindowSize { 158 if debug { 159 printf("compressed block too big: csize:%d block: %+v\n", uint64(cSize), b) 160 } 161 return ErrCompressedSizeTooBig 162 } 163 case blockTypeRaw: 164 b.RLESize = 0 165 // We do not need a destination for raw blocks. 166 maxSize = -1 167 default: 168 panic("Invalid block type") 169 } 170 171 // Read block data. 172 if cap(b.dataStorage) < cSize { 173 if b.lowMem { 174 b.dataStorage = make([]byte, 0, cSize) 175 } else { 176 b.dataStorage = make([]byte, 0, maxBlockSize) 177 } 178 } 179 if cap(b.dst) <= maxSize { 180 b.dst = make([]byte, 0, maxSize+1) 181 } 182 var err error 183 b.data, err = br.readBig(cSize, b.dataStorage) 184 if err != nil { 185 if debug { 186 println("Reading block:", err, "(", cSize, ")", len(b.data)) 187 printf("%T", br) 188 } 189 return err 190 } 191 return nil 192} 193 194// sendEOF will make the decoder send EOF on this frame. 195func (b *blockDec) sendErr(err error) { 196 b.Last = true 197 b.Type = blockTypeReserved 198 b.err = err 199 b.input <- struct{}{} 200} 201 202// Close will release resources. 203// Closed blockDec cannot be reset. 204func (b *blockDec) Close() { 205 close(b.input) 206 close(b.history) 207 close(b.result) 208 b.decWG.Wait() 209} 210 211// decodeAsync will prepare decoding the block when it receives input. 212// This will separate output and history. 213func (b *blockDec) startDecoder() { 214 defer b.decWG.Done() 215 for range b.input { 216 //println("blockDec: Got block input") 217 switch b.Type { 218 case blockTypeRLE: 219 if cap(b.dst) < int(b.RLESize) { 220 if b.lowMem { 221 b.dst = make([]byte, b.RLESize) 222 } else { 223 b.dst = make([]byte, maxBlockSize) 224 } 225 } 226 o := decodeOutput{ 227 d: b, 228 b: b.dst[:b.RLESize], 229 err: nil, 230 } 231 v := b.data[0] 232 for i := range o.b { 233 o.b[i] = v 234 } 235 hist := <-b.history 236 hist.append(o.b) 237 b.result <- o 238 case blockTypeRaw: 239 o := decodeOutput{ 240 d: b, 241 b: b.data, 242 err: nil, 243 } 244 hist := <-b.history 245 hist.append(o.b) 246 b.result <- o 247 case blockTypeCompressed: 248 b.dst = b.dst[:0] 249 err := b.decodeCompressed(nil) 250 o := decodeOutput{ 251 d: b, 252 b: b.dst, 253 err: err, 254 } 255 if debug { 256 println("Decompressed to", len(b.dst), "bytes, error:", err) 257 } 258 b.result <- o 259 case blockTypeReserved: 260 // Used for returning errors. 261 <-b.history 262 b.result <- decodeOutput{ 263 d: b, 264 b: nil, 265 err: b.err, 266 } 267 default: 268 panic("Invalid block type") 269 } 270 if debug { 271 println("blockDec: Finished block") 272 } 273 } 274} 275 276// decodeAsync will prepare decoding the block when it receives the history. 277// If history is provided, it will not fetch it from the channel. 278func (b *blockDec) decodeBuf(hist *history) error { 279 switch b.Type { 280 case blockTypeRLE: 281 if cap(b.dst) < int(b.RLESize) { 282 if b.lowMem { 283 b.dst = make([]byte, b.RLESize) 284 } else { 285 b.dst = make([]byte, maxBlockSize) 286 } 287 } 288 b.dst = b.dst[:b.RLESize] 289 v := b.data[0] 290 for i := range b.dst { 291 b.dst[i] = v 292 } 293 hist.appendKeep(b.dst) 294 return nil 295 case blockTypeRaw: 296 hist.appendKeep(b.data) 297 return nil 298 case blockTypeCompressed: 299 saved := b.dst 300 b.dst = hist.b 301 hist.b = nil 302 err := b.decodeCompressed(hist) 303 if debug { 304 println("Decompressed to total", len(b.dst), "bytes, hash:", xxhash.Sum64(b.dst), "error:", err) 305 } 306 hist.b = b.dst 307 b.dst = saved 308 return err 309 case blockTypeReserved: 310 // Used for returning errors. 311 return b.err 312 default: 313 panic("Invalid block type") 314 } 315} 316 317// decodeCompressed will start decompressing a block. 318// If no history is supplied the decoder will decodeAsync as much as possible 319// before fetching from blockDec.history 320func (b *blockDec) decodeCompressed(hist *history) error { 321 in := b.data 322 delayedHistory := hist == nil 323 324 if delayedHistory { 325 // We must always grab history. 326 defer func() { 327 if hist == nil { 328 <-b.history 329 } 330 }() 331 } 332 // There must be at least one byte for Literals_Block_Type and one for Sequences_Section_Header 333 if len(in) < 2 { 334 return ErrBlockTooSmall 335 } 336 litType := literalsBlockType(in[0] & 3) 337 var litRegenSize int 338 var litCompSize int 339 sizeFormat := (in[0] >> 2) & 3 340 var fourStreams bool 341 switch litType { 342 case literalsBlockRaw, literalsBlockRLE: 343 switch sizeFormat { 344 case 0, 2: 345 // Regenerated_Size uses 5 bits (0-31). Literals_Section_Header uses 1 byte. 346 litRegenSize = int(in[0] >> 3) 347 in = in[1:] 348 case 1: 349 // Regenerated_Size uses 12 bits (0-4095). Literals_Section_Header uses 2 bytes. 350 litRegenSize = int(in[0]>>4) + (int(in[1]) << 4) 351 in = in[2:] 352 case 3: 353 // Regenerated_Size uses 20 bits (0-1048575). Literals_Section_Header uses 3 bytes. 354 if len(in) < 3 { 355 println("too small: litType:", litType, " sizeFormat", sizeFormat, len(in)) 356 return ErrBlockTooSmall 357 } 358 litRegenSize = int(in[0]>>4) + (int(in[1]) << 4) + (int(in[2]) << 12) 359 in = in[3:] 360 } 361 case literalsBlockCompressed, literalsBlockTreeless: 362 switch sizeFormat { 363 case 0, 1: 364 // Both Regenerated_Size and Compressed_Size use 10 bits (0-1023). 365 if len(in) < 3 { 366 println("too small: litType:", litType, " sizeFormat", sizeFormat, len(in)) 367 return ErrBlockTooSmall 368 } 369 n := uint64(in[0]>>4) + (uint64(in[1]) << 4) + (uint64(in[2]) << 12) 370 litRegenSize = int(n & 1023) 371 litCompSize = int(n >> 10) 372 fourStreams = sizeFormat == 1 373 in = in[3:] 374 case 2: 375 fourStreams = true 376 if len(in) < 4 { 377 println("too small: litType:", litType, " sizeFormat", sizeFormat, len(in)) 378 return ErrBlockTooSmall 379 } 380 n := uint64(in[0]>>4) + (uint64(in[1]) << 4) + (uint64(in[2]) << 12) + (uint64(in[3]) << 20) 381 litRegenSize = int(n & 16383) 382 litCompSize = int(n >> 14) 383 in = in[4:] 384 case 3: 385 fourStreams = true 386 if len(in) < 5 { 387 println("too small: litType:", litType, " sizeFormat", sizeFormat, len(in)) 388 return ErrBlockTooSmall 389 } 390 n := uint64(in[0]>>4) + (uint64(in[1]) << 4) + (uint64(in[2]) << 12) + (uint64(in[3]) << 20) + (uint64(in[4]) << 28) 391 litRegenSize = int(n & 262143) 392 litCompSize = int(n >> 18) 393 in = in[5:] 394 } 395 } 396 if debug { 397 println("literals type:", litType, "litRegenSize:", litRegenSize, "litCompSize:", litCompSize, "sizeFormat:", sizeFormat, "4X:", fourStreams) 398 } 399 var literals []byte 400 var huff *huff0.Scratch 401 switch litType { 402 case literalsBlockRaw: 403 if len(in) < litRegenSize { 404 println("too small: litType:", litType, " sizeFormat", sizeFormat, "remain:", len(in), "want:", litRegenSize) 405 return ErrBlockTooSmall 406 } 407 literals = in[:litRegenSize] 408 in = in[litRegenSize:] 409 //printf("Found %d uncompressed literals\n", litRegenSize) 410 case literalsBlockRLE: 411 if len(in) < 1 { 412 println("too small: litType:", litType, " sizeFormat", sizeFormat, "remain:", len(in), "want:", 1) 413 return ErrBlockTooSmall 414 } 415 if cap(b.literalBuf) < litRegenSize { 416 if b.lowMem { 417 b.literalBuf = make([]byte, litRegenSize) 418 } else { 419 if litRegenSize > maxCompressedLiteralSize { 420 // Exceptional 421 b.literalBuf = make([]byte, litRegenSize) 422 } else { 423 b.literalBuf = make([]byte, litRegenSize, maxCompressedLiteralSize) 424 425 } 426 } 427 } 428 literals = b.literalBuf[:litRegenSize] 429 v := in[0] 430 for i := range literals { 431 literals[i] = v 432 } 433 in = in[1:] 434 if debug { 435 printf("Found %d RLE compressed literals\n", litRegenSize) 436 } 437 case literalsBlockTreeless: 438 if len(in) < litCompSize { 439 println("too small: litType:", litType, " sizeFormat", sizeFormat, "remain:", len(in), "want:", litCompSize) 440 return ErrBlockTooSmall 441 } 442 // Store compressed literals, so we defer decoding until we get history. 443 literals = in[:litCompSize] 444 in = in[litCompSize:] 445 if debug { 446 printf("Found %d compressed literals\n", litCompSize) 447 } 448 case literalsBlockCompressed: 449 if len(in) < litCompSize { 450 println("too small: litType:", litType, " sizeFormat", sizeFormat, "remain:", len(in), "want:", litCompSize) 451 return ErrBlockTooSmall 452 } 453 literals = in[:litCompSize] 454 in = in[litCompSize:] 455 huff = huffDecoderPool.Get().(*huff0.Scratch) 456 var err error 457 // Ensure we have space to store it. 458 if cap(b.literalBuf) < litRegenSize { 459 if b.lowMem { 460 b.literalBuf = make([]byte, 0, litRegenSize) 461 } else { 462 b.literalBuf = make([]byte, 0, maxCompressedLiteralSize) 463 } 464 } 465 if huff == nil { 466 huff = &huff0.Scratch{} 467 } 468 huff, literals, err = huff0.ReadTable(literals, huff) 469 if err != nil { 470 println("reading huffman table:", err) 471 return err 472 } 473 // Use our out buffer. 474 if fourStreams { 475 literals, err = huff.Decoder().Decompress4X(b.literalBuf[:0:litRegenSize], literals) 476 } else { 477 literals, err = huff.Decoder().Decompress1X(b.literalBuf[:0:litRegenSize], literals) 478 } 479 if err != nil { 480 println("decoding compressed literals:", err) 481 return err 482 } 483 // Make sure we don't leak our literals buffer 484 if len(literals) != litRegenSize { 485 return fmt.Errorf("literal output size mismatch want %d, got %d", litRegenSize, len(literals)) 486 } 487 if debug { 488 printf("Decompressed %d literals into %d bytes\n", litCompSize, litRegenSize) 489 } 490 } 491 492 // Decode Sequences 493 // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#sequences-section 494 if len(in) < 1 { 495 return ErrBlockTooSmall 496 } 497 seqHeader := in[0] 498 nSeqs := 0 499 switch { 500 case seqHeader == 0: 501 in = in[1:] 502 case seqHeader < 128: 503 nSeqs = int(seqHeader) 504 in = in[1:] 505 case seqHeader < 255: 506 if len(in) < 2 { 507 return ErrBlockTooSmall 508 } 509 nSeqs = int(seqHeader-128)<<8 | int(in[1]) 510 in = in[2:] 511 case seqHeader == 255: 512 if len(in) < 3 { 513 return ErrBlockTooSmall 514 } 515 nSeqs = 0x7f00 + int(in[1]) + (int(in[2]) << 8) 516 in = in[3:] 517 } 518 // Allocate sequences 519 if cap(b.sequenceBuf) < nSeqs { 520 if b.lowMem { 521 b.sequenceBuf = make([]seq, nSeqs) 522 } else { 523 // Allocate max 524 b.sequenceBuf = make([]seq, nSeqs, maxSequences) 525 } 526 } else { 527 // Reuse buffer 528 b.sequenceBuf = b.sequenceBuf[:nSeqs] 529 } 530 var seqs = &sequenceDecs{} 531 if nSeqs > 0 { 532 if len(in) < 1 { 533 return ErrBlockTooSmall 534 } 535 br := byteReader{b: in, off: 0} 536 compMode := br.Uint8() 537 br.advance(1) 538 if debug { 539 printf("Compression modes: 0b%b", compMode) 540 } 541 for i := uint(0); i < 3; i++ { 542 mode := seqCompMode((compMode >> (6 - i*2)) & 3) 543 if debug { 544 println("Table", tableIndex(i), "is", mode) 545 } 546 var seq *sequenceDec 547 switch tableIndex(i) { 548 case tableLiteralLengths: 549 seq = &seqs.litLengths 550 case tableOffsets: 551 seq = &seqs.offsets 552 case tableMatchLengths: 553 seq = &seqs.matchLengths 554 default: 555 panic("unknown table") 556 } 557 switch mode { 558 case compModePredefined: 559 seq.fse = &fsePredef[i] 560 case compModeRLE: 561 if br.remain() < 1 { 562 return ErrBlockTooSmall 563 } 564 v := br.Uint8() 565 br.advance(1) 566 dec := fseDecoderPool.Get().(*fseDecoder) 567 symb, err := decSymbolValue(v, symbolTableX[i]) 568 if err != nil { 569 printf("RLE Transform table (%v) error: %v", tableIndex(i), err) 570 return err 571 } 572 dec.setRLE(symb) 573 seq.fse = dec 574 if debug { 575 printf("RLE set to %+v, code: %v", symb, v) 576 } 577 case compModeFSE: 578 println("Reading table for", tableIndex(i)) 579 dec := fseDecoderPool.Get().(*fseDecoder) 580 err := dec.readNCount(&br, uint16(maxTableSymbol[i])) 581 if err != nil { 582 println("Read table error:", err) 583 return err 584 } 585 err = dec.transform(symbolTableX[i]) 586 if err != nil { 587 println("Transform table error:", err) 588 return err 589 } 590 if debug { 591 println("Read table ok", "symbolLen:", dec.symbolLen) 592 } 593 seq.fse = dec 594 case compModeRepeat: 595 seq.repeat = true 596 } 597 if br.overread() { 598 return io.ErrUnexpectedEOF 599 } 600 } 601 in = br.unread() 602 } 603 604 // Wait for history. 605 // All time spent after this is critical since it is strictly sequential. 606 if hist == nil { 607 hist = <-b.history 608 if hist.error { 609 return ErrDecoderClosed 610 } 611 } 612 613 // Decode treeless literal block. 614 if litType == literalsBlockTreeless { 615 // TODO: We could send the history early WITHOUT the stream history. 616 // This would allow decoding treeless literals before the byte history is available. 617 // Silencia stats: Treeless 4393, with: 32775, total: 37168, 11% treeless. 618 // So not much obvious gain here. 619 620 if hist.huffTree == nil { 621 return errors.New("literal block was treeless, but no history was defined") 622 } 623 // Ensure we have space to store it. 624 if cap(b.literalBuf) < litRegenSize { 625 if b.lowMem { 626 b.literalBuf = make([]byte, 0, litRegenSize) 627 } else { 628 b.literalBuf = make([]byte, 0, maxCompressedLiteralSize) 629 } 630 } 631 var err error 632 // Use our out buffer. 633 huff = hist.huffTree 634 if fourStreams { 635 literals, err = huff.Decoder().Decompress4X(b.literalBuf[:0:litRegenSize], literals) 636 } else { 637 literals, err = huff.Decoder().Decompress1X(b.literalBuf[:0:litRegenSize], literals) 638 } 639 // Make sure we don't leak our literals buffer 640 if err != nil { 641 println("decompressing literals:", err) 642 return err 643 } 644 if len(literals) != litRegenSize { 645 return fmt.Errorf("literal output size mismatch want %d, got %d", litRegenSize, len(literals)) 646 } 647 } else { 648 if hist.huffTree != nil && huff != nil { 649 if hist.dict == nil || hist.dict.litEnc != hist.huffTree { 650 huffDecoderPool.Put(hist.huffTree) 651 } 652 hist.huffTree = nil 653 } 654 } 655 if huff != nil { 656 hist.huffTree = huff 657 } 658 if debug { 659 println("Final literals:", len(literals), "hash:", xxhash.Sum64(literals), "and", nSeqs, "sequences.") 660 } 661 662 if nSeqs == 0 { 663 // Decompressed content is defined entirely as Literals Section content. 664 b.dst = append(b.dst, literals...) 665 if delayedHistory { 666 hist.append(literals) 667 } 668 return nil 669 } 670 671 seqs, err := seqs.mergeHistory(&hist.decoders) 672 if err != nil { 673 return err 674 } 675 if debug { 676 println("History merged ok") 677 } 678 br := &bitReader{} 679 if err := br.init(in); err != nil { 680 return err 681 } 682 683 // TODO: Investigate if sending history without decoders are faster. 684 // This would allow the sequences to be decoded async and only have to construct stream history. 685 // If only recent offsets were not transferred, this would be an obvious win. 686 // Also, if first 3 sequences don't reference recent offsets, all sequences can be decoded. 687 688 hbytes := hist.b 689 if len(hbytes) > hist.windowSize { 690 hbytes = hbytes[len(hbytes)-hist.windowSize:] 691 // We do not need history any more. 692 if hist.dict != nil { 693 hist.dict.content = nil 694 } 695 } 696 697 if err := seqs.initialize(br, hist, literals, b.dst); err != nil { 698 println("initializing sequences:", err) 699 return err 700 } 701 702 err = seqs.decode(nSeqs, br, hbytes) 703 if err != nil { 704 return err 705 } 706 if !br.finished() { 707 return fmt.Errorf("%d extra bits on block, should be 0", br.remain()) 708 } 709 710 err = br.close() 711 if err != nil { 712 printf("Closing sequences: %v, %+v\n", err, *br) 713 } 714 if len(b.data) > maxCompressedBlockSize { 715 return fmt.Errorf("compressed block size too large (%d)", len(b.data)) 716 } 717 // Set output and release references. 718 b.dst = seqs.out 719 seqs.out, seqs.literals, seqs.hist = nil, nil, nil 720 721 if !delayedHistory { 722 // If we don't have delayed history, no need to update. 723 hist.recentOffsets = seqs.prevOffset 724 return nil 725 } 726 if b.Last { 727 // if last block we don't care about history. 728 println("Last block, no history returned") 729 hist.b = hist.b[:0] 730 return nil 731 } 732 hist.append(b.dst) 733 hist.recentOffsets = seqs.prevOffset 734 if debug { 735 println("Finished block with literals:", len(literals), "and", nSeqs, "sequences.") 736 } 737 738 return nil 739} 740