1// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import (
8	"crypto/rand"
9	"fmt"
10	"io"
11	rdebug "runtime/debug"
12	"sync"
13
14	"github.com/klauspost/compress/zstd/internal/xxhash"
15)
16
17// Encoder provides encoding to Zstandard.
18// An Encoder can be used for either compressing a stream via the
19// io.WriteCloser interface supported by the Encoder or as multiple independent
20// tasks via the EncodeAll function.
21// Smaller encodes are encouraged to use the EncodeAll function.
22// Use NewWriter to create a new instance.
23type Encoder struct {
24	o        encoderOptions
25	encoders chan encoder
26	state    encoderState
27	init     sync.Once
28}
29
30type encoder interface {
31	Encode(blk *blockEnc, src []byte)
32	EncodeNoHist(blk *blockEnc, src []byte)
33	Block() *blockEnc
34	CRC() *xxhash.Digest
35	AppendCRC([]byte) []byte
36	WindowSize(size int64) int32
37	UseBlock(*blockEnc)
38	Reset(d *dict, singleBlock bool)
39}
40
41type encoderState struct {
42	w                io.Writer
43	filling          []byte
44	current          []byte
45	previous         []byte
46	encoder          encoder
47	writing          *blockEnc
48	err              error
49	writeErr         error
50	nWritten         int64
51	nInput           int64
52	frameContentSize int64
53	headerWritten    bool
54	eofWritten       bool
55	fullFrameWritten bool
56
57	// This waitgroup indicates an encode is running.
58	wg sync.WaitGroup
59	// This waitgroup indicates we have a block encoding/writing.
60	wWg sync.WaitGroup
61}
62
63// NewWriter will create a new Zstandard encoder.
64// If the encoder will be used for encoding blocks a nil writer can be used.
65func NewWriter(w io.Writer, opts ...EOption) (*Encoder, error) {
66	initPredefined()
67	var e Encoder
68	e.o.setDefault()
69	for _, o := range opts {
70		err := o(&e.o)
71		if err != nil {
72			return nil, err
73		}
74	}
75	if w != nil {
76		e.Reset(w)
77	}
78	return &e, nil
79}
80
81func (e *Encoder) initialize() {
82	if e.o.concurrent == 0 {
83		e.o.setDefault()
84	}
85	e.encoders = make(chan encoder, e.o.concurrent)
86	for i := 0; i < e.o.concurrent; i++ {
87		enc := e.o.encoder()
88		e.encoders <- enc
89	}
90}
91
92// Reset will re-initialize the writer and new writes will encode to the supplied writer
93// as a new, independent stream.
94func (e *Encoder) Reset(w io.Writer) {
95	s := &e.state
96	s.wg.Wait()
97	s.wWg.Wait()
98	if cap(s.filling) == 0 {
99		s.filling = make([]byte, 0, e.o.blockSize)
100	}
101	if cap(s.current) == 0 {
102		s.current = make([]byte, 0, e.o.blockSize)
103	}
104	if cap(s.previous) == 0 {
105		s.previous = make([]byte, 0, e.o.blockSize)
106	}
107	if s.encoder == nil {
108		s.encoder = e.o.encoder()
109	}
110	if s.writing == nil {
111		s.writing = &blockEnc{lowMem: e.o.lowMem}
112		s.writing.init()
113	}
114	s.writing.initNewEncode()
115	s.filling = s.filling[:0]
116	s.current = s.current[:0]
117	s.previous = s.previous[:0]
118	s.encoder.Reset(e.o.dict, false)
119	s.headerWritten = false
120	s.eofWritten = false
121	s.fullFrameWritten = false
122	s.w = w
123	s.err = nil
124	s.nWritten = 0
125	s.nInput = 0
126	s.writeErr = nil
127	s.frameContentSize = 0
128}
129
130// ResetContentSize will reset and set a content size for the next stream.
131// If the bytes written does not match the size given an error will be returned
132// when calling Close().
133// This is removed when Reset is called.
134// Sizes <= 0 results in no content size set.
135func (e *Encoder) ResetContentSize(w io.Writer, size int64) {
136	e.Reset(w)
137	if size >= 0 {
138		e.state.frameContentSize = size
139	}
140}
141
142// Write data to the encoder.
143// Input data will be buffered and as the buffer fills up
144// content will be compressed and written to the output.
145// When done writing, use Close to flush the remaining output
146// and write CRC if requested.
147func (e *Encoder) Write(p []byte) (n int, err error) {
148	s := &e.state
149	for len(p) > 0 {
150		if len(p)+len(s.filling) < e.o.blockSize {
151			if e.o.crc {
152				_, _ = s.encoder.CRC().Write(p)
153			}
154			s.filling = append(s.filling, p...)
155			return n + len(p), nil
156		}
157		add := p
158		if len(p)+len(s.filling) > e.o.blockSize {
159			add = add[:e.o.blockSize-len(s.filling)]
160		}
161		if e.o.crc {
162			_, _ = s.encoder.CRC().Write(add)
163		}
164		s.filling = append(s.filling, add...)
165		p = p[len(add):]
166		n += len(add)
167		if len(s.filling) < e.o.blockSize {
168			return n, nil
169		}
170		err := e.nextBlock(false)
171		if err != nil {
172			return n, err
173		}
174		if debugAsserts && len(s.filling) > 0 {
175			panic(len(s.filling))
176		}
177	}
178	return n, nil
179}
180
181// nextBlock will synchronize and start compressing input in e.state.filling.
182// If an error has occurred during encoding it will be returned.
183func (e *Encoder) nextBlock(final bool) error {
184	s := &e.state
185	// Wait for current block.
186	s.wg.Wait()
187	if s.err != nil {
188		return s.err
189	}
190	if len(s.filling) > e.o.blockSize {
191		return fmt.Errorf("block > maxStoreBlockSize")
192	}
193	if !s.headerWritten {
194		// If we have a single block encode, do a sync compression.
195		if final && len(s.filling) == 0 && !e.o.fullZero {
196			s.headerWritten = true
197			s.fullFrameWritten = true
198			s.eofWritten = true
199			return nil
200		}
201		if final && len(s.filling) > 0 {
202			s.current = e.EncodeAll(s.filling, s.current[:0])
203			var n2 int
204			n2, s.err = s.w.Write(s.current)
205			if s.err != nil {
206				return s.err
207			}
208			s.nWritten += int64(n2)
209			s.nInput += int64(len(s.filling))
210			s.current = s.current[:0]
211			s.filling = s.filling[:0]
212			s.headerWritten = true
213			s.fullFrameWritten = true
214			s.eofWritten = true
215			return nil
216		}
217
218		var tmp [maxHeaderSize]byte
219		fh := frameHeader{
220			ContentSize:   uint64(s.frameContentSize),
221			WindowSize:    uint32(s.encoder.WindowSize(s.frameContentSize)),
222			SingleSegment: false,
223			Checksum:      e.o.crc,
224			DictID:        e.o.dict.ID(),
225		}
226
227		dst, err := fh.appendTo(tmp[:0])
228		if err != nil {
229			return err
230		}
231		s.headerWritten = true
232		s.wWg.Wait()
233		var n2 int
234		n2, s.err = s.w.Write(dst)
235		if s.err != nil {
236			return s.err
237		}
238		s.nWritten += int64(n2)
239	}
240	if s.eofWritten {
241		// Ensure we only write it once.
242		final = false
243	}
244
245	if len(s.filling) == 0 {
246		// Final block, but no data.
247		if final {
248			enc := s.encoder
249			blk := enc.Block()
250			blk.reset(nil)
251			blk.last = true
252			blk.encodeRaw(nil)
253			s.wWg.Wait()
254			_, s.err = s.w.Write(blk.output)
255			s.nWritten += int64(len(blk.output))
256			s.eofWritten = true
257		}
258		return s.err
259	}
260
261	// Move blocks forward.
262	s.filling, s.current, s.previous = s.previous[:0], s.filling, s.current
263	s.nInput += int64(len(s.current))
264	s.wg.Add(1)
265	go func(src []byte) {
266		if debugEncoder {
267			println("Adding block,", len(src), "bytes, final:", final)
268		}
269		defer func() {
270			if r := recover(); r != nil {
271				s.err = fmt.Errorf("panic while encoding: %v", r)
272				rdebug.PrintStack()
273			}
274			s.wg.Done()
275		}()
276		enc := s.encoder
277		blk := enc.Block()
278		enc.Encode(blk, src)
279		blk.last = final
280		if final {
281			s.eofWritten = true
282		}
283		// Wait for pending writes.
284		s.wWg.Wait()
285		if s.writeErr != nil {
286			s.err = s.writeErr
287			return
288		}
289		// Transfer encoders from previous write block.
290		blk.swapEncoders(s.writing)
291		// Transfer recent offsets to next.
292		enc.UseBlock(s.writing)
293		s.writing = blk
294		s.wWg.Add(1)
295		go func() {
296			defer func() {
297				if r := recover(); r != nil {
298					s.writeErr = fmt.Errorf("panic while encoding/writing: %v", r)
299					rdebug.PrintStack()
300				}
301				s.wWg.Done()
302			}()
303			err := errIncompressible
304			// If we got the exact same number of literals as input,
305			// assume the literals cannot be compressed.
306			if len(src) != len(blk.literals) || len(src) != e.o.blockSize {
307				err = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
308			}
309			switch err {
310			case errIncompressible:
311				if debugEncoder {
312					println("Storing incompressible block as raw")
313				}
314				blk.encodeRaw(src)
315				// In fast mode, we do not transfer offsets, so we don't have to deal with changing the.
316			case nil:
317			default:
318				s.writeErr = err
319				return
320			}
321			_, s.writeErr = s.w.Write(blk.output)
322			s.nWritten += int64(len(blk.output))
323		}()
324	}(s.current)
325	return nil
326}
327
328// ReadFrom reads data from r until EOF or error.
329// The return value n is the number of bytes read.
330// Any error except io.EOF encountered during the read is also returned.
331//
332// The Copy function uses ReaderFrom if available.
333func (e *Encoder) ReadFrom(r io.Reader) (n int64, err error) {
334	if debugEncoder {
335		println("Using ReadFrom")
336	}
337
338	// Flush any current writes.
339	if len(e.state.filling) > 0 {
340		if err := e.nextBlock(false); err != nil {
341			return 0, err
342		}
343	}
344	e.state.filling = e.state.filling[:e.o.blockSize]
345	src := e.state.filling
346	for {
347		n2, err := r.Read(src)
348		if e.o.crc {
349			_, _ = e.state.encoder.CRC().Write(src[:n2])
350		}
351		// src is now the unfilled part...
352		src = src[n2:]
353		n += int64(n2)
354		switch err {
355		case io.EOF:
356			e.state.filling = e.state.filling[:len(e.state.filling)-len(src)]
357			if debugEncoder {
358				println("ReadFrom: got EOF final block:", len(e.state.filling))
359			}
360			return n, nil
361		case nil:
362		default:
363			if debugEncoder {
364				println("ReadFrom: got error:", err)
365			}
366			e.state.err = err
367			return n, err
368		}
369		if len(src) > 0 {
370			if debugEncoder {
371				println("ReadFrom: got space left in source:", len(src))
372			}
373			continue
374		}
375		err = e.nextBlock(false)
376		if err != nil {
377			return n, err
378		}
379		e.state.filling = e.state.filling[:e.o.blockSize]
380		src = e.state.filling
381	}
382}
383
384// Flush will send the currently written data to output
385// and block until everything has been written.
386// This should only be used on rare occasions where pushing the currently queued data is critical.
387func (e *Encoder) Flush() error {
388	s := &e.state
389	if len(s.filling) > 0 {
390		err := e.nextBlock(false)
391		if err != nil {
392			return err
393		}
394	}
395	s.wg.Wait()
396	s.wWg.Wait()
397	if s.err != nil {
398		return s.err
399	}
400	return s.writeErr
401}
402
403// Close will flush the final output and close the stream.
404// The function will block until everything has been written.
405// The Encoder can still be re-used after calling this.
406func (e *Encoder) Close() error {
407	s := &e.state
408	if s.encoder == nil {
409		return nil
410	}
411	err := e.nextBlock(true)
412	if err != nil {
413		return err
414	}
415	if s.frameContentSize > 0 {
416		if s.nInput != s.frameContentSize {
417			return fmt.Errorf("frame content size %d given, but %d bytes was written", s.frameContentSize, s.nInput)
418		}
419	}
420	if e.state.fullFrameWritten {
421		return s.err
422	}
423	s.wg.Wait()
424	s.wWg.Wait()
425
426	if s.err != nil {
427		return s.err
428	}
429	if s.writeErr != nil {
430		return s.writeErr
431	}
432
433	// Write CRC
434	if e.o.crc && s.err == nil {
435		// heap alloc.
436		var tmp [4]byte
437		_, s.err = s.w.Write(s.encoder.AppendCRC(tmp[:0]))
438		s.nWritten += 4
439	}
440
441	// Add padding with content from crypto/rand.Reader
442	if s.err == nil && e.o.pad > 0 {
443		add := calcSkippableFrame(s.nWritten, int64(e.o.pad))
444		frame, err := skippableFrame(s.filling[:0], add, rand.Reader)
445		if err != nil {
446			return err
447		}
448		_, s.err = s.w.Write(frame)
449	}
450	return s.err
451}
452
453// EncodeAll will encode all input in src and append it to dst.
454// This function can be called concurrently, but each call will only run on a single goroutine.
455// If empty input is given, nothing is returned, unless WithZeroFrames is specified.
456// Encoded blocks can be concatenated and the result will be the combined input stream.
457// Data compressed with EncodeAll can be decoded with the Decoder,
458// using either a stream or DecodeAll.
459func (e *Encoder) EncodeAll(src, dst []byte) []byte {
460	if len(src) == 0 {
461		if e.o.fullZero {
462			// Add frame header.
463			fh := frameHeader{
464				ContentSize:   0,
465				WindowSize:    MinWindowSize,
466				SingleSegment: true,
467				// Adding a checksum would be a waste of space.
468				Checksum: false,
469				DictID:   0,
470			}
471			dst, _ = fh.appendTo(dst)
472
473			// Write raw block as last one only.
474			var blk blockHeader
475			blk.setSize(0)
476			blk.setType(blockTypeRaw)
477			blk.setLast(true)
478			dst = blk.appendTo(dst)
479		}
480		return dst
481	}
482	e.init.Do(e.initialize)
483	enc := <-e.encoders
484	defer func() {
485		// Release encoder reference to last block.
486		// If a non-single block is needed the encoder will reset again.
487		e.encoders <- enc
488	}()
489	// Use single segments when above minimum window and below 1MB.
490	single := len(src) < 1<<20 && len(src) > MinWindowSize
491	if e.o.single != nil {
492		single = *e.o.single
493	}
494	fh := frameHeader{
495		ContentSize:   uint64(len(src)),
496		WindowSize:    uint32(enc.WindowSize(int64(len(src)))),
497		SingleSegment: single,
498		Checksum:      e.o.crc,
499		DictID:        e.o.dict.ID(),
500	}
501
502	// If less than 1MB, allocate a buffer up front.
503	if len(dst) == 0 && cap(dst) == 0 && len(src) < 1<<20 && !e.o.lowMem {
504		dst = make([]byte, 0, len(src))
505	}
506	dst, err := fh.appendTo(dst)
507	if err != nil {
508		panic(err)
509	}
510
511	// If we can do everything in one block, prefer that.
512	if len(src) <= maxCompressedBlockSize {
513		enc.Reset(e.o.dict, true)
514		// Slightly faster with no history and everything in one block.
515		if e.o.crc {
516			_, _ = enc.CRC().Write(src)
517		}
518		blk := enc.Block()
519		blk.last = true
520		if e.o.dict == nil {
521			enc.EncodeNoHist(blk, src)
522		} else {
523			enc.Encode(blk, src)
524		}
525
526		// If we got the exact same number of literals as input,
527		// assume the literals cannot be compressed.
528		err := errIncompressible
529		oldout := blk.output
530		if len(blk.literals) != len(src) || len(src) != e.o.blockSize {
531			// Output directly to dst
532			blk.output = dst
533			err = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
534		}
535
536		switch err {
537		case errIncompressible:
538			if debugEncoder {
539				println("Storing incompressible block as raw")
540			}
541			dst = blk.encodeRawTo(dst, src)
542		case nil:
543			dst = blk.output
544		default:
545			panic(err)
546		}
547		blk.output = oldout
548	} else {
549		enc.Reset(e.o.dict, false)
550		blk := enc.Block()
551		for len(src) > 0 {
552			todo := src
553			if len(todo) > e.o.blockSize {
554				todo = todo[:e.o.blockSize]
555			}
556			src = src[len(todo):]
557			if e.o.crc {
558				_, _ = enc.CRC().Write(todo)
559			}
560			blk.pushOffsets()
561			enc.Encode(blk, todo)
562			if len(src) == 0 {
563				blk.last = true
564			}
565			err := errIncompressible
566			// If we got the exact same number of literals as input,
567			// assume the literals cannot be compressed.
568			if len(blk.literals) != len(todo) || len(todo) != e.o.blockSize {
569				err = blk.encode(todo, e.o.noEntropy, !e.o.allLitEntropy)
570			}
571
572			switch err {
573			case errIncompressible:
574				if debugEncoder {
575					println("Storing incompressible block as raw")
576				}
577				dst = blk.encodeRawTo(dst, todo)
578				blk.popOffsets()
579			case nil:
580				dst = append(dst, blk.output...)
581			default:
582				panic(err)
583			}
584			blk.reset(nil)
585		}
586	}
587	if e.o.crc {
588		dst = enc.AppendCRC(dst)
589	}
590	// Add padding with content from crypto/rand.Reader
591	if e.o.pad > 0 {
592		add := calcSkippableFrame(int64(len(dst)), int64(e.o.pad))
593		dst, err = skippableFrame(dst, add, rand.Reader)
594		if err != nil {
595			panic(err)
596		}
597	}
598	return dst
599}
600