1// Copyright 2011 The Snappy-Go Authors. All rights reserved.
2// Copyright (c) 2019 Klaus Post. All rights reserved.
3// Use of this source code is governed by a BSD-style
4// license that can be found in the LICENSE file.
5
6package s2
7
8import (
9	"crypto/rand"
10	"encoding/binary"
11	"errors"
12	"fmt"
13	"io"
14	"math"
15	"math/bits"
16	"runtime"
17	"sync"
18)
19
20// Encode returns the encoded form of src. The returned slice may be a sub-
21// slice of dst if dst was large enough to hold the entire encoded block.
22// Otherwise, a newly allocated slice will be returned.
23//
24// The dst and src must not overlap. It is valid to pass a nil dst.
25//
26// The blocks will require the same amount of memory to decode as encoding,
27// and does not make for concurrent decoding.
28// Also note that blocks do not contain CRC information, so corruption may be undetected.
29//
30// If you need to encode larger amounts of data, consider using
31// the streaming interface which gives all of these features.
32func Encode(dst, src []byte) []byte {
33	if n := MaxEncodedLen(len(src)); n < 0 {
34		panic(ErrTooLarge)
35	} else if cap(dst) < n {
36		dst = make([]byte, n)
37	} else {
38		dst = dst[:n]
39	}
40
41	// The block starts with the varint-encoded length of the decompressed bytes.
42	d := binary.PutUvarint(dst, uint64(len(src)))
43
44	if len(src) == 0 {
45		return dst[:d]
46	}
47	if len(src) < minNonLiteralBlockSize {
48		d += emitLiteral(dst[d:], src)
49		return dst[:d]
50	}
51	n := encodeBlock(dst[d:], src)
52	if n > 0 {
53		d += n
54		return dst[:d]
55	}
56	// Not compressible
57	d += emitLiteral(dst[d:], src)
58	return dst[:d]
59}
60
61// EncodeBetter returns the encoded form of src. The returned slice may be a sub-
62// slice of dst if dst was large enough to hold the entire encoded block.
63// Otherwise, a newly allocated slice will be returned.
64//
65// EncodeBetter compresses better than Encode but typically with a
66// 10-40% speed decrease on both compression and decompression.
67//
68// The dst and src must not overlap. It is valid to pass a nil dst.
69//
70// The blocks will require the same amount of memory to decode as encoding,
71// and does not make for concurrent decoding.
72// Also note that blocks do not contain CRC information, so corruption may be undetected.
73//
74// If you need to encode larger amounts of data, consider using
75// the streaming interface which gives all of these features.
76func EncodeBetter(dst, src []byte) []byte {
77	if n := MaxEncodedLen(len(src)); n < 0 {
78		panic(ErrTooLarge)
79	} else if len(dst) < n {
80		dst = make([]byte, n)
81	}
82
83	// The block starts with the varint-encoded length of the decompressed bytes.
84	d := binary.PutUvarint(dst, uint64(len(src)))
85
86	if len(src) == 0 {
87		return dst[:d]
88	}
89	if len(src) < minNonLiteralBlockSize {
90		d += emitLiteral(dst[d:], src)
91		return dst[:d]
92	}
93	n := encodeBlockBetter(dst[d:], src)
94	if n > 0 {
95		d += n
96		return dst[:d]
97	}
98	// Not compressible
99	d += emitLiteral(dst[d:], src)
100	return dst[:d]
101}
102
103// EncodeSnappy returns the encoded form of src. The returned slice may be a sub-
104// slice of dst if dst was large enough to hold the entire encoded block.
105// Otherwise, a newly allocated slice will be returned.
106//
107// The output is Snappy compatible and will likely decompress faster.
108//
109// The dst and src must not overlap. It is valid to pass a nil dst.
110//
111// The blocks will require the same amount of memory to decode as encoding,
112// and does not make for concurrent decoding.
113// Also note that blocks do not contain CRC information, so corruption may be undetected.
114//
115// If you need to encode larger amounts of data, consider using
116// the streaming interface which gives all of these features.
117func EncodeSnappy(dst, src []byte) []byte {
118	if n := MaxEncodedLen(len(src)); n < 0 {
119		panic(ErrTooLarge)
120	} else if cap(dst) < n {
121		dst = make([]byte, n)
122	} else {
123		dst = dst[:n]
124	}
125
126	// The block starts with the varint-encoded length of the decompressed bytes.
127	d := binary.PutUvarint(dst, uint64(len(src)))
128
129	if len(src) == 0 {
130		return dst[:d]
131	}
132	if len(src) < minNonLiteralBlockSize {
133		d += emitLiteral(dst[d:], src)
134		return dst[:d]
135	}
136
137	n := encodeBlockSnappy(dst[d:], src)
138	if n > 0 {
139		d += n
140		return dst[:d]
141	}
142	// Not compressible
143	d += emitLiteral(dst[d:], src)
144	return dst[:d]
145}
146
147// ConcatBlocks will concatenate the supplied blocks and append them to the supplied destination.
148// If the destination is nil or too small, a new will be allocated.
149// The blocks are not validated, so garbage in = garbage out.
150// dst may not overlap block data.
151// Any data in dst is preserved as is, so it will not be considered a block.
152func ConcatBlocks(dst []byte, blocks ...[]byte) ([]byte, error) {
153	totalSize := uint64(0)
154	compSize := 0
155	for _, b := range blocks {
156		l, hdr, err := decodedLen(b)
157		if err != nil {
158			return nil, err
159		}
160		totalSize += uint64(l)
161		compSize += len(b) - hdr
162	}
163	if totalSize == 0 {
164		dst = append(dst, 0)
165		return dst, nil
166	}
167	if totalSize > math.MaxUint32 {
168		return nil, ErrTooLarge
169	}
170	var tmp [binary.MaxVarintLen32]byte
171	hdrSize := binary.PutUvarint(tmp[:], totalSize)
172	wantSize := hdrSize + compSize
173
174	if cap(dst)-len(dst) < wantSize {
175		dst = append(make([]byte, 0, wantSize+len(dst)), dst...)
176	}
177	dst = append(dst, tmp[:hdrSize]...)
178	for _, b := range blocks {
179		_, hdr, err := decodedLen(b)
180		if err != nil {
181			return nil, err
182		}
183		dst = append(dst, b[hdr:]...)
184	}
185	return dst, nil
186}
187
188// inputMargin is the minimum number of extra input bytes to keep, inside
189// encodeBlock's inner loop. On some architectures, this margin lets us
190// implement a fast path for emitLiteral, where the copy of short (<= 16 byte)
191// literals can be implemented as a single load to and store from a 16-byte
192// register. That literal's actual length can be as short as 1 byte, so this
193// can copy up to 15 bytes too much, but that's OK as subsequent iterations of
194// the encoding loop will fix up the copy overrun, and this inputMargin ensures
195// that we don't overrun the dst and src buffers.
196const inputMargin = 8
197
198// minNonLiteralBlockSize is the minimum size of the input to encodeBlock that
199// will be accepted by the encoder.
200const minNonLiteralBlockSize = 32
201
202// MaxBlockSize is the maximum value where MaxEncodedLen will return a valid block size.
203// Blocks this big are highly discouraged, though.
204const MaxBlockSize = math.MaxUint32 - binary.MaxVarintLen32 - 5
205
206// MaxEncodedLen returns the maximum length of a snappy block, given its
207// uncompressed length.
208//
209// It will return a negative value if srcLen is too large to encode.
210// 32 bit platforms will have lower thresholds for rejecting big content.
211func MaxEncodedLen(srcLen int) int {
212	n := uint64(srcLen)
213	if n > 0xffffffff {
214		// Also includes negative.
215		return -1
216	}
217	// Size of the varint encoded block size.
218	n = n + uint64((bits.Len64(n)+7)/7)
219
220	// Add maximum size of encoding block as literals.
221	n += uint64(literalExtraSize(int64(srcLen)))
222	if n > 0xffffffff {
223		return -1
224	}
225	return int(n)
226}
227
228var errClosed = errors.New("s2: Writer is closed")
229
230// NewWriter returns a new Writer that compresses to w, using the
231// framing format described at
232// https://github.com/google/snappy/blob/master/framing_format.txt
233//
234// Users must call Close to guarantee all data has been forwarded to
235// the underlying io.Writer and that resources are released.
236// They may also call Flush zero or more times before calling Close.
237func NewWriter(w io.Writer, opts ...WriterOption) *Writer {
238	w2 := Writer{
239		blockSize:   defaultBlockSize,
240		concurrency: runtime.GOMAXPROCS(0),
241	}
242	for _, opt := range opts {
243		if err := opt(&w2); err != nil {
244			w2.errState = err
245			return &w2
246		}
247	}
248	w2.obufLen = obufHeaderLen + MaxEncodedLen(w2.blockSize)
249	w2.paramsOK = true
250	w2.ibuf = make([]byte, 0, w2.blockSize)
251	w2.buffers.New = func() interface{} {
252		return make([]byte, w2.obufLen)
253	}
254	w2.Reset(w)
255	return &w2
256}
257
258// Writer is an io.Writer that can write Snappy-compressed bytes.
259type Writer struct {
260	errMu    sync.Mutex
261	errState error
262
263	// ibuf is a buffer for the incoming (uncompressed) bytes.
264	ibuf []byte
265
266	blockSize   int
267	obufLen     int
268	concurrency int
269	written     int64
270	output      chan chan result
271	buffers     sync.Pool
272	pad         int
273
274	writer   io.Writer
275	writerWg sync.WaitGroup
276
277	// wroteStreamHeader is whether we have written the stream header.
278	wroteStreamHeader bool
279	paramsOK          bool
280	better            bool
281}
282
283type result []byte
284
285// err returns the previously set error.
286// If no error has been set it is set to err if not nil.
287func (w *Writer) err(err error) error {
288	w.errMu.Lock()
289	errSet := w.errState
290	if errSet == nil && err != nil {
291		w.errState = err
292		errSet = err
293	}
294	w.errMu.Unlock()
295	return errSet
296}
297
298// Reset discards the writer's state and switches the Snappy writer to write to w.
299// This permits reusing a Writer rather than allocating a new one.
300func (w *Writer) Reset(writer io.Writer) {
301	if !w.paramsOK {
302		return
303	}
304	// Close previous writer, if any.
305	if w.output != nil {
306		close(w.output)
307		w.writerWg.Wait()
308		w.output = nil
309	}
310	w.errState = nil
311	w.ibuf = w.ibuf[:0]
312	w.wroteStreamHeader = false
313	w.written = 0
314	w.writer = writer
315	// If we didn't get a writer, stop here.
316	if writer == nil {
317		return
318	}
319	// If no concurrency requested, don't spin up writer goroutine.
320	if w.concurrency == 1 {
321		return
322	}
323
324	toWrite := make(chan chan result, w.concurrency)
325	w.output = toWrite
326	w.writerWg.Add(1)
327
328	// Start a writer goroutine that will write all output in order.
329	go func() {
330		defer w.writerWg.Done()
331
332		// Get a queued write.
333		for write := range toWrite {
334			// Wait for the data to be available.
335			in := <-write
336			if len(in) > 0 {
337				if w.err(nil) == nil {
338					// Don't expose data from previous buffers.
339					toWrite := in[:len(in):len(in)]
340					// Write to output.
341					n, err := writer.Write(toWrite)
342					if err == nil && n != len(toWrite) {
343						err = io.ErrShortBuffer
344					}
345					_ = w.err(err)
346					w.written += int64(n)
347				}
348			}
349			if cap(in) >= w.obufLen {
350				w.buffers.Put([]byte(in))
351			}
352			// close the incoming write request.
353			// This can be used for synchronizing flushes.
354			close(write)
355		}
356	}()
357}
358
359// Write satisfies the io.Writer interface.
360func (w *Writer) Write(p []byte) (nRet int, errRet error) {
361	// If we exceed the input buffer size, start writing
362	for len(p) > (cap(w.ibuf)-len(w.ibuf)) && w.err(nil) == nil {
363		var n int
364		if len(w.ibuf) == 0 {
365			// Large write, empty buffer.
366			// Write directly from p to avoid copy.
367			n, _ = w.write(p)
368		} else {
369			n = copy(w.ibuf[len(w.ibuf):cap(w.ibuf)], p)
370			w.ibuf = w.ibuf[:len(w.ibuf)+n]
371			w.write(w.ibuf)
372			w.ibuf = w.ibuf[:0]
373		}
374		nRet += n
375		p = p[n:]
376	}
377	if err := w.err(nil); err != nil {
378		return nRet, err
379	}
380	// p should always be able to fit into w.ibuf now.
381	n := copy(w.ibuf[len(w.ibuf):cap(w.ibuf)], p)
382	w.ibuf = w.ibuf[:len(w.ibuf)+n]
383	nRet += n
384	return nRet, nil
385}
386
387// ReadFrom implements the io.ReaderFrom interface.
388// Using this is typically more efficient since it avoids a memory copy.
389// ReadFrom reads data from r until EOF or error.
390// The return value n is the number of bytes read.
391// Any error except io.EOF encountered during the read is also returned.
392func (w *Writer) ReadFrom(r io.Reader) (n int64, err error) {
393	if len(w.ibuf) > 0 {
394		err := w.Flush()
395		if err != nil {
396			return 0, err
397		}
398	}
399	for {
400		inbuf := w.buffers.Get().([]byte)[:w.blockSize+obufHeaderLen]
401		n2, err := io.ReadFull(r, inbuf[obufHeaderLen:])
402		if err != nil {
403			if err == io.ErrUnexpectedEOF {
404				err = io.EOF
405			}
406			if err != io.EOF {
407				return n, w.err(err)
408			}
409		}
410		if n2 == 0 {
411			break
412		}
413		n += int64(n2)
414		err2 := w.writeFull(inbuf[:n2+obufHeaderLen])
415		if w.err(err2) != nil {
416			break
417		}
418
419		if err != nil {
420			// We got EOF and wrote everything
421			break
422		}
423	}
424
425	return n, w.err(nil)
426}
427
428// EncodeBuffer will add a buffer to the stream.
429// This is the fastest way to encode a stream,
430// but the input buffer cannot be written to by the caller
431// until this function, Flush or Close has been called.
432//
433// Note that input is not buffered.
434// This means that each write will result in discrete blocks being created.
435// For buffered writes, use the regular Write function.
436func (w *Writer) EncodeBuffer(buf []byte) (err error) {
437	if err := w.err(nil); err != nil {
438		return err
439	}
440
441	// Flush queued data first.
442	if len(w.ibuf) > 0 {
443		err := w.Flush()
444		if err != nil {
445			return err
446		}
447	}
448	if w.concurrency == 1 {
449		_, err := w.writeSync(buf)
450		return err
451	}
452
453	// Spawn goroutine and write block to output channel.
454	if !w.wroteStreamHeader {
455		w.wroteStreamHeader = true
456		hWriter := make(chan result)
457		w.output <- hWriter
458		hWriter <- []byte(magicChunk)
459	}
460
461	for len(buf) > 0 {
462		// Cut input.
463		uncompressed := buf
464		if len(uncompressed) > w.blockSize {
465			uncompressed = uncompressed[:w.blockSize]
466		}
467		buf = buf[len(uncompressed):]
468		// Get an output buffer.
469		obuf := w.buffers.Get().([]byte)[:len(uncompressed)+obufHeaderLen]
470		output := make(chan result)
471		// Queue output now, so we keep order.
472		w.output <- output
473		go func() {
474			checksum := crc(uncompressed)
475
476			// Set to uncompressed.
477			chunkType := uint8(chunkTypeUncompressedData)
478			chunkLen := 4 + len(uncompressed)
479
480			// Attempt compressing.
481			n := binary.PutUvarint(obuf[obufHeaderLen:], uint64(len(uncompressed)))
482			var n2 int
483			if w.better {
484				n2 = encodeBlockBetter(obuf[obufHeaderLen+n:], uncompressed)
485			} else {
486				n2 = encodeBlock(obuf[obufHeaderLen+n:], uncompressed)
487			}
488
489			// Check if we should use this, or store as uncompressed instead.
490			if n2 > 0 {
491				chunkType = uint8(chunkTypeCompressedData)
492				chunkLen = 4 + n + n2
493				obuf = obuf[:obufHeaderLen+n+n2]
494			} else {
495				// copy uncompressed
496				copy(obuf[obufHeaderLen:], uncompressed)
497			}
498
499			// Fill in the per-chunk header that comes before the body.
500			obuf[0] = chunkType
501			obuf[1] = uint8(chunkLen >> 0)
502			obuf[2] = uint8(chunkLen >> 8)
503			obuf[3] = uint8(chunkLen >> 16)
504			obuf[4] = uint8(checksum >> 0)
505			obuf[5] = uint8(checksum >> 8)
506			obuf[6] = uint8(checksum >> 16)
507			obuf[7] = uint8(checksum >> 24)
508
509			// Queue final output.
510			output <- obuf
511		}()
512	}
513	return nil
514}
515
516func (w *Writer) write(p []byte) (nRet int, errRet error) {
517	if err := w.err(nil); err != nil {
518		return 0, err
519	}
520	if w.concurrency == 1 {
521		return w.writeSync(p)
522	}
523
524	// Spawn goroutine and write block to output channel.
525	for len(p) > 0 {
526		if !w.wroteStreamHeader {
527			w.wroteStreamHeader = true
528			hWriter := make(chan result)
529			w.output <- hWriter
530			hWriter <- []byte(magicChunk)
531		}
532
533		var uncompressed []byte
534		if len(p) > w.blockSize {
535			uncompressed, p = p[:w.blockSize], p[w.blockSize:]
536		} else {
537			uncompressed, p = p, nil
538		}
539
540		// Copy input.
541		// If the block is incompressible, this is used for the result.
542		inbuf := w.buffers.Get().([]byte)[:len(uncompressed)+obufHeaderLen]
543		obuf := w.buffers.Get().([]byte)[:w.obufLen]
544		copy(inbuf[obufHeaderLen:], uncompressed)
545		uncompressed = inbuf[obufHeaderLen:]
546
547		output := make(chan result)
548		// Queue output now, so we keep order.
549		w.output <- output
550		go func() {
551			checksum := crc(uncompressed)
552
553			// Set to uncompressed.
554			chunkType := uint8(chunkTypeUncompressedData)
555			chunkLen := 4 + len(uncompressed)
556
557			// Attempt compressing.
558			n := binary.PutUvarint(obuf[obufHeaderLen:], uint64(len(uncompressed)))
559			var n2 int
560			if w.better {
561				n2 = encodeBlockBetter(obuf[obufHeaderLen+n:], uncompressed)
562			} else {
563				n2 = encodeBlock(obuf[obufHeaderLen+n:], uncompressed)
564			}
565
566			// Check if we should use this, or store as uncompressed instead.
567			if n2 > 0 {
568				chunkType = uint8(chunkTypeCompressedData)
569				chunkLen = 4 + n + n2
570				obuf = obuf[:obufHeaderLen+n+n2]
571			} else {
572				// Use input as output.
573				obuf, inbuf = inbuf, obuf
574			}
575
576			// Fill in the per-chunk header that comes before the body.
577			obuf[0] = chunkType
578			obuf[1] = uint8(chunkLen >> 0)
579			obuf[2] = uint8(chunkLen >> 8)
580			obuf[3] = uint8(chunkLen >> 16)
581			obuf[4] = uint8(checksum >> 0)
582			obuf[5] = uint8(checksum >> 8)
583			obuf[6] = uint8(checksum >> 16)
584			obuf[7] = uint8(checksum >> 24)
585
586			// Queue final output.
587			output <- obuf
588
589			// Put unused buffer back in pool.
590			w.buffers.Put(inbuf)
591		}()
592		nRet += len(uncompressed)
593	}
594	return nRet, nil
595}
596
597// writeFull is a special version of write that will always write the full buffer.
598// Data to be compressed should start at offset obufHeaderLen and fill the remainder of the buffer.
599// The data will be written as a single block.
600// The caller is not allowed to use inbuf after this function has been called.
601func (w *Writer) writeFull(inbuf []byte) (errRet error) {
602	if err := w.err(nil); err != nil {
603		return err
604	}
605
606	if w.concurrency == 1 {
607		_, err := w.writeSync(inbuf[obufHeaderLen:])
608		return err
609	}
610
611	// Spawn goroutine and write block to output channel.
612	if !w.wroteStreamHeader {
613		w.wroteStreamHeader = true
614		hWriter := make(chan result)
615		w.output <- hWriter
616		hWriter <- []byte(magicChunk)
617	}
618
619	// Get an output buffer.
620	obuf := w.buffers.Get().([]byte)[:w.obufLen]
621	uncompressed := inbuf[obufHeaderLen:]
622
623	output := make(chan result)
624	// Queue output now, so we keep order.
625	w.output <- output
626	go func() {
627		checksum := crc(uncompressed)
628
629		// Set to uncompressed.
630		chunkType := uint8(chunkTypeUncompressedData)
631		chunkLen := 4 + len(uncompressed)
632
633		// Attempt compressing.
634		n := binary.PutUvarint(obuf[obufHeaderLen:], uint64(len(uncompressed)))
635		var n2 int
636		if w.better {
637			n2 = encodeBlockBetter(obuf[obufHeaderLen+n:], uncompressed)
638		} else {
639			n2 = encodeBlock(obuf[obufHeaderLen+n:], uncompressed)
640		}
641
642		// Check if we should use this, or store as uncompressed instead.
643		if n2 > 0 {
644			chunkType = uint8(chunkTypeCompressedData)
645			chunkLen = 4 + n + n2
646			obuf = obuf[:obufHeaderLen+n+n2]
647		} else {
648			// Use input as output.
649			obuf, inbuf = inbuf, obuf
650		}
651
652		// Fill in the per-chunk header that comes before the body.
653		obuf[0] = chunkType
654		obuf[1] = uint8(chunkLen >> 0)
655		obuf[2] = uint8(chunkLen >> 8)
656		obuf[3] = uint8(chunkLen >> 16)
657		obuf[4] = uint8(checksum >> 0)
658		obuf[5] = uint8(checksum >> 8)
659		obuf[6] = uint8(checksum >> 16)
660		obuf[7] = uint8(checksum >> 24)
661
662		// Queue final output.
663		output <- obuf
664
665		// Put unused buffer back in pool.
666		w.buffers.Put(inbuf)
667	}()
668	return nil
669}
670
671func (w *Writer) writeSync(p []byte) (nRet int, errRet error) {
672	if err := w.err(nil); err != nil {
673		return 0, err
674	}
675	if !w.wroteStreamHeader {
676		w.wroteStreamHeader = true
677		n, err := w.writer.Write([]byte(magicChunk))
678		if err != nil {
679			return 0, w.err(err)
680		}
681		if n != len(magicChunk) {
682			return 0, w.err(io.ErrShortWrite)
683		}
684		w.written += int64(n)
685	}
686
687	for len(p) > 0 {
688		var uncompressed []byte
689		if len(p) > w.blockSize {
690			uncompressed, p = p[:w.blockSize], p[w.blockSize:]
691		} else {
692			uncompressed, p = p, nil
693		}
694
695		obuf := w.buffers.Get().([]byte)[:w.obufLen]
696		checksum := crc(uncompressed)
697
698		// Set to uncompressed.
699		chunkType := uint8(chunkTypeUncompressedData)
700		chunkLen := 4 + len(uncompressed)
701
702		// Attempt compressing.
703		n := binary.PutUvarint(obuf[obufHeaderLen:], uint64(len(uncompressed)))
704		var n2 int
705		if w.better {
706			n2 = encodeBlockBetter(obuf[obufHeaderLen+n:], uncompressed)
707		} else {
708			n2 = encodeBlock(obuf[obufHeaderLen+n:], uncompressed)
709		}
710
711		if n2 > 0 {
712			chunkType = uint8(chunkTypeCompressedData)
713			chunkLen = 4 + n + n2
714			obuf = obuf[:obufHeaderLen+n+n2]
715		} else {
716			obuf = obuf[:8]
717		}
718
719		// Fill in the per-chunk header that comes before the body.
720		obuf[0] = chunkType
721		obuf[1] = uint8(chunkLen >> 0)
722		obuf[2] = uint8(chunkLen >> 8)
723		obuf[3] = uint8(chunkLen >> 16)
724		obuf[4] = uint8(checksum >> 0)
725		obuf[5] = uint8(checksum >> 8)
726		obuf[6] = uint8(checksum >> 16)
727		obuf[7] = uint8(checksum >> 24)
728
729		n, err := w.writer.Write(obuf)
730		if err != nil {
731			return 0, w.err(err)
732		}
733		if n != len(obuf) {
734			return 0, w.err(io.ErrShortWrite)
735		}
736		w.written += int64(n)
737		if chunkType == chunkTypeUncompressedData {
738			// Write uncompressed data.
739			n, err := w.writer.Write(uncompressed)
740			if err != nil {
741				return 0, w.err(err)
742			}
743			if n != len(uncompressed) {
744				return 0, w.err(io.ErrShortWrite)
745			}
746			w.written += int64(n)
747		}
748		w.buffers.Put(obuf)
749		// Queue final output.
750		nRet += len(uncompressed)
751	}
752	return nRet, nil
753}
754
755// Flush flushes the Writer to its underlying io.Writer.
756// This does not apply padding.
757func (w *Writer) Flush() error {
758	if err := w.err(nil); err != nil {
759		return err
760	}
761
762	// Queue any data still in input buffer.
763	if len(w.ibuf) != 0 {
764		_, err := w.write(w.ibuf)
765		w.ibuf = w.ibuf[:0]
766		err = w.err(err)
767		if err != nil {
768			return err
769		}
770	}
771	if w.output == nil {
772		return w.err(nil)
773	}
774
775	// Send empty buffer
776	res := make(chan result)
777	w.output <- res
778	// Block until this has been picked up.
779	res <- nil
780	// When it is closed, we have flushed.
781	<-res
782	return w.err(nil)
783}
784
785// Close calls Flush and then closes the Writer.
786// Calling Close multiple times is ok.
787func (w *Writer) Close() error {
788	err := w.Flush()
789	if w.output != nil {
790		close(w.output)
791		w.writerWg.Wait()
792		w.output = nil
793	}
794	if w.err(nil) == nil && w.writer != nil && w.pad > 0 {
795		add := calcSkippableFrame(w.written, int64(w.pad))
796		frame, err := skippableFrame(w.ibuf[:0], add, rand.Reader)
797		if err = w.err(err); err != nil {
798			return err
799		}
800		_, err2 := w.writer.Write(frame)
801		_ = w.err(err2)
802	}
803	_ = w.err(errClosed)
804	if err == errClosed {
805		return nil
806	}
807	return err
808}
809
810const skippableFrameHeader = 4
811
812// calcSkippableFrame will return a total size to be added for written
813// to be divisible by multiple.
814// The value will always be > skippableFrameHeader.
815// The function will panic if written < 0 or wantMultiple <= 0.
816func calcSkippableFrame(written, wantMultiple int64) int {
817	if wantMultiple <= 0 {
818		panic("wantMultiple <= 0")
819	}
820	if written < 0 {
821		panic("written < 0")
822	}
823	leftOver := written % wantMultiple
824	if leftOver == 0 {
825		return 0
826	}
827	toAdd := wantMultiple - leftOver
828	for toAdd < skippableFrameHeader {
829		toAdd += wantMultiple
830	}
831	return int(toAdd)
832}
833
834// skippableFrame will add a skippable frame with a total size of bytes.
835// total should be >= skippableFrameHeader and < maxBlockSize + skippableFrameHeader
836func skippableFrame(dst []byte, total int, r io.Reader) ([]byte, error) {
837	if total == 0 {
838		return dst, nil
839	}
840	if total < skippableFrameHeader {
841		return dst, fmt.Errorf("s2: requested skippable frame (%d) < 4", total)
842	}
843	if int64(total) >= maxBlockSize+skippableFrameHeader {
844		return dst, fmt.Errorf("s2: requested skippable frame (%d) >= max 1<<24", total)
845	}
846	// Chunk type 0xfe "Section 4.4 Padding (chunk type 0xfe)"
847	dst = append(dst, chunkTypePadding)
848	f := uint32(total - skippableFrameHeader)
849	// Add chunk length.
850	dst = append(dst, uint8(f), uint8(f>>8), uint8(f>>16))
851	// Add data
852	start := len(dst)
853	dst = append(dst, make([]byte, f)...)
854	_, err := io.ReadFull(r, dst[start:])
855	return dst, err
856}
857
858// WriterOption is an option for creating a encoder.
859type WriterOption func(*Writer) error
860
861// WriterConcurrency will set the concurrency,
862// meaning the maximum number of decoders to run concurrently.
863// The value supplied must be at least 1.
864// By default this will be set to GOMAXPROCS.
865func WriterConcurrency(n int) WriterOption {
866	return func(w *Writer) error {
867		if n <= 0 {
868			return errors.New("concurrency must be at least 1")
869		}
870		w.concurrency = n
871		return nil
872	}
873}
874
875// WriterBetterCompression will enable better compression.
876// EncodeBetter compresses better than Encode but typically with a
877// 10-40% speed decrease on both compression and decompression.
878func WriterBetterCompression() WriterOption {
879	return func(w *Writer) error {
880		w.better = true
881		return nil
882	}
883}
884
885// WriterBlockSize allows to override the default block size.
886// Blocks will be this size or smaller.
887// Minimum size is 4KB and and maximum size is 4MB.
888//
889// Bigger blocks may give bigger throughput on systems with many cores,
890// and will increase compression slightly, but it will limit the possible
891// concurrency for smaller payloads for both encoding and decoding.
892// Default block size is 1MB.
893func WriterBlockSize(n int) WriterOption {
894	return func(w *Writer) error {
895		if w.blockSize > maxBlockSize || w.blockSize < minBlockSize {
896			return errors.New("s2: block size too large. Must be <= 4MB and >=4KB")
897		}
898		w.blockSize = n
899		return nil
900	}
901}
902
903// WriterPadding will add padding to all output so the size will be a multiple of n.
904// This can be used to obfuscate the exact output size or make blocks of a certain size.
905// The contents will be a skippable frame, so it will be invisible by the decoder.
906// n must be > 0 and <= 4MB.
907// The padded area will be filled with data from crypto/rand.Reader.
908// The padding will be applied whenever Close is called on the writer.
909func WriterPadding(n int) WriterOption {
910	return func(w *Writer) error {
911		if n <= 0 {
912			return fmt.Errorf("s2: padding must be at least 1")
913		}
914		// No need to waste our time.
915		if n == 1 {
916			w.pad = 0
917		}
918		if n > maxBlockSize {
919			return fmt.Errorf("s2: padding must less than 4MB")
920		}
921		w.pad = n
922		return nil
923	}
924}
925