1// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import (
8	"bytes"
9	"errors"
10	"io"
11	"sync"
12)
13
14// Decoder provides decoding of zstandard streams.
15// The decoder has been designed to operate without allocations after a warmup.
16// This means that you should store the decoder for best performance.
17// To re-use a stream decoder, use the Reset(r io.Reader) error to switch to another stream.
18// A decoder can safely be re-used even if the previous stream failed.
19// To release the resources, you must call the Close() function on a decoder.
20type Decoder struct {
21	o decoderOptions
22
23	// Unreferenced decoders, ready for use.
24	decoders chan *blockDec
25
26	// Streams ready to be decoded.
27	stream chan decodeStream
28
29	// Current read position used for Reader functionality.
30	current decoderState
31
32	// Custom dictionaries.
33	// Always uses copies.
34	dicts map[uint32]dict
35
36	// streamWg is the waitgroup for all streams
37	streamWg sync.WaitGroup
38}
39
40// decoderState is used for maintaining state when the decoder
41// is used for streaming.
42type decoderState struct {
43	// current block being written to stream.
44	decodeOutput
45
46	// output in order to be written to stream.
47	output chan decodeOutput
48
49	// cancel remaining output.
50	cancel chan struct{}
51
52	flushed bool
53}
54
55var (
56	// Check the interfaces we want to support.
57	_ = io.WriterTo(&Decoder{})
58	_ = io.Reader(&Decoder{})
59)
60
61// NewReader creates a new decoder.
62// A nil Reader can be provided in which case Reset can be used to start a decode.
63//
64// A Decoder can be used in two modes:
65//
66// 1) As a stream, or
67// 2) For stateless decoding using DecodeAll.
68//
69// Only a single stream can be decoded concurrently, but the same decoder
70// can run multiple concurrent stateless decodes. It is even possible to
71// use stateless decodes while a stream is being decoded.
72//
73// The Reset function can be used to initiate a new stream, which is will considerably
74// reduce the allocations normally caused by NewReader.
75func NewReader(r io.Reader, opts ...DOption) (*Decoder, error) {
76	initPredefined()
77	var d Decoder
78	d.o.setDefault()
79	for _, o := range opts {
80		err := o(&d.o)
81		if err != nil {
82			return nil, err
83		}
84	}
85	d.current.output = make(chan decodeOutput, d.o.concurrent)
86	d.current.flushed = true
87
88	// Transfer option dicts.
89	d.dicts = make(map[uint32]dict, len(d.o.dicts))
90	for _, dc := range d.o.dicts {
91		d.dicts[dc.id] = dc
92	}
93	d.o.dicts = nil
94
95	// Create decoders
96	d.decoders = make(chan *blockDec, d.o.concurrent)
97	for i := 0; i < d.o.concurrent; i++ {
98		dec := newBlockDec(d.o.lowMem)
99		dec.localFrame = newFrameDec(d.o)
100		d.decoders <- dec
101	}
102
103	if r == nil {
104		return &d, nil
105	}
106	return &d, d.Reset(r)
107}
108
109// Read bytes from the decompressed stream into p.
110// Returns the number of bytes written and any error that occurred.
111// When the stream is done, io.EOF will be returned.
112func (d *Decoder) Read(p []byte) (int, error) {
113	if d.stream == nil {
114		return 0, errors.New("no input has been initialized")
115	}
116	var n int
117	for {
118		if len(d.current.b) > 0 {
119			filled := copy(p, d.current.b)
120			p = p[filled:]
121			d.current.b = d.current.b[filled:]
122			n += filled
123		}
124		if len(p) == 0 {
125			break
126		}
127		if len(d.current.b) == 0 {
128			// We have an error and no more data
129			if d.current.err != nil {
130				break
131			}
132			if !d.nextBlock(n == 0) {
133				return n, nil
134			}
135		}
136	}
137	if len(d.current.b) > 0 {
138		if debug {
139			println("returning", n, "still bytes left:", len(d.current.b))
140		}
141		// Only return error at end of block
142		return n, nil
143	}
144	if d.current.err != nil {
145		d.drainOutput()
146	}
147	if debug {
148		println("returning", n, d.current.err, len(d.decoders))
149	}
150	return n, d.current.err
151}
152
153// Reset will reset the decoder the supplied stream after the current has finished processing.
154// Note that this functionality cannot be used after Close has been called.
155func (d *Decoder) Reset(r io.Reader) error {
156	if d.current.err == ErrDecoderClosed {
157		return d.current.err
158	}
159	if r == nil {
160		return errors.New("nil Reader sent as input")
161	}
162
163	if d.stream == nil {
164		d.stream = make(chan decodeStream, 1)
165		d.streamWg.Add(1)
166		go d.startStreamDecoder(d.stream)
167	}
168
169	d.drainOutput()
170
171	// If bytes buffer and < 1MB, do sync decoding anyway.
172	if bb, ok := r.(*bytes.Buffer); ok && bb.Len() < 1<<20 {
173		if debug {
174			println("*bytes.Buffer detected, doing sync decode, len:", bb.Len())
175		}
176		b := bb.Bytes()
177		var dst []byte
178		if cap(d.current.b) > 0 {
179			dst = d.current.b
180		}
181
182		dst, err := d.DecodeAll(b, dst[:0])
183		if err == nil {
184			err = io.EOF
185		}
186		d.current.b = dst
187		d.current.err = err
188		d.current.flushed = true
189		if debug {
190			println("sync decode to", len(dst), "bytes, err:", err)
191		}
192		return nil
193	}
194
195	// Remove current block.
196	d.current.decodeOutput = decodeOutput{}
197	d.current.err = nil
198	d.current.cancel = make(chan struct{})
199	d.current.flushed = false
200	d.current.d = nil
201
202	d.stream <- decodeStream{
203		r:      r,
204		output: d.current.output,
205		cancel: d.current.cancel,
206	}
207	return nil
208}
209
210// drainOutput will drain the output until errEndOfStream is sent.
211func (d *Decoder) drainOutput() {
212	if d.current.cancel != nil {
213		println("cancelling current")
214		close(d.current.cancel)
215		d.current.cancel = nil
216	}
217	if d.current.d != nil {
218		if debug {
219			printf("re-adding current decoder %p, decoders: %d", d.current.d, len(d.decoders))
220		}
221		d.decoders <- d.current.d
222		d.current.d = nil
223		d.current.b = nil
224	}
225	if d.current.output == nil || d.current.flushed {
226		println("current already flushed")
227		return
228	}
229	for {
230		select {
231		case v := <-d.current.output:
232			if v.d != nil {
233				if debug {
234					printf("re-adding decoder %p", v.d)
235				}
236				d.decoders <- v.d
237			}
238			if v.err == errEndOfStream {
239				println("current flushed")
240				d.current.flushed = true
241				return
242			}
243		}
244	}
245}
246
247// WriteTo writes data to w until there's no more data to write or when an error occurs.
248// The return value n is the number of bytes written.
249// Any error encountered during the write is also returned.
250func (d *Decoder) WriteTo(w io.Writer) (int64, error) {
251	if d.stream == nil {
252		return 0, errors.New("no input has been initialized")
253	}
254	var n int64
255	for {
256		if len(d.current.b) > 0 {
257			n2, err2 := w.Write(d.current.b)
258			n += int64(n2)
259			if err2 != nil && d.current.err == nil {
260				d.current.err = err2
261				break
262			}
263		}
264		if d.current.err != nil {
265			break
266		}
267		d.nextBlock(true)
268	}
269	err := d.current.err
270	if err != nil {
271		d.drainOutput()
272	}
273	if err == io.EOF {
274		err = nil
275	}
276	return n, err
277}
278
279// DecodeAll allows stateless decoding of a blob of bytes.
280// Output will be appended to dst, so if the destination size is known
281// you can pre-allocate the destination slice to avoid allocations.
282// DecodeAll can be used concurrently.
283// The Decoder concurrency limits will be respected.
284func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) {
285	if d.current.err == ErrDecoderClosed {
286		return dst, ErrDecoderClosed
287	}
288
289	// Grab a block decoder and frame decoder.
290	block := <-d.decoders
291	frame := block.localFrame
292	defer func() {
293		if debug {
294			printf("re-adding decoder: %p", block)
295		}
296		frame.rawInput = nil
297		frame.bBuf = nil
298		d.decoders <- block
299	}()
300	frame.bBuf = input
301
302	for {
303		frame.history.reset()
304		err := frame.reset(&frame.bBuf)
305		if err == io.EOF {
306			if debug {
307				println("frame reset return EOF")
308			}
309			return dst, nil
310		}
311		if frame.DictionaryID != nil {
312			dict, ok := d.dicts[*frame.DictionaryID]
313			if !ok {
314				return nil, ErrUnknownDictionary
315			}
316			frame.history.setDict(&dict)
317		}
318		if err != nil {
319			return dst, err
320		}
321		if frame.FrameContentSize > d.o.maxDecodedSize-uint64(len(dst)) {
322			return dst, ErrDecoderSizeExceeded
323		}
324		if frame.FrameContentSize > 0 && frame.FrameContentSize < 1<<30 {
325			// Never preallocate moe than 1 GB up front.
326			if cap(dst)-len(dst) < int(frame.FrameContentSize) {
327				dst2 := make([]byte, len(dst), len(dst)+int(frame.FrameContentSize))
328				copy(dst2, dst)
329				dst = dst2
330			}
331		}
332		if cap(dst) == 0 {
333			// Allocate len(input) * 2 by default if nothing is provided
334			// and we didn't get frame content size.
335			size := len(input) * 2
336			// Cap to 1 MB.
337			if size > 1<<20 {
338				size = 1 << 20
339			}
340			if uint64(size) > d.o.maxDecodedSize {
341				size = int(d.o.maxDecodedSize)
342			}
343			dst = make([]byte, 0, size)
344		}
345
346		dst, err = frame.runDecoder(dst, block)
347		if err != nil {
348			return dst, err
349		}
350		if len(frame.bBuf) == 0 {
351			if debug {
352				println("frame dbuf empty")
353			}
354			break
355		}
356	}
357	return dst, nil
358}
359
360// nextBlock returns the next block.
361// If an error occurs d.err will be set.
362// Optionally the function can block for new output.
363// If non-blocking mode is used the returned boolean will be false
364// if no data was available without blocking.
365func (d *Decoder) nextBlock(blocking bool) (ok bool) {
366	if d.current.d != nil {
367		if debug {
368			printf("re-adding current decoder %p", d.current.d)
369		}
370		d.decoders <- d.current.d
371		d.current.d = nil
372	}
373	if d.current.err != nil {
374		// Keep error state.
375		return blocking
376	}
377
378	if blocking {
379		d.current.decodeOutput = <-d.current.output
380	} else {
381		select {
382		case d.current.decodeOutput = <-d.current.output:
383		default:
384			return false
385		}
386	}
387	if debug {
388		println("got", len(d.current.b), "bytes, error:", d.current.err)
389	}
390	return true
391}
392
393// Close will release all resources.
394// It is NOT possible to reuse the decoder after this.
395func (d *Decoder) Close() {
396	if d.current.err == ErrDecoderClosed {
397		return
398	}
399	d.drainOutput()
400	if d.stream != nil {
401		close(d.stream)
402		d.streamWg.Wait()
403		d.stream = nil
404	}
405	if d.decoders != nil {
406		close(d.decoders)
407		for dec := range d.decoders {
408			dec.Close()
409		}
410		d.decoders = nil
411	}
412	if d.current.d != nil {
413		d.current.d.Close()
414		d.current.d = nil
415	}
416	d.current.err = ErrDecoderClosed
417}
418
419// IOReadCloser returns the decoder as an io.ReadCloser for convenience.
420// Any changes to the decoder will be reflected, so the returned ReadCloser
421// can be reused along with the decoder.
422// io.WriterTo is also supported by the returned ReadCloser.
423func (d *Decoder) IOReadCloser() io.ReadCloser {
424	return closeWrapper{d: d}
425}
426
427// closeWrapper wraps a function call as a closer.
428type closeWrapper struct {
429	d *Decoder
430}
431
432// WriteTo forwards WriteTo calls to the decoder.
433func (c closeWrapper) WriteTo(w io.Writer) (n int64, err error) {
434	return c.d.WriteTo(w)
435}
436
437// Read forwards read calls to the decoder.
438func (c closeWrapper) Read(p []byte) (n int, err error) {
439	return c.d.Read(p)
440}
441
442// Close closes the decoder.
443func (c closeWrapper) Close() error {
444	c.d.Close()
445	return nil
446}
447
448type decodeOutput struct {
449	d   *blockDec
450	b   []byte
451	err error
452}
453
454type decodeStream struct {
455	r io.Reader
456
457	// Blocks ready to be written to output.
458	output chan decodeOutput
459
460	// cancel reading from the input
461	cancel chan struct{}
462}
463
464// errEndOfStream indicates that everything from the stream was read.
465var errEndOfStream = errors.New("end-of-stream")
466
467// Create Decoder:
468// Spawn n block decoders. These accept tasks to decode a block.
469// Create goroutine that handles stream processing, this will send history to decoders as they are available.
470// Decoders update the history as they decode.
471// When a block is returned:
472// 		a) history is sent to the next decoder,
473// 		b) content written to CRC.
474// 		c) return data to WRITER.
475// 		d) wait for next block to return data.
476// Once WRITTEN, the decoders reused by the writer frame decoder for re-use.
477func (d *Decoder) startStreamDecoder(inStream chan decodeStream) {
478	defer d.streamWg.Done()
479	frame := newFrameDec(d.o)
480	for stream := range inStream {
481		if debug {
482			println("got new stream")
483		}
484		br := readerWrapper{r: stream.r}
485	decodeStream:
486		for {
487			frame.history.reset()
488			err := frame.reset(&br)
489			if debug && err != nil {
490				println("Frame decoder returned", err)
491			}
492			if err == nil && frame.DictionaryID != nil {
493				dict, ok := d.dicts[*frame.DictionaryID]
494				if !ok {
495					err = ErrUnknownDictionary
496				} else {
497					frame.history.setDict(&dict)
498				}
499			}
500			if err != nil {
501				stream.output <- decodeOutput{
502					err: err,
503				}
504				break
505			}
506			if debug {
507				println("starting frame decoder")
508			}
509
510			// This goroutine will forward history between frames.
511			frame.frameDone.Add(1)
512			frame.initAsync()
513
514			go frame.startDecoder(stream.output)
515		decodeFrame:
516			// Go through all blocks of the frame.
517			for {
518				dec := <-d.decoders
519				select {
520				case <-stream.cancel:
521					if !frame.sendErr(dec, io.EOF) {
522						// To not let the decoder dangle, send it back.
523						stream.output <- decodeOutput{d: dec}
524					}
525					break decodeStream
526				default:
527				}
528				err := frame.next(dec)
529				switch err {
530				case io.EOF:
531					// End of current frame, no error
532					println("EOF on next block")
533					break decodeFrame
534				case nil:
535					continue
536				default:
537					println("block decoder returned", err)
538					break decodeStream
539				}
540			}
541			// All blocks have started decoding, check if there are more frames.
542			println("waiting for done")
543			frame.frameDone.Wait()
544			println("done waiting...")
545		}
546		frame.frameDone.Wait()
547		println("Sending EOS")
548		stream.output <- decodeOutput{err: errEndOfStream}
549	}
550}
551