1// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import (
8	"crypto/rand"
9	"fmt"
10	"io"
11	rdebug "runtime/debug"
12	"sync"
13
14	"github.com/klauspost/compress/zstd/internal/xxhash"
15)
16
17// Encoder provides encoding to Zstandard.
18// An Encoder can be used for either compressing a stream via the
19// io.WriteCloser interface supported by the Encoder or as multiple independent
20// tasks via the EncodeAll function.
21// Smaller encodes are encouraged to use the EncodeAll function.
22// Use NewWriter to create a new instance.
23type Encoder struct {
24	o        encoderOptions
25	encoders chan encoder
26	state    encoderState
27	init     sync.Once
28}
29
30type encoder interface {
31	Encode(blk *blockEnc, src []byte)
32	EncodeNoHist(blk *blockEnc, src []byte)
33	Block() *blockEnc
34	CRC() *xxhash.Digest
35	AppendCRC([]byte) []byte
36	WindowSize(size int) int32
37	UseBlock(*blockEnc)
38	Reset(d *dict, singleBlock bool)
39}
40
41type encoderState struct {
42	w                io.Writer
43	filling          []byte
44	current          []byte
45	previous         []byte
46	encoder          encoder
47	writing          *blockEnc
48	err              error
49	writeErr         error
50	nWritten         int64
51	headerWritten    bool
52	eofWritten       bool
53	fullFrameWritten bool
54
55	// This waitgroup indicates an encode is running.
56	wg sync.WaitGroup
57	// This waitgroup indicates we have a block encoding/writing.
58	wWg sync.WaitGroup
59}
60
61// NewWriter will create a new Zstandard encoder.
62// If the encoder will be used for encoding blocks a nil writer can be used.
63func NewWriter(w io.Writer, opts ...EOption) (*Encoder, error) {
64	initPredefined()
65	var e Encoder
66	e.o.setDefault()
67	for _, o := range opts {
68		err := o(&e.o)
69		if err != nil {
70			return nil, err
71		}
72	}
73	if w != nil {
74		e.Reset(w)
75	}
76	return &e, nil
77}
78
79func (e *Encoder) initialize() {
80	if e.o.concurrent == 0 {
81		e.o.setDefault()
82	}
83	e.encoders = make(chan encoder, e.o.concurrent)
84	for i := 0; i < e.o.concurrent; i++ {
85		enc := e.o.encoder()
86		e.encoders <- enc
87	}
88}
89
90// Reset will re-initialize the writer and new writes will encode to the supplied writer
91// as a new, independent stream.
92func (e *Encoder) Reset(w io.Writer) {
93	s := &e.state
94	s.wg.Wait()
95	s.wWg.Wait()
96	if cap(s.filling) == 0 {
97		s.filling = make([]byte, 0, e.o.blockSize)
98	}
99	if cap(s.current) == 0 {
100		s.current = make([]byte, 0, e.o.blockSize)
101	}
102	if cap(s.previous) == 0 {
103		s.previous = make([]byte, 0, e.o.blockSize)
104	}
105	if s.encoder == nil {
106		s.encoder = e.o.encoder()
107	}
108	if s.writing == nil {
109		s.writing = &blockEnc{lowMem: e.o.lowMem}
110		s.writing.init()
111	}
112	s.writing.initNewEncode()
113	s.filling = s.filling[:0]
114	s.current = s.current[:0]
115	s.previous = s.previous[:0]
116	s.encoder.Reset(e.o.dict, false)
117	s.headerWritten = false
118	s.eofWritten = false
119	s.fullFrameWritten = false
120	s.w = w
121	s.err = nil
122	s.nWritten = 0
123	s.writeErr = nil
124}
125
126// Write data to the encoder.
127// Input data will be buffered and as the buffer fills up
128// content will be compressed and written to the output.
129// When done writing, use Close to flush the remaining output
130// and write CRC if requested.
131func (e *Encoder) Write(p []byte) (n int, err error) {
132	s := &e.state
133	for len(p) > 0 {
134		if len(p)+len(s.filling) < e.o.blockSize {
135			if e.o.crc {
136				_, _ = s.encoder.CRC().Write(p)
137			}
138			s.filling = append(s.filling, p...)
139			return n + len(p), nil
140		}
141		add := p
142		if len(p)+len(s.filling) > e.o.blockSize {
143			add = add[:e.o.blockSize-len(s.filling)]
144		}
145		if e.o.crc {
146			_, _ = s.encoder.CRC().Write(add)
147		}
148		s.filling = append(s.filling, add...)
149		p = p[len(add):]
150		n += len(add)
151		if len(s.filling) < e.o.blockSize {
152			return n, nil
153		}
154		err := e.nextBlock(false)
155		if err != nil {
156			return n, err
157		}
158		if debugAsserts && len(s.filling) > 0 {
159			panic(len(s.filling))
160		}
161	}
162	return n, nil
163}
164
165// nextBlock will synchronize and start compressing input in e.state.filling.
166// If an error has occurred during encoding it will be returned.
167func (e *Encoder) nextBlock(final bool) error {
168	s := &e.state
169	// Wait for current block.
170	s.wg.Wait()
171	if s.err != nil {
172		return s.err
173	}
174	if len(s.filling) > e.o.blockSize {
175		return fmt.Errorf("block > maxStoreBlockSize")
176	}
177	if !s.headerWritten {
178		// If we have a single block encode, do a sync compression.
179		if final && len(s.filling) == 0 && !e.o.fullZero {
180			s.headerWritten = true
181			s.fullFrameWritten = true
182			s.eofWritten = true
183			return nil
184		}
185		if final && len(s.filling) > 0 {
186			s.current = e.EncodeAll(s.filling, s.current[:0])
187			var n2 int
188			n2, s.err = s.w.Write(s.current)
189			if s.err != nil {
190				return s.err
191			}
192			s.nWritten += int64(n2)
193			s.current = s.current[:0]
194			s.filling = s.filling[:0]
195			s.headerWritten = true
196			s.fullFrameWritten = true
197			s.eofWritten = true
198			return nil
199		}
200
201		var tmp [maxHeaderSize]byte
202		fh := frameHeader{
203			ContentSize:   0,
204			WindowSize:    uint32(s.encoder.WindowSize(0)),
205			SingleSegment: false,
206			Checksum:      e.o.crc,
207			DictID:        e.o.dict.ID(),
208		}
209
210		dst, err := fh.appendTo(tmp[:0])
211		if err != nil {
212			return err
213		}
214		s.headerWritten = true
215		s.wWg.Wait()
216		var n2 int
217		n2, s.err = s.w.Write(dst)
218		if s.err != nil {
219			return s.err
220		}
221		s.nWritten += int64(n2)
222	}
223	if s.eofWritten {
224		// Ensure we only write it once.
225		final = false
226	}
227
228	if len(s.filling) == 0 {
229		// Final block, but no data.
230		if final {
231			enc := s.encoder
232			blk := enc.Block()
233			blk.reset(nil)
234			blk.last = true
235			blk.encodeRaw(nil)
236			s.wWg.Wait()
237			_, s.err = s.w.Write(blk.output)
238			s.nWritten += int64(len(blk.output))
239			s.eofWritten = true
240		}
241		return s.err
242	}
243
244	// Move blocks forward.
245	s.filling, s.current, s.previous = s.previous[:0], s.filling, s.current
246	s.wg.Add(1)
247	go func(src []byte) {
248		if debugEncoder {
249			println("Adding block,", len(src), "bytes, final:", final)
250		}
251		defer func() {
252			if r := recover(); r != nil {
253				s.err = fmt.Errorf("panic while encoding: %v", r)
254				rdebug.PrintStack()
255			}
256			s.wg.Done()
257		}()
258		enc := s.encoder
259		blk := enc.Block()
260		enc.Encode(blk, src)
261		blk.last = final
262		if final {
263			s.eofWritten = true
264		}
265		// Wait for pending writes.
266		s.wWg.Wait()
267		if s.writeErr != nil {
268			s.err = s.writeErr
269			return
270		}
271		// Transfer encoders from previous write block.
272		blk.swapEncoders(s.writing)
273		// Transfer recent offsets to next.
274		enc.UseBlock(s.writing)
275		s.writing = blk
276		s.wWg.Add(1)
277		go func() {
278			defer func() {
279				if r := recover(); r != nil {
280					s.writeErr = fmt.Errorf("panic while encoding/writing: %v", r)
281					rdebug.PrintStack()
282				}
283				s.wWg.Done()
284			}()
285			err := errIncompressible
286			// If we got the exact same number of literals as input,
287			// assume the literals cannot be compressed.
288			if len(src) != len(blk.literals) || len(src) != e.o.blockSize {
289				err = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
290			}
291			switch err {
292			case errIncompressible:
293				if debugEncoder {
294					println("Storing incompressible block as raw")
295				}
296				blk.encodeRaw(src)
297				// In fast mode, we do not transfer offsets, so we don't have to deal with changing the.
298			case nil:
299			default:
300				s.writeErr = err
301				return
302			}
303			_, s.writeErr = s.w.Write(blk.output)
304			s.nWritten += int64(len(blk.output))
305		}()
306	}(s.current)
307	return nil
308}
309
310// ReadFrom reads data from r until EOF or error.
311// The return value n is the number of bytes read.
312// Any error except io.EOF encountered during the read is also returned.
313//
314// The Copy function uses ReaderFrom if available.
315func (e *Encoder) ReadFrom(r io.Reader) (n int64, err error) {
316	if debugEncoder {
317		println("Using ReadFrom")
318	}
319
320	// Flush any current writes.
321	if len(e.state.filling) > 0 {
322		if err := e.nextBlock(false); err != nil {
323			return 0, err
324		}
325	}
326	e.state.filling = e.state.filling[:e.o.blockSize]
327	src := e.state.filling
328	for {
329		n2, err := r.Read(src)
330		if e.o.crc {
331			_, _ = e.state.encoder.CRC().Write(src[:n2])
332		}
333		// src is now the unfilled part...
334		src = src[n2:]
335		n += int64(n2)
336		switch err {
337		case io.EOF:
338			e.state.filling = e.state.filling[:len(e.state.filling)-len(src)]
339			if debugEncoder {
340				println("ReadFrom: got EOF final block:", len(e.state.filling))
341			}
342			return n, nil
343		case nil:
344		default:
345			if debugEncoder {
346				println("ReadFrom: got error:", err)
347			}
348			e.state.err = err
349			return n, err
350		}
351		if len(src) > 0 {
352			if debugEncoder {
353				println("ReadFrom: got space left in source:", len(src))
354			}
355			continue
356		}
357		err = e.nextBlock(false)
358		if err != nil {
359			return n, err
360		}
361		e.state.filling = e.state.filling[:e.o.blockSize]
362		src = e.state.filling
363	}
364}
365
366// Flush will send the currently written data to output
367// and block until everything has been written.
368// This should only be used on rare occasions where pushing the currently queued data is critical.
369func (e *Encoder) Flush() error {
370	s := &e.state
371	if len(s.filling) > 0 {
372		err := e.nextBlock(false)
373		if err != nil {
374			return err
375		}
376	}
377	s.wg.Wait()
378	s.wWg.Wait()
379	if s.err != nil {
380		return s.err
381	}
382	return s.writeErr
383}
384
385// Close will flush the final output and close the stream.
386// The function will block until everything has been written.
387// The Encoder can still be re-used after calling this.
388func (e *Encoder) Close() error {
389	s := &e.state
390	if s.encoder == nil {
391		return nil
392	}
393	err := e.nextBlock(true)
394	if err != nil {
395		return err
396	}
397	if e.state.fullFrameWritten {
398		return s.err
399	}
400	s.wg.Wait()
401	s.wWg.Wait()
402
403	if s.err != nil {
404		return s.err
405	}
406	if s.writeErr != nil {
407		return s.writeErr
408	}
409
410	// Write CRC
411	if e.o.crc && s.err == nil {
412		// heap alloc.
413		var tmp [4]byte
414		_, s.err = s.w.Write(s.encoder.AppendCRC(tmp[:0]))
415		s.nWritten += 4
416	}
417
418	// Add padding with content from crypto/rand.Reader
419	if s.err == nil && e.o.pad > 0 {
420		add := calcSkippableFrame(s.nWritten, int64(e.o.pad))
421		frame, err := skippableFrame(s.filling[:0], add, rand.Reader)
422		if err != nil {
423			return err
424		}
425		_, s.err = s.w.Write(frame)
426	}
427	return s.err
428}
429
430// EncodeAll will encode all input in src and append it to dst.
431// This function can be called concurrently, but each call will only run on a single goroutine.
432// If empty input is given, nothing is returned, unless WithZeroFrames is specified.
433// Encoded blocks can be concatenated and the result will be the combined input stream.
434// Data compressed with EncodeAll can be decoded with the Decoder,
435// using either a stream or DecodeAll.
436func (e *Encoder) EncodeAll(src, dst []byte) []byte {
437	if len(src) == 0 {
438		if e.o.fullZero {
439			// Add frame header.
440			fh := frameHeader{
441				ContentSize:   0,
442				WindowSize:    MinWindowSize,
443				SingleSegment: true,
444				// Adding a checksum would be a waste of space.
445				Checksum: false,
446				DictID:   0,
447			}
448			dst, _ = fh.appendTo(dst)
449
450			// Write raw block as last one only.
451			var blk blockHeader
452			blk.setSize(0)
453			blk.setType(blockTypeRaw)
454			blk.setLast(true)
455			dst = blk.appendTo(dst)
456		}
457		return dst
458	}
459	e.init.Do(e.initialize)
460	enc := <-e.encoders
461	defer func() {
462		// Release encoder reference to last block.
463		// If a non-single block is needed the encoder will reset again.
464		e.encoders <- enc
465	}()
466	// Use single segments when above minimum window and below 1MB.
467	single := len(src) < 1<<20 && len(src) > MinWindowSize
468	if e.o.single != nil {
469		single = *e.o.single
470	}
471	fh := frameHeader{
472		ContentSize:   uint64(len(src)),
473		WindowSize:    uint32(enc.WindowSize(len(src))),
474		SingleSegment: single,
475		Checksum:      e.o.crc,
476		DictID:        e.o.dict.ID(),
477	}
478
479	// If less than 1MB, allocate a buffer up front.
480	if len(dst) == 0 && cap(dst) == 0 && len(src) < 1<<20 && !e.o.lowMem {
481		dst = make([]byte, 0, len(src))
482	}
483	dst, err := fh.appendTo(dst)
484	if err != nil {
485		panic(err)
486	}
487
488	// If we can do everything in one block, prefer that.
489	if len(src) <= maxCompressedBlockSize {
490		enc.Reset(e.o.dict, true)
491		// Slightly faster with no history and everything in one block.
492		if e.o.crc {
493			_, _ = enc.CRC().Write(src)
494		}
495		blk := enc.Block()
496		blk.last = true
497		if e.o.dict == nil {
498			enc.EncodeNoHist(blk, src)
499		} else {
500			enc.Encode(blk, src)
501		}
502
503		// If we got the exact same number of literals as input,
504		// assume the literals cannot be compressed.
505		err := errIncompressible
506		oldout := blk.output
507		if len(blk.literals) != len(src) || len(src) != e.o.blockSize {
508			// Output directly to dst
509			blk.output = dst
510			err = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
511		}
512
513		switch err {
514		case errIncompressible:
515			if debugEncoder {
516				println("Storing incompressible block as raw")
517			}
518			dst = blk.encodeRawTo(dst, src)
519		case nil:
520			dst = blk.output
521		default:
522			panic(err)
523		}
524		blk.output = oldout
525	} else {
526		enc.Reset(e.o.dict, false)
527		blk := enc.Block()
528		for len(src) > 0 {
529			todo := src
530			if len(todo) > e.o.blockSize {
531				todo = todo[:e.o.blockSize]
532			}
533			src = src[len(todo):]
534			if e.o.crc {
535				_, _ = enc.CRC().Write(todo)
536			}
537			blk.pushOffsets()
538			enc.Encode(blk, todo)
539			if len(src) == 0 {
540				blk.last = true
541			}
542			err := errIncompressible
543			// If we got the exact same number of literals as input,
544			// assume the literals cannot be compressed.
545			if len(blk.literals) != len(todo) || len(todo) != e.o.blockSize {
546				err = blk.encode(todo, e.o.noEntropy, !e.o.allLitEntropy)
547			}
548
549			switch err {
550			case errIncompressible:
551				if debugEncoder {
552					println("Storing incompressible block as raw")
553				}
554				dst = blk.encodeRawTo(dst, todo)
555				blk.popOffsets()
556			case nil:
557				dst = append(dst, blk.output...)
558			default:
559				panic(err)
560			}
561			blk.reset(nil)
562		}
563	}
564	if e.o.crc {
565		dst = enc.AppendCRC(dst)
566	}
567	// Add padding with content from crypto/rand.Reader
568	if e.o.pad > 0 {
569		add := calcSkippableFrame(int64(len(dst)), int64(e.o.pad))
570		dst, err = skippableFrame(dst, add, rand.Reader)
571		if err != nil {
572			panic(err)
573		}
574	}
575	return dst
576}
577