1// Copyright 2011 The Snappy-Go Authors. All rights reserved.
2// Copyright (c) 2019 Klaus Post. All rights reserved.
3// Use of this source code is governed by a BSD-style
4// license that can be found in the LICENSE file.
5
6package s2
7
8import (
9	"encoding/binary"
10	"errors"
11	"io"
12)
13
14var (
15	// ErrCorrupt reports that the input is invalid.
16	ErrCorrupt = errors.New("s2: corrupt input")
17	// ErrCRC reports that the input failed CRC validation (streams only)
18	ErrCRC = errors.New("s2: corrupt input, crc mismatch")
19	// ErrTooLarge reports that the uncompressed length is too large.
20	ErrTooLarge = errors.New("s2: decoded block is too large")
21	// ErrUnsupported reports that the input isn't supported.
22	ErrUnsupported = errors.New("s2: unsupported input")
23)
24
25// DecodedLen returns the length of the decoded block.
26func DecodedLen(src []byte) (int, error) {
27	v, _, err := decodedLen(src)
28	return v, err
29}
30
31// decodedLen returns the length of the decoded block and the number of bytes
32// that the length header occupied.
33func decodedLen(src []byte) (blockLen, headerLen int, err error) {
34	v, n := binary.Uvarint(src)
35	if n <= 0 || v > 0xffffffff {
36		return 0, 0, ErrCorrupt
37	}
38
39	const wordSize = 32 << (^uint(0) >> 32 & 1)
40	if wordSize == 32 && v > 0x7fffffff {
41		return 0, 0, ErrTooLarge
42	}
43	return int(v), n, nil
44}
45
46const (
47	decodeErrCodeCorrupt = 1
48)
49
50// Decode returns the decoded form of src. The returned slice may be a sub-
51// slice of dst if dst was large enough to hold the entire decoded block.
52// Otherwise, a newly allocated slice will be returned.
53//
54// The dst and src must not overlap. It is valid to pass a nil dst.
55func Decode(dst, src []byte) ([]byte, error) {
56	dLen, s, err := decodedLen(src)
57	if err != nil {
58		return nil, err
59	}
60	if dLen <= cap(dst) {
61		dst = dst[:dLen]
62	} else {
63		dst = make([]byte, dLen)
64	}
65	if s2Decode(dst, src[s:]) != 0 {
66		return nil, ErrCorrupt
67	}
68	return dst, nil
69}
70
71// NewReader returns a new Reader that decompresses from r, using the framing
72// format described at
73// https://github.com/google/snappy/blob/master/framing_format.txt with S2 changes.
74func NewReader(r io.Reader, opts ...ReaderOption) *Reader {
75	nr := Reader{
76		r:        r,
77		maxBlock: maxBlockSize,
78	}
79	for _, opt := range opts {
80		if err := opt(&nr); err != nil {
81			nr.err = err
82			return &nr
83		}
84	}
85	nr.maxBufSize = MaxEncodedLen(nr.maxBlock) + checksumSize
86	if nr.lazyBuf > 0 {
87		nr.buf = make([]byte, MaxEncodedLen(nr.lazyBuf)+checksumSize)
88	} else {
89		nr.buf = make([]byte, MaxEncodedLen(defaultBlockSize)+checksumSize)
90	}
91	nr.paramsOK = true
92	return &nr
93}
94
95// ReaderOption is an option for creating a decoder.
96type ReaderOption func(*Reader) error
97
98// ReaderMaxBlockSize allows to control allocations if the stream
99// has been compressed with a smaller WriterBlockSize, or with the default 1MB.
100// Blocks must be this size or smaller to decompress,
101// otherwise the decoder will return ErrUnsupported.
102//
103// For streams compressed with Snappy this can safely be set to 64KB (64 << 10).
104//
105// Default is the maximum limit of 4MB.
106func ReaderMaxBlockSize(blockSize int) ReaderOption {
107	return func(r *Reader) error {
108		if blockSize > maxBlockSize || blockSize <= 0 {
109			return errors.New("s2: block size too large. Must be <= 4MB and > 0")
110		}
111		if r.lazyBuf == 0 && blockSize < defaultBlockSize {
112			r.lazyBuf = blockSize
113		}
114		r.maxBlock = blockSize
115		return nil
116	}
117}
118
119// ReaderAllocBlock allows to control upfront stream allocations
120// and not allocate for frames bigger than this initially.
121// If frames bigger than this is seen a bigger buffer will be allocated.
122//
123// Default is 1MB, which is default output size.
124func ReaderAllocBlock(blockSize int) ReaderOption {
125	return func(r *Reader) error {
126		if blockSize > maxBlockSize || blockSize < 1024 {
127			return errors.New("s2: invalid ReaderAllocBlock. Must be <= 4MB and >= 1024")
128		}
129		r.lazyBuf = blockSize
130		return nil
131	}
132}
133
134// Reader is an io.Reader that can read Snappy-compressed bytes.
135type Reader struct {
136	r       io.Reader
137	err     error
138	decoded []byte
139	buf     []byte
140	// decoded[i:j] contains decoded bytes that have not yet been passed on.
141	i, j int
142	// maximum block size allowed.
143	maxBlock int
144	// maximum expected buffer size.
145	maxBufSize int
146	// alloc a buffer this size if > 0.
147	lazyBuf     int
148	readHeader  bool
149	paramsOK    bool
150	snappyFrame bool
151}
152
153// ensureBufferSize will ensure that the buffer can take at least n bytes.
154// If false is returned the buffer exceeds maximum allowed size.
155func (r *Reader) ensureBufferSize(n int) bool {
156	if len(r.buf) >= n {
157		return true
158	}
159	if n > r.maxBufSize {
160		r.err = ErrCorrupt
161		return false
162	}
163	// Realloc buffer.
164	r.buf = make([]byte, n)
165	return true
166}
167
168// Reset discards any buffered data, resets all state, and switches the Snappy
169// reader to read from r. This permits reusing a Reader rather than allocating
170// a new one.
171func (r *Reader) Reset(reader io.Reader) {
172	if !r.paramsOK {
173		return
174	}
175	r.r = reader
176	r.err = nil
177	r.i = 0
178	r.j = 0
179	r.readHeader = false
180}
181
182func (r *Reader) readFull(p []byte, allowEOF bool) (ok bool) {
183	if _, r.err = io.ReadFull(r.r, p); r.err != nil {
184		if r.err == io.ErrUnexpectedEOF || (r.err == io.EOF && !allowEOF) {
185			r.err = ErrCorrupt
186		}
187		return false
188	}
189	return true
190}
191
192// skipN will skip n bytes.
193// If the supplied reader supports seeking that is used.
194// tmp is used as a temporary buffer for reading.
195// The supplied slice does not need to be the size of the read.
196func (r *Reader) skipN(tmp []byte, n int, allowEOF bool) (ok bool) {
197	if rs, ok := r.r.(io.ReadSeeker); ok {
198		_, err := rs.Seek(int64(n), io.SeekCurrent)
199		if err == nil {
200			return true
201		}
202		if err == io.ErrUnexpectedEOF || (r.err == io.EOF && !allowEOF) {
203			r.err = ErrCorrupt
204			return false
205		}
206	}
207	for n > 0 {
208		if n < len(tmp) {
209			tmp = tmp[:n]
210		}
211		if _, r.err = io.ReadFull(r.r, tmp); r.err != nil {
212			if r.err == io.ErrUnexpectedEOF || (r.err == io.EOF && !allowEOF) {
213				r.err = ErrCorrupt
214			}
215			return false
216		}
217		n -= len(tmp)
218	}
219	return true
220}
221
222// Read satisfies the io.Reader interface.
223func (r *Reader) Read(p []byte) (int, error) {
224	if r.err != nil {
225		return 0, r.err
226	}
227	for {
228		if r.i < r.j {
229			n := copy(p, r.decoded[r.i:r.j])
230			r.i += n
231			return n, nil
232		}
233		if !r.readFull(r.buf[:4], true) {
234			return 0, r.err
235		}
236		chunkType := r.buf[0]
237		if !r.readHeader {
238			if chunkType != chunkTypeStreamIdentifier {
239				r.err = ErrCorrupt
240				return 0, r.err
241			}
242			r.readHeader = true
243		}
244		chunkLen := int(r.buf[1]) | int(r.buf[2])<<8 | int(r.buf[3])<<16
245
246		// The chunk types are specified at
247		// https://github.com/google/snappy/blob/master/framing_format.txt
248		switch chunkType {
249		case chunkTypeCompressedData:
250			// Section 4.2. Compressed data (chunk type 0x00).
251			if chunkLen < checksumSize {
252				r.err = ErrCorrupt
253				return 0, r.err
254			}
255			if !r.ensureBufferSize(chunkLen) {
256				if r.err == nil {
257					r.err = ErrUnsupported
258				}
259				return 0, r.err
260			}
261			buf := r.buf[:chunkLen]
262			if !r.readFull(buf, false) {
263				return 0, r.err
264			}
265			checksum := uint32(buf[0]) | uint32(buf[1])<<8 | uint32(buf[2])<<16 | uint32(buf[3])<<24
266			buf = buf[checksumSize:]
267
268			n, err := DecodedLen(buf)
269			if err != nil {
270				r.err = err
271				return 0, r.err
272			}
273			if r.snappyFrame && n > maxSnappyBlockSize {
274				r.err = ErrCorrupt
275				return 0, r.err
276			}
277
278			if n > len(r.decoded) {
279				if n > r.maxBlock {
280					r.err = ErrCorrupt
281					return 0, r.err
282				}
283				r.decoded = make([]byte, n)
284			}
285			if _, err := Decode(r.decoded, buf); err != nil {
286				r.err = err
287				return 0, r.err
288			}
289			if crc(r.decoded[:n]) != checksum {
290				r.err = ErrCRC
291				return 0, r.err
292			}
293			r.i, r.j = 0, n
294			continue
295
296		case chunkTypeUncompressedData:
297			// Section 4.3. Uncompressed data (chunk type 0x01).
298			if chunkLen < checksumSize {
299				r.err = ErrCorrupt
300				return 0, r.err
301			}
302			if !r.ensureBufferSize(chunkLen) {
303				if r.err == nil {
304					r.err = ErrUnsupported
305				}
306				return 0, r.err
307			}
308			buf := r.buf[:checksumSize]
309			if !r.readFull(buf, false) {
310				return 0, r.err
311			}
312			checksum := uint32(buf[0]) | uint32(buf[1])<<8 | uint32(buf[2])<<16 | uint32(buf[3])<<24
313			// Read directly into r.decoded instead of via r.buf.
314			n := chunkLen - checksumSize
315			if r.snappyFrame && n > maxSnappyBlockSize {
316				r.err = ErrCorrupt
317				return 0, r.err
318			}
319			if n > len(r.decoded) {
320				if n > r.maxBlock {
321					r.err = ErrCorrupt
322					return 0, r.err
323				}
324				r.decoded = make([]byte, n)
325			}
326			if !r.readFull(r.decoded[:n], false) {
327				return 0, r.err
328			}
329			if crc(r.decoded[:n]) != checksum {
330				r.err = ErrCRC
331				return 0, r.err
332			}
333			r.i, r.j = 0, n
334			continue
335
336		case chunkTypeStreamIdentifier:
337			// Section 4.1. Stream identifier (chunk type 0xff).
338			if chunkLen != len(magicBody) {
339				r.err = ErrCorrupt
340				return 0, r.err
341			}
342			if !r.readFull(r.buf[:len(magicBody)], false) {
343				return 0, r.err
344			}
345			if string(r.buf[:len(magicBody)]) != magicBody {
346				if string(r.buf[:len(magicBody)]) != magicBodySnappy {
347					r.err = ErrCorrupt
348					return 0, r.err
349				} else {
350					r.snappyFrame = true
351				}
352			} else {
353				r.snappyFrame = false
354			}
355			continue
356		}
357
358		if chunkType <= 0x7f {
359			// Section 4.5. Reserved unskippable chunks (chunk types 0x02-0x7f).
360			r.err = ErrUnsupported
361			return 0, r.err
362		}
363		// Section 4.4 Padding (chunk type 0xfe).
364		// Section 4.6. Reserved skippable chunks (chunk types 0x80-0xfd).
365		if chunkLen > maxBlockSize {
366			r.err = ErrUnsupported
367			return 0, r.err
368		}
369
370		if !r.skipN(r.buf, chunkLen, false) {
371			return 0, r.err
372		}
373	}
374}
375
376// Skip will skip n bytes forward in the decompressed output.
377// For larger skips this consumes less CPU and is faster than reading output and discarding it.
378// CRC is not checked on skipped blocks.
379// io.ErrUnexpectedEOF is returned if the stream ends before all bytes have been skipped.
380// If a decoding error is encountered subsequent calls to Read will also fail.
381func (r *Reader) Skip(n int64) error {
382	if n < 0 {
383		return errors.New("attempted negative skip")
384	}
385	if r.err != nil {
386		return r.err
387	}
388
389	for n > 0 {
390		if r.i < r.j {
391			// Skip in buffer.
392			// decoded[i:j] contains decoded bytes that have not yet been passed on.
393			left := int64(r.j - r.i)
394			if left >= n {
395				r.i += int(n)
396				return nil
397			}
398			n -= int64(r.j - r.i)
399			r.i, r.j = 0, 0
400		}
401
402		// Buffer empty; read blocks until we have content.
403		if !r.readFull(r.buf[:4], true) {
404			if r.err == io.EOF {
405				r.err = io.ErrUnexpectedEOF
406			}
407			return r.err
408		}
409		chunkType := r.buf[0]
410		if !r.readHeader {
411			if chunkType != chunkTypeStreamIdentifier {
412				r.err = ErrCorrupt
413				return r.err
414			}
415			r.readHeader = true
416		}
417		chunkLen := int(r.buf[1]) | int(r.buf[2])<<8 | int(r.buf[3])<<16
418
419		// The chunk types are specified at
420		// https://github.com/google/snappy/blob/master/framing_format.txt
421		switch chunkType {
422		case chunkTypeCompressedData:
423			// Section 4.2. Compressed data (chunk type 0x00).
424			if chunkLen < checksumSize {
425				r.err = ErrCorrupt
426				return r.err
427			}
428			if !r.ensureBufferSize(chunkLen) {
429				if r.err == nil {
430					r.err = ErrUnsupported
431				}
432				return r.err
433			}
434			buf := r.buf[:chunkLen]
435			if !r.readFull(buf, false) {
436				return r.err
437			}
438			checksum := uint32(buf[0]) | uint32(buf[1])<<8 | uint32(buf[2])<<16 | uint32(buf[3])<<24
439			buf = buf[checksumSize:]
440
441			dLen, err := DecodedLen(buf)
442			if err != nil {
443				r.err = err
444				return r.err
445			}
446			if dLen > r.maxBlock {
447				r.err = ErrCorrupt
448				return r.err
449			}
450			// Check if destination is within this block
451			if int64(dLen) > n {
452				if len(r.decoded) < dLen {
453					r.decoded = make([]byte, dLen)
454				}
455				if _, err := Decode(r.decoded, buf); err != nil {
456					r.err = err
457					return r.err
458				}
459				if crc(r.decoded[:dLen]) != checksum {
460					r.err = ErrCorrupt
461					return r.err
462				}
463			} else {
464				// Skip block completely
465				n -= int64(dLen)
466				dLen = 0
467			}
468			r.i, r.j = 0, dLen
469			continue
470		case chunkTypeUncompressedData:
471			// Section 4.3. Uncompressed data (chunk type 0x01).
472			if chunkLen < checksumSize {
473				r.err = ErrCorrupt
474				return r.err
475			}
476			if !r.ensureBufferSize(chunkLen) {
477				if r.err != nil {
478					r.err = ErrUnsupported
479				}
480				return r.err
481			}
482			buf := r.buf[:checksumSize]
483			if !r.readFull(buf, false) {
484				return r.err
485			}
486			checksum := uint32(buf[0]) | uint32(buf[1])<<8 | uint32(buf[2])<<16 | uint32(buf[3])<<24
487			// Read directly into r.decoded instead of via r.buf.
488			n2 := chunkLen - checksumSize
489			if n2 > len(r.decoded) {
490				if n2 > r.maxBlock {
491					r.err = ErrCorrupt
492					return r.err
493				}
494				r.decoded = make([]byte, n2)
495			}
496			if !r.readFull(r.decoded[:n2], false) {
497				return r.err
498			}
499			if int64(n2) < n {
500				if crc(r.decoded[:n2]) != checksum {
501					r.err = ErrCorrupt
502					return r.err
503				}
504			}
505			r.i, r.j = 0, n2
506			continue
507		case chunkTypeStreamIdentifier:
508			// Section 4.1. Stream identifier (chunk type 0xff).
509			if chunkLen != len(magicBody) {
510				r.err = ErrCorrupt
511				return r.err
512			}
513			if !r.readFull(r.buf[:len(magicBody)], false) {
514				return r.err
515			}
516			if string(r.buf[:len(magicBody)]) != magicBody {
517				if string(r.buf[:len(magicBody)]) != magicBodySnappy {
518					r.err = ErrCorrupt
519					return r.err
520				}
521			}
522
523			continue
524		}
525
526		if chunkType <= 0x7f {
527			// Section 4.5. Reserved unskippable chunks (chunk types 0x02-0x7f).
528			r.err = ErrUnsupported
529			return r.err
530		}
531		if chunkLen > maxBlockSize {
532			r.err = ErrUnsupported
533			return r.err
534		}
535		// Section 4.4 Padding (chunk type 0xfe).
536		// Section 4.6. Reserved skippable chunks (chunk types 0x80-0xfd).
537		if !r.skipN(r.buf, chunkLen, false) {
538			return r.err
539		}
540	}
541	return nil
542}
543
544// ReadByte satisfies the io.ByteReader interface.
545func (r *Reader) ReadByte() (byte, error) {
546	if r.err != nil {
547		return 0, r.err
548	}
549	if r.i < r.j {
550		c := r.decoded[r.i]
551		r.i++
552		return c, nil
553	}
554	var tmp [1]byte
555	for i := 0; i < 10; i++ {
556		n, err := r.Read(tmp[:])
557		if err != nil {
558			return 0, err
559		}
560		if n == 1 {
561			return tmp[0], nil
562		}
563	}
564	return 0, io.ErrNoProgress
565}
566