1// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import (
8	"errors"
9	"fmt"
10	"io"
11	"sync"
12
13	"github.com/klauspost/compress/huff0"
14	"github.com/klauspost/compress/zstd/internal/xxhash"
15)
16
17type blockType uint8
18
19//go:generate stringer -type=blockType,literalsBlockType,seqCompMode,tableIndex
20
21const (
22	blockTypeRaw blockType = iota
23	blockTypeRLE
24	blockTypeCompressed
25	blockTypeReserved
26)
27
28type literalsBlockType uint8
29
30const (
31	literalsBlockRaw literalsBlockType = iota
32	literalsBlockRLE
33	literalsBlockCompressed
34	literalsBlockTreeless
35)
36
37const (
38	// maxCompressedBlockSize is the biggest allowed compressed block size (128KB)
39	maxCompressedBlockSize = 128 << 10
40
41	// Maximum possible block size (all Raw+Uncompressed).
42	maxBlockSize = (1 << 21) - 1
43
44	// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#literals_section_header
45	maxCompressedLiteralSize = 1 << 18
46	maxRLELiteralSize        = 1 << 20
47	maxMatchLen              = 131074
48	maxSequences             = 0x7f00 + 0xffff
49
50	// We support slightly less than the reference decoder to be able to
51	// use ints on 32 bit archs.
52	maxOffsetBits = 30
53)
54
55var (
56	huffDecoderPool = sync.Pool{New: func() interface{} {
57		return &huff0.Scratch{}
58	}}
59
60	fseDecoderPool = sync.Pool{New: func() interface{} {
61		return &fseDecoder{}
62	}}
63)
64
65type blockDec struct {
66	// Raw source data of the block.
67	data        []byte
68	dataStorage []byte
69
70	// Destination of the decoded data.
71	dst []byte
72
73	// Buffer for literals data.
74	literalBuf []byte
75
76	// Window size of the block.
77	WindowSize uint64
78
79	history     chan *history
80	input       chan struct{}
81	result      chan decodeOutput
82	sequenceBuf []seq
83	err         error
84	decWG       sync.WaitGroup
85
86	// Frame to use for singlethreaded decoding.
87	// Should not be used by the decoder itself since parent may be another frame.
88	localFrame *frameDec
89
90	// Block is RLE, this is the size.
91	RLESize uint32
92	tmp     [4]byte
93
94	Type blockType
95
96	// Is this the last block of a frame?
97	Last bool
98
99	// Use less memory
100	lowMem bool
101}
102
103func (b *blockDec) String() string {
104	if b == nil {
105		return "<nil>"
106	}
107	return fmt.Sprintf("Steam Size: %d, Type: %v, Last: %t, Window: %d", len(b.data), b.Type, b.Last, b.WindowSize)
108}
109
110func newBlockDec(lowMem bool) *blockDec {
111	b := blockDec{
112		lowMem:  lowMem,
113		result:  make(chan decodeOutput, 1),
114		input:   make(chan struct{}, 1),
115		history: make(chan *history, 1),
116	}
117	b.decWG.Add(1)
118	go b.startDecoder()
119	return &b
120}
121
122// reset will reset the block.
123// Input must be a start of a block and will be at the end of the block when returned.
124func (b *blockDec) reset(br byteBuffer, windowSize uint64) error {
125	b.WindowSize = windowSize
126	tmp := br.readSmall(3)
127	if tmp == nil {
128		if debug {
129			println("Reading block header:", io.ErrUnexpectedEOF)
130		}
131		return io.ErrUnexpectedEOF
132	}
133	bh := uint32(tmp[0]) | (uint32(tmp[1]) << 8) | (uint32(tmp[2]) << 16)
134	b.Last = bh&1 != 0
135	b.Type = blockType((bh >> 1) & 3)
136	// find size.
137	cSize := int(bh >> 3)
138	maxSize := maxBlockSize
139	switch b.Type {
140	case blockTypeReserved:
141		return ErrReservedBlockType
142	case blockTypeRLE:
143		b.RLESize = uint32(cSize)
144		if b.lowMem {
145			maxSize = cSize
146		}
147		cSize = 1
148	case blockTypeCompressed:
149		if debug {
150			println("Data size on stream:", cSize)
151		}
152		b.RLESize = 0
153		maxSize = maxCompressedBlockSize
154		if windowSize < maxCompressedBlockSize && b.lowMem {
155			maxSize = int(windowSize)
156		}
157		if cSize > maxCompressedBlockSize || uint64(cSize) > b.WindowSize {
158			if debug {
159				printf("compressed block too big: csize:%d block: %+v\n", uint64(cSize), b)
160			}
161			return ErrCompressedSizeTooBig
162		}
163	case blockTypeRaw:
164		b.RLESize = 0
165		// We do not need a destination for raw blocks.
166		maxSize = -1
167	default:
168		panic("Invalid block type")
169	}
170
171	// Read block data.
172	if cap(b.dataStorage) < cSize {
173		if b.lowMem {
174			b.dataStorage = make([]byte, 0, cSize)
175		} else {
176			b.dataStorage = make([]byte, 0, maxBlockSize)
177		}
178	}
179	if cap(b.dst) <= maxSize {
180		b.dst = make([]byte, 0, maxSize+1)
181	}
182	var err error
183	b.data, err = br.readBig(cSize, b.dataStorage)
184	if err != nil {
185		if debug {
186			println("Reading block:", err, "(", cSize, ")", len(b.data))
187			printf("%T", br)
188		}
189		return err
190	}
191	return nil
192}
193
194// sendEOF will make the decoder send EOF on this frame.
195func (b *blockDec) sendErr(err error) {
196	b.Last = true
197	b.Type = blockTypeReserved
198	b.err = err
199	b.input <- struct{}{}
200}
201
202// Close will release resources.
203// Closed blockDec cannot be reset.
204func (b *blockDec) Close() {
205	close(b.input)
206	close(b.history)
207	close(b.result)
208	b.decWG.Wait()
209}
210
211// decodeAsync will prepare decoding the block when it receives input.
212// This will separate output and history.
213func (b *blockDec) startDecoder() {
214	defer b.decWG.Done()
215	for range b.input {
216		//println("blockDec: Got block input")
217		switch b.Type {
218		case blockTypeRLE:
219			if cap(b.dst) < int(b.RLESize) {
220				if b.lowMem {
221					b.dst = make([]byte, b.RLESize)
222				} else {
223					b.dst = make([]byte, maxBlockSize)
224				}
225			}
226			o := decodeOutput{
227				d:   b,
228				b:   b.dst[:b.RLESize],
229				err: nil,
230			}
231			v := b.data[0]
232			for i := range o.b {
233				o.b[i] = v
234			}
235			hist := <-b.history
236			hist.append(o.b)
237			b.result <- o
238		case blockTypeRaw:
239			o := decodeOutput{
240				d:   b,
241				b:   b.data,
242				err: nil,
243			}
244			hist := <-b.history
245			hist.append(o.b)
246			b.result <- o
247		case blockTypeCompressed:
248			b.dst = b.dst[:0]
249			err := b.decodeCompressed(nil)
250			o := decodeOutput{
251				d:   b,
252				b:   b.dst,
253				err: err,
254			}
255			if debug {
256				println("Decompressed to", len(b.dst), "bytes, error:", err)
257			}
258			b.result <- o
259		case blockTypeReserved:
260			// Used for returning errors.
261			<-b.history
262			b.result <- decodeOutput{
263				d:   b,
264				b:   nil,
265				err: b.err,
266			}
267		default:
268			panic("Invalid block type")
269		}
270		if debug {
271			println("blockDec: Finished block")
272		}
273	}
274}
275
276// decodeAsync will prepare decoding the block when it receives the history.
277// If history is provided, it will not fetch it from the channel.
278func (b *blockDec) decodeBuf(hist *history) error {
279	switch b.Type {
280	case blockTypeRLE:
281		if cap(b.dst) < int(b.RLESize) {
282			if b.lowMem {
283				b.dst = make([]byte, b.RLESize)
284			} else {
285				b.dst = make([]byte, maxBlockSize)
286			}
287		}
288		b.dst = b.dst[:b.RLESize]
289		v := b.data[0]
290		for i := range b.dst {
291			b.dst[i] = v
292		}
293		hist.appendKeep(b.dst)
294		return nil
295	case blockTypeRaw:
296		hist.appendKeep(b.data)
297		return nil
298	case blockTypeCompressed:
299		saved := b.dst
300		b.dst = hist.b
301		hist.b = nil
302		err := b.decodeCompressed(hist)
303		if debug {
304			println("Decompressed to total", len(b.dst), "bytes, hash:", xxhash.Sum64(b.dst), "error:", err)
305		}
306		hist.b = b.dst
307		b.dst = saved
308		return err
309	case blockTypeReserved:
310		// Used for returning errors.
311		return b.err
312	default:
313		panic("Invalid block type")
314	}
315}
316
317// decodeCompressed will start decompressing a block.
318// If no history is supplied the decoder will decodeAsync as much as possible
319// before fetching from blockDec.history
320func (b *blockDec) decodeCompressed(hist *history) error {
321	in := b.data
322	delayedHistory := hist == nil
323
324	if delayedHistory {
325		// We must always grab history.
326		defer func() {
327			if hist == nil {
328				<-b.history
329			}
330		}()
331	}
332	// There must be at least one byte for Literals_Block_Type and one for Sequences_Section_Header
333	if len(in) < 2 {
334		return ErrBlockTooSmall
335	}
336	litType := literalsBlockType(in[0] & 3)
337	var litRegenSize int
338	var litCompSize int
339	sizeFormat := (in[0] >> 2) & 3
340	var fourStreams bool
341	switch litType {
342	case literalsBlockRaw, literalsBlockRLE:
343		switch sizeFormat {
344		case 0, 2:
345			// Regenerated_Size uses 5 bits (0-31). Literals_Section_Header uses 1 byte.
346			litRegenSize = int(in[0] >> 3)
347			in = in[1:]
348		case 1:
349			// Regenerated_Size uses 12 bits (0-4095). Literals_Section_Header uses 2 bytes.
350			litRegenSize = int(in[0]>>4) + (int(in[1]) << 4)
351			in = in[2:]
352		case 3:
353			//  Regenerated_Size uses 20 bits (0-1048575). Literals_Section_Header uses 3 bytes.
354			if len(in) < 3 {
355				println("too small: litType:", litType, " sizeFormat", sizeFormat, len(in))
356				return ErrBlockTooSmall
357			}
358			litRegenSize = int(in[0]>>4) + (int(in[1]) << 4) + (int(in[2]) << 12)
359			in = in[3:]
360		}
361	case literalsBlockCompressed, literalsBlockTreeless:
362		switch sizeFormat {
363		case 0, 1:
364			// Both Regenerated_Size and Compressed_Size use 10 bits (0-1023).
365			if len(in) < 3 {
366				println("too small: litType:", litType, " sizeFormat", sizeFormat, len(in))
367				return ErrBlockTooSmall
368			}
369			n := uint64(in[0]>>4) + (uint64(in[1]) << 4) + (uint64(in[2]) << 12)
370			litRegenSize = int(n & 1023)
371			litCompSize = int(n >> 10)
372			fourStreams = sizeFormat == 1
373			in = in[3:]
374		case 2:
375			fourStreams = true
376			if len(in) < 4 {
377				println("too small: litType:", litType, " sizeFormat", sizeFormat, len(in))
378				return ErrBlockTooSmall
379			}
380			n := uint64(in[0]>>4) + (uint64(in[1]) << 4) + (uint64(in[2]) << 12) + (uint64(in[3]) << 20)
381			litRegenSize = int(n & 16383)
382			litCompSize = int(n >> 14)
383			in = in[4:]
384		case 3:
385			fourStreams = true
386			if len(in) < 5 {
387				println("too small: litType:", litType, " sizeFormat", sizeFormat, len(in))
388				return ErrBlockTooSmall
389			}
390			n := uint64(in[0]>>4) + (uint64(in[1]) << 4) + (uint64(in[2]) << 12) + (uint64(in[3]) << 20) + (uint64(in[4]) << 28)
391			litRegenSize = int(n & 262143)
392			litCompSize = int(n >> 18)
393			in = in[5:]
394		}
395	}
396	if debug {
397		println("literals type:", litType, "litRegenSize:", litRegenSize, "litCompSize:", litCompSize, "sizeFormat:", sizeFormat, "4X:", fourStreams)
398	}
399	var literals []byte
400	var huff *huff0.Scratch
401	switch litType {
402	case literalsBlockRaw:
403		if len(in) < litRegenSize {
404			println("too small: litType:", litType, " sizeFormat", sizeFormat, "remain:", len(in), "want:", litRegenSize)
405			return ErrBlockTooSmall
406		}
407		literals = in[:litRegenSize]
408		in = in[litRegenSize:]
409		//printf("Found %d uncompressed literals\n", litRegenSize)
410	case literalsBlockRLE:
411		if len(in) < 1 {
412			println("too small: litType:", litType, " sizeFormat", sizeFormat, "remain:", len(in), "want:", 1)
413			return ErrBlockTooSmall
414		}
415		if cap(b.literalBuf) < litRegenSize {
416			if b.lowMem {
417				b.literalBuf = make([]byte, litRegenSize)
418			} else {
419				if litRegenSize > maxCompressedLiteralSize {
420					// Exceptional
421					b.literalBuf = make([]byte, litRegenSize)
422				} else {
423					b.literalBuf = make([]byte, litRegenSize, maxCompressedLiteralSize)
424
425				}
426			}
427		}
428		literals = b.literalBuf[:litRegenSize]
429		v := in[0]
430		for i := range literals {
431			literals[i] = v
432		}
433		in = in[1:]
434		if debug {
435			printf("Found %d RLE compressed literals\n", litRegenSize)
436		}
437	case literalsBlockTreeless:
438		if len(in) < litCompSize {
439			println("too small: litType:", litType, " sizeFormat", sizeFormat, "remain:", len(in), "want:", litCompSize)
440			return ErrBlockTooSmall
441		}
442		// Store compressed literals, so we defer decoding until we get history.
443		literals = in[:litCompSize]
444		in = in[litCompSize:]
445		if debug {
446			printf("Found %d compressed literals\n", litCompSize)
447		}
448	case literalsBlockCompressed:
449		if len(in) < litCompSize {
450			println("too small: litType:", litType, " sizeFormat", sizeFormat, "remain:", len(in), "want:", litCompSize)
451			return ErrBlockTooSmall
452		}
453		literals = in[:litCompSize]
454		in = in[litCompSize:]
455		huff = huffDecoderPool.Get().(*huff0.Scratch)
456		var err error
457		// Ensure we have space to store it.
458		if cap(b.literalBuf) < litRegenSize {
459			if b.lowMem {
460				b.literalBuf = make([]byte, 0, litRegenSize)
461			} else {
462				b.literalBuf = make([]byte, 0, maxCompressedLiteralSize)
463			}
464		}
465		if huff == nil {
466			huff = &huff0.Scratch{}
467		}
468		huff, literals, err = huff0.ReadTable(literals, huff)
469		if err != nil {
470			println("reading huffman table:", err)
471			return err
472		}
473		// Use our out buffer.
474		if fourStreams {
475			literals, err = huff.Decoder().Decompress4X(b.literalBuf[:0:litRegenSize], literals)
476		} else {
477			literals, err = huff.Decoder().Decompress1X(b.literalBuf[:0:litRegenSize], literals)
478		}
479		if err != nil {
480			println("decoding compressed literals:", err)
481			return err
482		}
483		// Make sure we don't leak our literals buffer
484		if len(literals) != litRegenSize {
485			return fmt.Errorf("literal output size mismatch want %d, got %d", litRegenSize, len(literals))
486		}
487		if debug {
488			printf("Decompressed %d literals into %d bytes\n", litCompSize, litRegenSize)
489		}
490	}
491
492	// Decode Sequences
493	// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#sequences-section
494	if len(in) < 1 {
495		return ErrBlockTooSmall
496	}
497	seqHeader := in[0]
498	nSeqs := 0
499	switch {
500	case seqHeader == 0:
501		in = in[1:]
502	case seqHeader < 128:
503		nSeqs = int(seqHeader)
504		in = in[1:]
505	case seqHeader < 255:
506		if len(in) < 2 {
507			return ErrBlockTooSmall
508		}
509		nSeqs = int(seqHeader-128)<<8 | int(in[1])
510		in = in[2:]
511	case seqHeader == 255:
512		if len(in) < 3 {
513			return ErrBlockTooSmall
514		}
515		nSeqs = 0x7f00 + int(in[1]) + (int(in[2]) << 8)
516		in = in[3:]
517	}
518	// Allocate sequences
519	if cap(b.sequenceBuf) < nSeqs {
520		if b.lowMem {
521			b.sequenceBuf = make([]seq, nSeqs)
522		} else {
523			// Allocate max
524			b.sequenceBuf = make([]seq, nSeqs, maxSequences)
525		}
526	} else {
527		// Reuse buffer
528		b.sequenceBuf = b.sequenceBuf[:nSeqs]
529	}
530	var seqs = &sequenceDecs{}
531	if nSeqs > 0 {
532		if len(in) < 1 {
533			return ErrBlockTooSmall
534		}
535		br := byteReader{b: in, off: 0}
536		compMode := br.Uint8()
537		br.advance(1)
538		if debug {
539			printf("Compression modes: 0b%b", compMode)
540		}
541		for i := uint(0); i < 3; i++ {
542			mode := seqCompMode((compMode >> (6 - i*2)) & 3)
543			if debug {
544				println("Table", tableIndex(i), "is", mode)
545			}
546			var seq *sequenceDec
547			switch tableIndex(i) {
548			case tableLiteralLengths:
549				seq = &seqs.litLengths
550			case tableOffsets:
551				seq = &seqs.offsets
552			case tableMatchLengths:
553				seq = &seqs.matchLengths
554			default:
555				panic("unknown table")
556			}
557			switch mode {
558			case compModePredefined:
559				seq.fse = &fsePredef[i]
560			case compModeRLE:
561				if br.remain() < 1 {
562					return ErrBlockTooSmall
563				}
564				v := br.Uint8()
565				br.advance(1)
566				dec := fseDecoderPool.Get().(*fseDecoder)
567				symb, err := decSymbolValue(v, symbolTableX[i])
568				if err != nil {
569					printf("RLE Transform table (%v) error: %v", tableIndex(i), err)
570					return err
571				}
572				dec.setRLE(symb)
573				seq.fse = dec
574				if debug {
575					printf("RLE set to %+v, code: %v", symb, v)
576				}
577			case compModeFSE:
578				println("Reading table for", tableIndex(i))
579				dec := fseDecoderPool.Get().(*fseDecoder)
580				err := dec.readNCount(&br, uint16(maxTableSymbol[i]))
581				if err != nil {
582					println("Read table error:", err)
583					return err
584				}
585				err = dec.transform(symbolTableX[i])
586				if err != nil {
587					println("Transform table error:", err)
588					return err
589				}
590				if debug {
591					println("Read table ok", "symbolLen:", dec.symbolLen)
592				}
593				seq.fse = dec
594			case compModeRepeat:
595				seq.repeat = true
596			}
597			if br.overread() {
598				return io.ErrUnexpectedEOF
599			}
600		}
601		in = br.unread()
602	}
603
604	// Wait for history.
605	// All time spent after this is critical since it is strictly sequential.
606	if hist == nil {
607		hist = <-b.history
608		if hist.error {
609			return ErrDecoderClosed
610		}
611	}
612
613	// Decode treeless literal block.
614	if litType == literalsBlockTreeless {
615		// TODO: We could send the history early WITHOUT the stream history.
616		//   This would allow decoding treeless literals before the byte history is available.
617		//   Silencia stats: Treeless 4393, with: 32775, total: 37168, 11% treeless.
618		//   So not much obvious gain here.
619
620		if hist.huffTree == nil {
621			return errors.New("literal block was treeless, but no history was defined")
622		}
623		// Ensure we have space to store it.
624		if cap(b.literalBuf) < litRegenSize {
625			if b.lowMem {
626				b.literalBuf = make([]byte, 0, litRegenSize)
627			} else {
628				b.literalBuf = make([]byte, 0, maxCompressedLiteralSize)
629			}
630		}
631		var err error
632		// Use our out buffer.
633		huff = hist.huffTree
634		if fourStreams {
635			literals, err = huff.Decoder().Decompress4X(b.literalBuf[:0:litRegenSize], literals)
636		} else {
637			literals, err = huff.Decoder().Decompress1X(b.literalBuf[:0:litRegenSize], literals)
638		}
639		// Make sure we don't leak our literals buffer
640		if err != nil {
641			println("decompressing literals:", err)
642			return err
643		}
644		if len(literals) != litRegenSize {
645			return fmt.Errorf("literal output size mismatch want %d, got %d", litRegenSize, len(literals))
646		}
647	} else {
648		if hist.huffTree != nil && huff != nil {
649			if hist.dict == nil || hist.dict.litEnc != hist.huffTree {
650				huffDecoderPool.Put(hist.huffTree)
651			}
652			hist.huffTree = nil
653		}
654	}
655	if huff != nil {
656		hist.huffTree = huff
657	}
658	if debug {
659		println("Final literals:", len(literals), "hash:", xxhash.Sum64(literals), "and", nSeqs, "sequences.")
660	}
661
662	if nSeqs == 0 {
663		// Decompressed content is defined entirely as Literals Section content.
664		b.dst = append(b.dst, literals...)
665		if delayedHistory {
666			hist.append(literals)
667		}
668		return nil
669	}
670
671	seqs, err := seqs.mergeHistory(&hist.decoders)
672	if err != nil {
673		return err
674	}
675	if debug {
676		println("History merged ok")
677	}
678	br := &bitReader{}
679	if err := br.init(in); err != nil {
680		return err
681	}
682
683	// TODO: Investigate if sending history without decoders are faster.
684	//   This would allow the sequences to be decoded async and only have to construct stream history.
685	//   If only recent offsets were not transferred, this would be an obvious win.
686	// 	 Also, if first 3 sequences don't reference recent offsets, all sequences can be decoded.
687
688	hbytes := hist.b
689	if len(hbytes) > hist.windowSize {
690		hbytes = hbytes[len(hbytes)-hist.windowSize:]
691		// We do not need history any more.
692		if hist.dict != nil {
693			hist.dict.content = nil
694		}
695	}
696
697	if err := seqs.initialize(br, hist, literals, b.dst); err != nil {
698		println("initializing sequences:", err)
699		return err
700	}
701
702	err = seqs.decode(nSeqs, br, hbytes)
703	if err != nil {
704		return err
705	}
706	if !br.finished() {
707		return fmt.Errorf("%d extra bits on block, should be 0", br.remain())
708	}
709
710	err = br.close()
711	if err != nil {
712		printf("Closing sequences: %v, %+v\n", err, *br)
713	}
714	if len(b.data) > maxCompressedBlockSize {
715		return fmt.Errorf("compressed block size too large (%d)", len(b.data))
716	}
717	// Set output and release references.
718	b.dst = seqs.out
719	seqs.out, seqs.literals, seqs.hist = nil, nil, nil
720
721	if !delayedHistory {
722		// If we don't have delayed history, no need to update.
723		hist.recentOffsets = seqs.prevOffset
724		return nil
725	}
726	if b.Last {
727		// if last block we don't care about history.
728		println("Last block, no history returned")
729		hist.b = hist.b[:0]
730		return nil
731	}
732	hist.append(b.dst)
733	hist.recentOffsets = seqs.prevOffset
734	if debug {
735		println("Finished block with literals:", len(literals), "and", nSeqs, "sequences.")
736	}
737
738	return nil
739}
740