1// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import (
8	"crypto/rand"
9	"fmt"
10	"io"
11	rdebug "runtime/debug"
12	"sync"
13
14	"github.com/klauspost/compress/zstd/internal/xxhash"
15)
16
17// Encoder provides encoding to Zstandard.
18// An Encoder can be used for either compressing a stream via the
19// io.WriteCloser interface supported by the Encoder or as multiple independent
20// tasks via the EncodeAll function.
21// Smaller encodes are encouraged to use the EncodeAll function.
22// Use NewWriter to create a new instance.
23type Encoder struct {
24	o        encoderOptions
25	encoders chan encoder
26	state    encoderState
27	init     sync.Once
28}
29
30type encoder interface {
31	Encode(blk *blockEnc, src []byte)
32	EncodeNoHist(blk *blockEnc, src []byte)
33	Block() *blockEnc
34	CRC() *xxhash.Digest
35	AppendCRC([]byte) []byte
36	WindowSize(size int) int32
37	UseBlock(*blockEnc)
38	Reset(d *dict, singleBlock bool)
39}
40
41type encoderState struct {
42	w                io.Writer
43	filling          []byte
44	current          []byte
45	previous         []byte
46	encoder          encoder
47	writing          *blockEnc
48	err              error
49	writeErr         error
50	nWritten         int64
51	headerWritten    bool
52	eofWritten       bool
53	fullFrameWritten bool
54
55	// This waitgroup indicates an encode is running.
56	wg sync.WaitGroup
57	// This waitgroup indicates we have a block encoding/writing.
58	wWg sync.WaitGroup
59}
60
61// NewWriter will create a new Zstandard encoder.
62// If the encoder will be used for encoding blocks a nil writer can be used.
63func NewWriter(w io.Writer, opts ...EOption) (*Encoder, error) {
64	initPredefined()
65	var e Encoder
66	e.o.setDefault()
67	for _, o := range opts {
68		err := o(&e.o)
69		if err != nil {
70			return nil, err
71		}
72	}
73	if w != nil {
74		e.Reset(w)
75	}
76	return &e, nil
77}
78
79func (e *Encoder) initialize() {
80	if e.o.concurrent == 0 {
81		e.o.setDefault()
82	}
83	e.encoders = make(chan encoder, e.o.concurrent)
84	for i := 0; i < e.o.concurrent; i++ {
85		enc := e.o.encoder()
86		e.encoders <- enc
87	}
88}
89
90// Reset will re-initialize the writer and new writes will encode to the supplied writer
91// as a new, independent stream.
92func (e *Encoder) Reset(w io.Writer) {
93	s := &e.state
94	s.wg.Wait()
95	s.wWg.Wait()
96	if cap(s.filling) == 0 {
97		s.filling = make([]byte, 0, e.o.blockSize)
98	}
99	if cap(s.current) == 0 {
100		s.current = make([]byte, 0, e.o.blockSize)
101	}
102	if cap(s.previous) == 0 {
103		s.previous = make([]byte, 0, e.o.blockSize)
104	}
105	if s.encoder == nil {
106		s.encoder = e.o.encoder()
107	}
108	if s.writing == nil {
109		s.writing = &blockEnc{}
110		s.writing.init()
111	}
112	s.writing.initNewEncode()
113	s.filling = s.filling[:0]
114	s.current = s.current[:0]
115	s.previous = s.previous[:0]
116	s.encoder.Reset(e.o.dict, false)
117	s.headerWritten = false
118	s.eofWritten = false
119	s.fullFrameWritten = false
120	s.w = w
121	s.err = nil
122	s.nWritten = 0
123	s.writeErr = nil
124}
125
126// Write data to the encoder.
127// Input data will be buffered and as the buffer fills up
128// content will be compressed and written to the output.
129// When done writing, use Close to flush the remaining output
130// and write CRC if requested.
131func (e *Encoder) Write(p []byte) (n int, err error) {
132	s := &e.state
133	for len(p) > 0 {
134		if len(p)+len(s.filling) < e.o.blockSize {
135			if e.o.crc {
136				_, _ = s.encoder.CRC().Write(p)
137			}
138			s.filling = append(s.filling, p...)
139			return n + len(p), nil
140		}
141		add := p
142		if len(p)+len(s.filling) > e.o.blockSize {
143			add = add[:e.o.blockSize-len(s.filling)]
144		}
145		if e.o.crc {
146			_, _ = s.encoder.CRC().Write(add)
147		}
148		s.filling = append(s.filling, add...)
149		p = p[len(add):]
150		n += len(add)
151		if len(s.filling) < e.o.blockSize {
152			return n, nil
153		}
154		err := e.nextBlock(false)
155		if err != nil {
156			return n, err
157		}
158		if debugAsserts && len(s.filling) > 0 {
159			panic(len(s.filling))
160		}
161	}
162	return n, nil
163}
164
165// nextBlock will synchronize and start compressing input in e.state.filling.
166// If an error has occurred during encoding it will be returned.
167func (e *Encoder) nextBlock(final bool) error {
168	s := &e.state
169	// Wait for current block.
170	s.wg.Wait()
171	if s.err != nil {
172		return s.err
173	}
174	if len(s.filling) > e.o.blockSize {
175		return fmt.Errorf("block > maxStoreBlockSize")
176	}
177	if !s.headerWritten {
178		// If we have a single block encode, do a sync compression.
179		if final && len(s.filling) > 0 {
180			s.current = e.EncodeAll(s.filling, s.current[:0])
181			var n2 int
182			n2, s.err = s.w.Write(s.current)
183			if s.err != nil {
184				return s.err
185			}
186			s.nWritten += int64(n2)
187			s.current = s.current[:0]
188			s.filling = s.filling[:0]
189			s.headerWritten = true
190			s.fullFrameWritten = true
191			s.eofWritten = true
192			return nil
193		}
194
195		var tmp [maxHeaderSize]byte
196		fh := frameHeader{
197			ContentSize:   0,
198			WindowSize:    uint32(s.encoder.WindowSize(0)),
199			SingleSegment: false,
200			Checksum:      e.o.crc,
201			DictID:        e.o.dict.ID(),
202		}
203
204		dst, err := fh.appendTo(tmp[:0])
205		if err != nil {
206			return err
207		}
208		s.headerWritten = true
209		s.wWg.Wait()
210		var n2 int
211		n2, s.err = s.w.Write(dst)
212		if s.err != nil {
213			return s.err
214		}
215		s.nWritten += int64(n2)
216	}
217	if s.eofWritten {
218		// Ensure we only write it once.
219		final = false
220	}
221
222	if len(s.filling) == 0 {
223		// Final block, but no data.
224		if final {
225			enc := s.encoder
226			blk := enc.Block()
227			blk.reset(nil)
228			blk.last = true
229			blk.encodeRaw(nil)
230			s.wWg.Wait()
231			_, s.err = s.w.Write(blk.output)
232			s.nWritten += int64(len(blk.output))
233			s.eofWritten = true
234		}
235		return s.err
236	}
237
238	// Move blocks forward.
239	s.filling, s.current, s.previous = s.previous[:0], s.filling, s.current
240	s.wg.Add(1)
241	go func(src []byte) {
242		if debug {
243			println("Adding block,", len(src), "bytes, final:", final)
244		}
245		defer func() {
246			if r := recover(); r != nil {
247				s.err = fmt.Errorf("panic while encoding: %v", r)
248				rdebug.PrintStack()
249			}
250			s.wg.Done()
251		}()
252		enc := s.encoder
253		blk := enc.Block()
254		enc.Encode(blk, src)
255		blk.last = final
256		if final {
257			s.eofWritten = true
258		}
259		// Wait for pending writes.
260		s.wWg.Wait()
261		if s.writeErr != nil {
262			s.err = s.writeErr
263			return
264		}
265		// Transfer encoders from previous write block.
266		blk.swapEncoders(s.writing)
267		// Transfer recent offsets to next.
268		enc.UseBlock(s.writing)
269		s.writing = blk
270		s.wWg.Add(1)
271		go func() {
272			defer func() {
273				if r := recover(); r != nil {
274					s.writeErr = fmt.Errorf("panic while encoding/writing: %v", r)
275					rdebug.PrintStack()
276				}
277				s.wWg.Done()
278			}()
279			err := errIncompressible
280			// If we got the exact same number of literals as input,
281			// assume the literals cannot be compressed.
282			if len(src) != len(blk.literals) || len(src) != e.o.blockSize {
283				err = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
284			}
285			switch err {
286			case errIncompressible:
287				if debug {
288					println("Storing incompressible block as raw")
289				}
290				blk.encodeRaw(src)
291				// In fast mode, we do not transfer offsets, so we don't have to deal with changing the.
292			case nil:
293			default:
294				s.writeErr = err
295				return
296			}
297			_, s.writeErr = s.w.Write(blk.output)
298			s.nWritten += int64(len(blk.output))
299		}()
300	}(s.current)
301	return nil
302}
303
304// ReadFrom reads data from r until EOF or error.
305// The return value n is the number of bytes read.
306// Any error except io.EOF encountered during the read is also returned.
307//
308// The Copy function uses ReaderFrom if available.
309func (e *Encoder) ReadFrom(r io.Reader) (n int64, err error) {
310	if debug {
311		println("Using ReadFrom")
312	}
313
314	// Flush any current writes.
315	if len(e.state.filling) > 0 {
316		if err := e.nextBlock(false); err != nil {
317			return 0, err
318		}
319	}
320	e.state.filling = e.state.filling[:e.o.blockSize]
321	src := e.state.filling
322	for {
323		n2, err := r.Read(src)
324		if e.o.crc {
325			_, _ = e.state.encoder.CRC().Write(src[:n2])
326		}
327		// src is now the unfilled part...
328		src = src[n2:]
329		n += int64(n2)
330		switch err {
331		case io.EOF:
332			e.state.filling = e.state.filling[:len(e.state.filling)-len(src)]
333			if debug {
334				println("ReadFrom: got EOF final block:", len(e.state.filling))
335			}
336			return n, nil
337		default:
338			if debug {
339				println("ReadFrom: got error:", err)
340			}
341			e.state.err = err
342			return n, err
343		case nil:
344		}
345		if len(src) > 0 {
346			if debug {
347				println("ReadFrom: got space left in source:", len(src))
348			}
349			continue
350		}
351		err = e.nextBlock(false)
352		if err != nil {
353			return n, err
354		}
355		e.state.filling = e.state.filling[:e.o.blockSize]
356		src = e.state.filling
357	}
358}
359
360// Flush will send the currently written data to output
361// and block until everything has been written.
362// This should only be used on rare occasions where pushing the currently queued data is critical.
363func (e *Encoder) Flush() error {
364	s := &e.state
365	if len(s.filling) > 0 {
366		err := e.nextBlock(false)
367		if err != nil {
368			return err
369		}
370	}
371	s.wg.Wait()
372	s.wWg.Wait()
373	if s.err != nil {
374		return s.err
375	}
376	return s.writeErr
377}
378
379// Close will flush the final output and close the stream.
380// The function will block until everything has been written.
381// The Encoder can still be re-used after calling this.
382func (e *Encoder) Close() error {
383	s := &e.state
384	if s.encoder == nil {
385		return nil
386	}
387	err := e.nextBlock(true)
388	if err != nil {
389		return err
390	}
391	if e.state.fullFrameWritten {
392		return s.err
393	}
394	s.wg.Wait()
395	s.wWg.Wait()
396
397	if s.err != nil {
398		return s.err
399	}
400	if s.writeErr != nil {
401		return s.writeErr
402	}
403
404	// Write CRC
405	if e.o.crc && s.err == nil {
406		// heap alloc.
407		var tmp [4]byte
408		_, s.err = s.w.Write(s.encoder.AppendCRC(tmp[:0]))
409		s.nWritten += 4
410	}
411
412	// Add padding with content from crypto/rand.Reader
413	if s.err == nil && e.o.pad > 0 {
414		add := calcSkippableFrame(s.nWritten, int64(e.o.pad))
415		frame, err := skippableFrame(s.filling[:0], add, rand.Reader)
416		if err != nil {
417			return err
418		}
419		_, s.err = s.w.Write(frame)
420	}
421	return s.err
422}
423
424// EncodeAll will encode all input in src and append it to dst.
425// This function can be called concurrently, but each call will only run on a single goroutine.
426// If empty input is given, nothing is returned, unless WithZeroFrames is specified.
427// Encoded blocks can be concatenated and the result will be the combined input stream.
428// Data compressed with EncodeAll can be decoded with the Decoder,
429// using either a stream or DecodeAll.
430func (e *Encoder) EncodeAll(src, dst []byte) []byte {
431	if len(src) == 0 {
432		if e.o.fullZero {
433			// Add frame header.
434			fh := frameHeader{
435				ContentSize:   0,
436				WindowSize:    MinWindowSize,
437				SingleSegment: true,
438				// Adding a checksum would be a waste of space.
439				Checksum: false,
440				DictID:   0,
441			}
442			dst, _ = fh.appendTo(dst)
443
444			// Write raw block as last one only.
445			var blk blockHeader
446			blk.setSize(0)
447			blk.setType(blockTypeRaw)
448			blk.setLast(true)
449			dst = blk.appendTo(dst)
450		}
451		return dst
452	}
453	e.init.Do(e.initialize)
454	enc := <-e.encoders
455	defer func() {
456		// Release encoder reference to last block.
457		// If a non-single block is needed the encoder will reset again.
458		e.encoders <- enc
459	}()
460	// Use single segments when above minimum window and below 1MB.
461	single := len(src) < 1<<20 && len(src) > MinWindowSize
462	if e.o.single != nil {
463		single = *e.o.single
464	}
465	fh := frameHeader{
466		ContentSize:   uint64(len(src)),
467		WindowSize:    uint32(enc.WindowSize(len(src))),
468		SingleSegment: single,
469		Checksum:      e.o.crc,
470		DictID:        e.o.dict.ID(),
471	}
472
473	// If less than 1MB, allocate a buffer up front.
474	if len(dst) == 0 && cap(dst) == 0 && len(src) < 1<<20 {
475		dst = make([]byte, 0, len(src))
476	}
477	dst, err := fh.appendTo(dst)
478	if err != nil {
479		panic(err)
480	}
481
482	// If we can do everything in one block, prefer that.
483	if len(src) <= maxCompressedBlockSize {
484		enc.Reset(e.o.dict, true)
485		// Slightly faster with no history and everything in one block.
486		if e.o.crc {
487			_, _ = enc.CRC().Write(src)
488		}
489		blk := enc.Block()
490		blk.last = true
491		if e.o.dict == nil {
492			enc.EncodeNoHist(blk, src)
493		} else {
494			enc.Encode(blk, src)
495		}
496
497		// If we got the exact same number of literals as input,
498		// assume the literals cannot be compressed.
499		err := errIncompressible
500		oldout := blk.output
501		if len(blk.literals) != len(src) || len(src) != e.o.blockSize {
502			// Output directly to dst
503			blk.output = dst
504			err = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
505		}
506
507		switch err {
508		case errIncompressible:
509			if debug {
510				println("Storing incompressible block as raw")
511			}
512			dst = blk.encodeRawTo(dst, src)
513		case nil:
514			dst = blk.output
515		default:
516			panic(err)
517		}
518		blk.output = oldout
519	} else {
520		enc.Reset(e.o.dict, false)
521		blk := enc.Block()
522		for len(src) > 0 {
523			todo := src
524			if len(todo) > e.o.blockSize {
525				todo = todo[:e.o.blockSize]
526			}
527			src = src[len(todo):]
528			if e.o.crc {
529				_, _ = enc.CRC().Write(todo)
530			}
531			blk.pushOffsets()
532			enc.Encode(blk, todo)
533			if len(src) == 0 {
534				blk.last = true
535			}
536			err := errIncompressible
537			// If we got the exact same number of literals as input,
538			// assume the literals cannot be compressed.
539			if len(blk.literals) != len(todo) || len(todo) != e.o.blockSize {
540				err = blk.encode(todo, e.o.noEntropy, !e.o.allLitEntropy)
541			}
542
543			switch err {
544			case errIncompressible:
545				if debug {
546					println("Storing incompressible block as raw")
547				}
548				dst = blk.encodeRawTo(dst, todo)
549				blk.popOffsets()
550			case nil:
551				dst = append(dst, blk.output...)
552			default:
553				panic(err)
554			}
555			blk.reset(nil)
556		}
557	}
558	if e.o.crc {
559		dst = enc.AppendCRC(dst)
560	}
561	// Add padding with content from crypto/rand.Reader
562	if e.o.pad > 0 {
563		add := calcSkippableFrame(int64(len(dst)), int64(e.o.pad))
564		dst, err = skippableFrame(dst, add, rand.Reader)
565		if err != nil {
566			panic(err)
567		}
568	}
569	return dst
570}
571