1// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import (
8	"bytes"
9	"encoding/hex"
10	"errors"
11	"hash"
12	"io"
13	"sync"
14
15	"github.com/klauspost/compress/zstd/internal/xxhash"
16)
17
18type frameDec struct {
19	o      decoderOptions
20	crc    hash.Hash64
21	offset int64
22
23	WindowSize uint64
24
25	// maxWindowSize is the maximum windows size to support.
26	// should never be bigger than max-int.
27	maxWindowSize uint64
28
29	// In order queue of blocks being decoded.
30	decoding chan *blockDec
31
32	// Frame history passed between blocks
33	history history
34
35	rawInput byteBuffer
36
37	// Byte buffer that can be reused for small input blocks.
38	bBuf byteBuf
39
40	FrameContentSize uint64
41	frameDone        sync.WaitGroup
42
43	DictionaryID  *uint32
44	HasCheckSum   bool
45	SingleSegment bool
46
47	// asyncRunning indicates whether the async routine processes input on 'decoding'.
48	asyncRunningMu sync.Mutex
49	asyncRunning   bool
50}
51
52const (
53	// The minimum Window_Size is 1 KB.
54	MinWindowSize = 1 << 10
55	MaxWindowSize = 1 << 29
56)
57
58var (
59	frameMagic          = []byte{0x28, 0xb5, 0x2f, 0xfd}
60	skippableFrameMagic = []byte{0x2a, 0x4d, 0x18}
61)
62
63func newFrameDec(o decoderOptions) *frameDec {
64	d := frameDec{
65		o:             o,
66		maxWindowSize: MaxWindowSize,
67	}
68	if d.maxWindowSize > o.maxDecodedSize {
69		d.maxWindowSize = o.maxDecodedSize
70	}
71	return &d
72}
73
74// reset will read the frame header and prepare for block decoding.
75// If nothing can be read from the input, io.EOF will be returned.
76// Any other error indicated that the stream contained data, but
77// there was a problem.
78func (d *frameDec) reset(br byteBuffer) error {
79	d.HasCheckSum = false
80	d.WindowSize = 0
81	var b []byte
82	for {
83		b = br.readSmall(4)
84		if b == nil {
85			return io.EOF
86		}
87		if !bytes.Equal(b[1:4], skippableFrameMagic) || b[0]&0xf0 != 0x50 {
88			if debug {
89				println("Not skippable", hex.EncodeToString(b), hex.EncodeToString(skippableFrameMagic))
90			}
91			// Break if not skippable frame.
92			break
93		}
94		// Read size to skip
95		b = br.readSmall(4)
96		if b == nil {
97			println("Reading Frame Size EOF")
98			return io.ErrUnexpectedEOF
99		}
100		n := uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
101		println("Skipping frame with", n, "bytes.")
102		err := br.skipN(int(n))
103		if err != nil {
104			if debug {
105				println("Reading discarded frame", err)
106			}
107			return err
108		}
109	}
110	if !bytes.Equal(b, frameMagic) {
111		println("Got magic numbers: ", b, "want:", frameMagic)
112		return ErrMagicMismatch
113	}
114
115	// Read Frame_Header_Descriptor
116	fhd, err := br.readByte()
117	if err != nil {
118		println("Reading Frame_Header_Descriptor", err)
119		return err
120	}
121	d.SingleSegment = fhd&(1<<5) != 0
122
123	if fhd&(1<<3) != 0 {
124		return errors.New("Reserved bit set on frame header")
125	}
126
127	// Read Window_Descriptor
128	// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#window_descriptor
129	d.WindowSize = 0
130	if !d.SingleSegment {
131		wd, err := br.readByte()
132		if err != nil {
133			println("Reading Window_Descriptor", err)
134			return err
135		}
136		printf("raw: %x, mantissa: %d, exponent: %d\n", wd, wd&7, wd>>3)
137		windowLog := 10 + (wd >> 3)
138		windowBase := uint64(1) << windowLog
139		windowAdd := (windowBase / 8) * uint64(wd&0x7)
140		d.WindowSize = windowBase + windowAdd
141	}
142
143	// Read Dictionary_ID
144	// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#dictionary_id
145	d.DictionaryID = nil
146	if size := fhd & 3; size != 0 {
147		if size == 3 {
148			size = 4
149		}
150		b = br.readSmall(int(size))
151		if b == nil {
152			if debug {
153				println("Reading Dictionary_ID", io.ErrUnexpectedEOF)
154			}
155			return io.ErrUnexpectedEOF
156		}
157		var id uint32
158		switch size {
159		case 1:
160			id = uint32(b[0])
161		case 2:
162			id = uint32(b[0]) | (uint32(b[1]) << 8)
163		case 4:
164			id = uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
165		}
166		if debug {
167			println("Dict size", size, "ID:", id)
168		}
169		if id > 0 {
170			// ID 0 means "sorry, no dictionary anyway".
171			// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#dictionary-format
172			d.DictionaryID = &id
173		}
174	}
175
176	// Read Frame_Content_Size
177	// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#frame_content_size
178	var fcsSize int
179	v := fhd >> 6
180	switch v {
181	case 0:
182		if d.SingleSegment {
183			fcsSize = 1
184		}
185	default:
186		fcsSize = 1 << v
187	}
188	d.FrameContentSize = 0
189	if fcsSize > 0 {
190		b := br.readSmall(fcsSize)
191		if b == nil {
192			println("Reading Frame content", io.ErrUnexpectedEOF)
193			return io.ErrUnexpectedEOF
194		}
195		switch fcsSize {
196		case 1:
197			d.FrameContentSize = uint64(b[0])
198		case 2:
199			// When FCS_Field_Size is 2, the offset of 256 is added.
200			d.FrameContentSize = uint64(b[0]) | (uint64(b[1]) << 8) + 256
201		case 4:
202			d.FrameContentSize = uint64(b[0]) | (uint64(b[1]) << 8) | (uint64(b[2]) << 16) | (uint64(b[3]) << 24)
203		case 8:
204			d1 := uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
205			d2 := uint32(b[4]) | (uint32(b[5]) << 8) | (uint32(b[6]) << 16) | (uint32(b[7]) << 24)
206			d.FrameContentSize = uint64(d1) | (uint64(d2) << 32)
207		}
208		if debug {
209			println("field size bits:", v, "fcsSize:", fcsSize, "FrameContentSize:", d.FrameContentSize, hex.EncodeToString(b[:fcsSize]), "singleseg:", d.SingleSegment, "window:", d.WindowSize)
210		}
211	}
212	// Move this to shared.
213	d.HasCheckSum = fhd&(1<<2) != 0
214	if d.HasCheckSum {
215		if d.crc == nil {
216			d.crc = xxhash.New()
217		}
218		d.crc.Reset()
219	}
220
221	if d.WindowSize == 0 && d.SingleSegment {
222		// We may not need window in this case.
223		d.WindowSize = d.FrameContentSize
224		if d.WindowSize < MinWindowSize {
225			d.WindowSize = MinWindowSize
226		}
227	}
228
229	if d.WindowSize > d.maxWindowSize {
230		printf("window size %d > max %d\n", d.WindowSize, d.maxWindowSize)
231		return ErrWindowSizeExceeded
232	}
233	// The minimum Window_Size is 1 KB.
234	if d.WindowSize < MinWindowSize {
235		println("got window size: ", d.WindowSize)
236		return ErrWindowSizeTooSmall
237	}
238	d.history.windowSize = int(d.WindowSize)
239	if d.o.lowMem && d.history.windowSize < maxBlockSize {
240		d.history.maxSize = d.history.windowSize * 2
241	} else {
242		d.history.maxSize = d.history.windowSize + maxBlockSize
243	}
244	// history contains input - maybe we do something
245	d.rawInput = br
246	return nil
247}
248
249// next will start decoding the next block from stream.
250func (d *frameDec) next(block *blockDec) error {
251	if debug {
252		printf("decoding new block %p:%p", block, block.data)
253	}
254	err := block.reset(d.rawInput, d.WindowSize)
255	if err != nil {
256		println("block error:", err)
257		// Signal the frame decoder we have a problem.
258		d.sendErr(block, err)
259		return err
260	}
261	block.input <- struct{}{}
262	if debug {
263		println("next block:", block)
264	}
265	d.asyncRunningMu.Lock()
266	defer d.asyncRunningMu.Unlock()
267	if !d.asyncRunning {
268		return nil
269	}
270	if block.Last {
271		// We indicate the frame is done by sending io.EOF
272		d.decoding <- block
273		return io.EOF
274	}
275	d.decoding <- block
276	return nil
277}
278
279// sendEOF will queue an error block on the frame.
280// This will cause the frame decoder to return when it encounters the block.
281// Returns true if the decoder was added.
282func (d *frameDec) sendErr(block *blockDec, err error) bool {
283	d.asyncRunningMu.Lock()
284	defer d.asyncRunningMu.Unlock()
285	if !d.asyncRunning {
286		return false
287	}
288
289	println("sending error", err.Error())
290	block.sendErr(err)
291	d.decoding <- block
292	return true
293}
294
295// checkCRC will check the checksum if the frame has one.
296// Will return ErrCRCMismatch if crc check failed, otherwise nil.
297func (d *frameDec) checkCRC() error {
298	if !d.HasCheckSum {
299		return nil
300	}
301	var tmp [4]byte
302	got := d.crc.Sum64()
303	// Flip to match file order.
304	tmp[0] = byte(got >> 0)
305	tmp[1] = byte(got >> 8)
306	tmp[2] = byte(got >> 16)
307	tmp[3] = byte(got >> 24)
308
309	// We can overwrite upper tmp now
310	want := d.rawInput.readSmall(4)
311	if want == nil {
312		println("CRC missing?")
313		return io.ErrUnexpectedEOF
314	}
315
316	if !bytes.Equal(tmp[:], want) {
317		if debug {
318			println("CRC Check Failed:", tmp[:], "!=", want)
319		}
320		return ErrCRCMismatch
321	}
322	if debug {
323		println("CRC ok", tmp[:])
324	}
325	return nil
326}
327
328func (d *frameDec) initAsync() {
329	if !d.o.lowMem && !d.SingleSegment {
330		// set max extra size history to 10MB.
331		d.history.maxSize = d.history.windowSize + maxBlockSize*5
332	}
333	// re-alloc if more than one extra block size.
334	if d.o.lowMem && cap(d.history.b) > d.history.maxSize+maxBlockSize {
335		d.history.b = make([]byte, 0, d.history.maxSize)
336	}
337	if cap(d.history.b) < d.history.maxSize {
338		d.history.b = make([]byte, 0, d.history.maxSize)
339	}
340	if cap(d.decoding) < d.o.concurrent {
341		d.decoding = make(chan *blockDec, d.o.concurrent)
342	}
343	if debug {
344		h := d.history
345		printf("history init. len: %d, cap: %d", len(h.b), cap(h.b))
346	}
347	d.asyncRunningMu.Lock()
348	d.asyncRunning = true
349	d.asyncRunningMu.Unlock()
350}
351
352// startDecoder will start decoding blocks and write them to the writer.
353// The decoder will stop as soon as an error occurs or at end of frame.
354// When the frame has finished decoding the *bufio.Reader
355// containing the remaining input will be sent on frameDec.frameDone.
356func (d *frameDec) startDecoder(output chan decodeOutput) {
357	written := int64(0)
358
359	defer func() {
360		d.asyncRunningMu.Lock()
361		d.asyncRunning = false
362		d.asyncRunningMu.Unlock()
363
364		// Drain the currently decoding.
365		d.history.error = true
366	flushdone:
367		for {
368			select {
369			case b := <-d.decoding:
370				b.history <- &d.history
371				output <- <-b.result
372			default:
373				break flushdone
374			}
375		}
376		println("frame decoder done, signalling done")
377		d.frameDone.Done()
378	}()
379	// Get decoder for first block.
380	block := <-d.decoding
381	block.history <- &d.history
382	for {
383		var next *blockDec
384		// Get result
385		r := <-block.result
386		if r.err != nil {
387			println("Result contained error", r.err)
388			output <- r
389			return
390		}
391		if debug {
392			println("got result, from ", d.offset, "to", d.offset+int64(len(r.b)))
393			d.offset += int64(len(r.b))
394		}
395		if !block.Last {
396			// Send history to next block
397			select {
398			case next = <-d.decoding:
399				if debug {
400					println("Sending ", len(d.history.b), "bytes as history")
401				}
402				next.history <- &d.history
403			default:
404				// Wait until we have sent the block, so
405				// other decoders can potentially get the decoder.
406				next = nil
407			}
408		}
409
410		// Add checksum, async to decoding.
411		if d.HasCheckSum {
412			n, err := d.crc.Write(r.b)
413			if err != nil {
414				r.err = err
415				if n != len(r.b) {
416					r.err = io.ErrShortWrite
417				}
418				output <- r
419				return
420			}
421		}
422		written += int64(len(r.b))
423		if d.SingleSegment && uint64(written) > d.FrameContentSize {
424			println("runDecoder: single segment and", uint64(written), ">", d.FrameContentSize)
425			r.err = ErrFrameSizeExceeded
426			output <- r
427			return
428		}
429		if block.Last {
430			r.err = d.checkCRC()
431			output <- r
432			return
433		}
434		output <- r
435		if next == nil {
436			// There was no decoder available, we wait for one now that we have sent to the writer.
437			if debug {
438				println("Sending ", len(d.history.b), " bytes as history")
439			}
440			next = <-d.decoding
441			next.history <- &d.history
442		}
443		block = next
444	}
445}
446
447// runDecoder will create a sync decoder that will decode a block of data.
448func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) {
449	saved := d.history.b
450
451	// We use the history for output to avoid copying it.
452	d.history.b = dst
453	// Store input length, so we only check new data.
454	crcStart := len(dst)
455	var err error
456	for {
457		err = dec.reset(d.rawInput, d.WindowSize)
458		if err != nil {
459			break
460		}
461		if debug {
462			println("next block:", dec)
463		}
464		err = dec.decodeBuf(&d.history)
465		if err != nil || dec.Last {
466			break
467		}
468		if uint64(len(d.history.b)) > d.o.maxDecodedSize {
469			err = ErrDecoderSizeExceeded
470			break
471		}
472		if d.SingleSegment && uint64(len(d.history.b)) > d.o.maxDecodedSize {
473			println("runDecoder: single segment and", uint64(len(d.history.b)), ">", d.o.maxDecodedSize)
474			err = ErrFrameSizeExceeded
475			break
476		}
477	}
478	dst = d.history.b
479	if err == nil {
480		if d.HasCheckSum {
481			var n int
482			n, err = d.crc.Write(dst[crcStart:])
483			if err == nil {
484				if n != len(dst)-crcStart {
485					err = io.ErrShortWrite
486				} else {
487					err = d.checkCRC()
488				}
489			}
490		}
491	}
492	d.history.b = saved
493	return dst, err
494}
495