1// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import (
8	"encoding/binary"
9	"errors"
10	"hash/crc32"
11	"io"
12
13	"github.com/klauspost/compress/huff0"
14	"github.com/klauspost/compress/snappy"
15)
16
17const (
18	snappyTagLiteral = 0x00
19	snappyTagCopy1   = 0x01
20	snappyTagCopy2   = 0x02
21	snappyTagCopy4   = 0x03
22)
23
24const (
25	snappyChecksumSize = 4
26	snappyMagicBody    = "sNaPpY"
27
28	// snappyMaxBlockSize is the maximum size of the input to encodeBlock. It is not
29	// part of the wire format per se, but some parts of the encoder assume
30	// that an offset fits into a uint16.
31	//
32	// Also, for the framing format (Writer type instead of Encode function),
33	// https://github.com/google/snappy/blob/master/framing_format.txt says
34	// that "the uncompressed data in a chunk must be no longer than 65536
35	// bytes".
36	snappyMaxBlockSize = 65536
37
38	// snappyMaxEncodedLenOfMaxBlockSize equals MaxEncodedLen(snappyMaxBlockSize), but is
39	// hard coded to be a const instead of a variable, so that obufLen can also
40	// be a const. Their equivalence is confirmed by
41	// TestMaxEncodedLenOfMaxBlockSize.
42	snappyMaxEncodedLenOfMaxBlockSize = 76490
43)
44
45const (
46	chunkTypeCompressedData   = 0x00
47	chunkTypeUncompressedData = 0x01
48	chunkTypePadding          = 0xfe
49	chunkTypeStreamIdentifier = 0xff
50)
51
52var (
53	// ErrSnappyCorrupt reports that the input is invalid.
54	ErrSnappyCorrupt = errors.New("snappy: corrupt input")
55	// ErrSnappyTooLarge reports that the uncompressed length is too large.
56	ErrSnappyTooLarge = errors.New("snappy: decoded block is too large")
57	// ErrSnappyUnsupported reports that the input isn't supported.
58	ErrSnappyUnsupported = errors.New("snappy: unsupported input")
59
60	errUnsupportedLiteralLength = errors.New("snappy: unsupported literal length")
61)
62
63// SnappyConverter can read SnappyConverter-compressed streams and convert them to zstd.
64// Conversion is done by converting the stream directly from Snappy without intermediate
65// full decoding.
66// Therefore the compression ratio is much less than what can be done by a full decompression
67// and compression, and a faulty Snappy stream may lead to a faulty Zstandard stream without
68// any errors being generated.
69// No CRC value is being generated and not all CRC values of the Snappy stream are checked.
70// However, it provides really fast recompression of Snappy streams.
71// The converter can be reused to avoid allocations, even after errors.
72type SnappyConverter struct {
73	r     io.Reader
74	err   error
75	buf   []byte
76	block *blockEnc
77}
78
79// Convert the Snappy stream supplied in 'in' and write the zStandard stream to 'w'.
80// If any error is detected on the Snappy stream it is returned.
81// The number of bytes written is returned.
82func (r *SnappyConverter) Convert(in io.Reader, w io.Writer) (int64, error) {
83	initPredefined()
84	r.err = nil
85	r.r = in
86	if r.block == nil {
87		r.block = &blockEnc{}
88		r.block.init()
89	}
90	r.block.initNewEncode()
91	if len(r.buf) != snappyMaxEncodedLenOfMaxBlockSize+snappyChecksumSize {
92		r.buf = make([]byte, snappyMaxEncodedLenOfMaxBlockSize+snappyChecksumSize)
93	}
94	r.block.litEnc.Reuse = huff0.ReusePolicyNone
95	var written int64
96	var readHeader bool
97	{
98		var header []byte
99		var n int
100		header, r.err = frameHeader{WindowSize: snappyMaxBlockSize}.appendTo(r.buf[:0])
101
102		n, r.err = w.Write(header)
103		if r.err != nil {
104			return written, r.err
105		}
106		written += int64(n)
107	}
108
109	for {
110		if !r.readFull(r.buf[:4], true) {
111			// Add empty last block
112			r.block.reset(nil)
113			r.block.last = true
114			err := r.block.encodeLits(false)
115			if err != nil {
116				return written, err
117			}
118			n, err := w.Write(r.block.output)
119			if err != nil {
120				return written, err
121			}
122			written += int64(n)
123
124			return written, r.err
125		}
126		chunkType := r.buf[0]
127		if !readHeader {
128			if chunkType != chunkTypeStreamIdentifier {
129				println("chunkType != chunkTypeStreamIdentifier", chunkType)
130				r.err = ErrSnappyCorrupt
131				return written, r.err
132			}
133			readHeader = true
134		}
135		chunkLen := int(r.buf[1]) | int(r.buf[2])<<8 | int(r.buf[3])<<16
136		if chunkLen > len(r.buf) {
137			println("chunkLen > len(r.buf)", chunkType)
138			r.err = ErrSnappyUnsupported
139			return written, r.err
140		}
141
142		// The chunk types are specified at
143		// https://github.com/google/snappy/blob/master/framing_format.txt
144		switch chunkType {
145		case chunkTypeCompressedData:
146			// Section 4.2. Compressed data (chunk type 0x00).
147			if chunkLen < snappyChecksumSize {
148				println("chunkLen < snappyChecksumSize", chunkLen, snappyChecksumSize)
149				r.err = ErrSnappyCorrupt
150				return written, r.err
151			}
152			buf := r.buf[:chunkLen]
153			if !r.readFull(buf, false) {
154				return written, r.err
155			}
156			//checksum := uint32(buf[0]) | uint32(buf[1])<<8 | uint32(buf[2])<<16 | uint32(buf[3])<<24
157			buf = buf[snappyChecksumSize:]
158
159			n, hdr, err := snappyDecodedLen(buf)
160			if err != nil {
161				r.err = err
162				return written, r.err
163			}
164			buf = buf[hdr:]
165			if n > snappyMaxBlockSize {
166				println("n > snappyMaxBlockSize", n, snappyMaxBlockSize)
167				r.err = ErrSnappyCorrupt
168				return written, r.err
169			}
170			r.block.reset(nil)
171			r.block.pushOffsets()
172			if err := decodeSnappy(r.block, buf); err != nil {
173				r.err = err
174				return written, r.err
175			}
176			if r.block.size+r.block.extraLits != n {
177				printf("invalid size, want %d, got %d\n", n, r.block.size+r.block.extraLits)
178				r.err = ErrSnappyCorrupt
179				return written, r.err
180			}
181			err = r.block.encode(false, false)
182			switch err {
183			case errIncompressible:
184				r.block.popOffsets()
185				r.block.reset(nil)
186				r.block.literals, err = snappy.Decode(r.block.literals[:n], r.buf[snappyChecksumSize:chunkLen])
187				if err != nil {
188					println("snappy.Decode:", err)
189					return written, err
190				}
191				err = r.block.encodeLits(false)
192				if err != nil {
193					return written, err
194				}
195			case nil:
196			default:
197				return written, err
198			}
199
200			n, r.err = w.Write(r.block.output)
201			if r.err != nil {
202				return written, err
203			}
204			written += int64(n)
205			continue
206		case chunkTypeUncompressedData:
207			if debug {
208				println("Uncompressed, chunklen", chunkLen)
209			}
210			// Section 4.3. Uncompressed data (chunk type 0x01).
211			if chunkLen < snappyChecksumSize {
212				println("chunkLen < snappyChecksumSize", chunkLen, snappyChecksumSize)
213				r.err = ErrSnappyCorrupt
214				return written, r.err
215			}
216			r.block.reset(nil)
217			buf := r.buf[:snappyChecksumSize]
218			if !r.readFull(buf, false) {
219				return written, r.err
220			}
221			checksum := uint32(buf[0]) | uint32(buf[1])<<8 | uint32(buf[2])<<16 | uint32(buf[3])<<24
222			// Read directly into r.decoded instead of via r.buf.
223			n := chunkLen - snappyChecksumSize
224			if n > snappyMaxBlockSize {
225				println("n > snappyMaxBlockSize", n, snappyMaxBlockSize)
226				r.err = ErrSnappyCorrupt
227				return written, r.err
228			}
229			r.block.literals = r.block.literals[:n]
230			if !r.readFull(r.block.literals, false) {
231				return written, r.err
232			}
233			if snappyCRC(r.block.literals) != checksum {
234				println("literals crc mismatch")
235				r.err = ErrSnappyCorrupt
236				return written, r.err
237			}
238			err := r.block.encodeLits(false)
239			if err != nil {
240				return written, err
241			}
242			n, r.err = w.Write(r.block.output)
243			if r.err != nil {
244				return written, err
245			}
246			written += int64(n)
247			continue
248
249		case chunkTypeStreamIdentifier:
250			if debug {
251				println("stream id", chunkLen, len(snappyMagicBody))
252			}
253			// Section 4.1. Stream identifier (chunk type 0xff).
254			if chunkLen != len(snappyMagicBody) {
255				println("chunkLen != len(snappyMagicBody)", chunkLen, len(snappyMagicBody))
256				r.err = ErrSnappyCorrupt
257				return written, r.err
258			}
259			if !r.readFull(r.buf[:len(snappyMagicBody)], false) {
260				return written, r.err
261			}
262			for i := 0; i < len(snappyMagicBody); i++ {
263				if r.buf[i] != snappyMagicBody[i] {
264					println("r.buf[i] != snappyMagicBody[i]", r.buf[i], snappyMagicBody[i], i)
265					r.err = ErrSnappyCorrupt
266					return written, r.err
267				}
268			}
269			continue
270		}
271
272		if chunkType <= 0x7f {
273			// Section 4.5. Reserved unskippable chunks (chunk types 0x02-0x7f).
274			println("chunkType <= 0x7f")
275			r.err = ErrSnappyUnsupported
276			return written, r.err
277		}
278		// Section 4.4 Padding (chunk type 0xfe).
279		// Section 4.6. Reserved skippable chunks (chunk types 0x80-0xfd).
280		if !r.readFull(r.buf[:chunkLen], false) {
281			return written, r.err
282		}
283	}
284}
285
286// decodeSnappy writes the decoding of src to dst. It assumes that the varint-encoded
287// length of the decompressed bytes has already been read.
288func decodeSnappy(blk *blockEnc, src []byte) error {
289	//decodeRef(make([]byte, snappyMaxBlockSize), src)
290	var s, length int
291	lits := blk.extraLits
292	var offset uint32
293	for s < len(src) {
294		switch src[s] & 0x03 {
295		case snappyTagLiteral:
296			x := uint32(src[s] >> 2)
297			switch {
298			case x < 60:
299				s++
300			case x == 60:
301				s += 2
302				if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
303					println("uint(s) > uint(len(src)", s, src)
304					return ErrSnappyCorrupt
305				}
306				x = uint32(src[s-1])
307			case x == 61:
308				s += 3
309				if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
310					println("uint(s) > uint(len(src)", s, src)
311					return ErrSnappyCorrupt
312				}
313				x = uint32(src[s-2]) | uint32(src[s-1])<<8
314			case x == 62:
315				s += 4
316				if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
317					println("uint(s) > uint(len(src)", s, src)
318					return ErrSnappyCorrupt
319				}
320				x = uint32(src[s-3]) | uint32(src[s-2])<<8 | uint32(src[s-1])<<16
321			case x == 63:
322				s += 5
323				if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
324					println("uint(s) > uint(len(src)", s, src)
325					return ErrSnappyCorrupt
326				}
327				x = uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24
328			}
329			if x > snappyMaxBlockSize {
330				println("x > snappyMaxBlockSize", x, snappyMaxBlockSize)
331				return ErrSnappyCorrupt
332			}
333			length = int(x) + 1
334			if length <= 0 {
335				println("length <= 0 ", length)
336
337				return errUnsupportedLiteralLength
338			}
339			//if length > snappyMaxBlockSize-d || uint32(length) > len(src)-s {
340			//	return ErrSnappyCorrupt
341			//}
342
343			blk.literals = append(blk.literals, src[s:s+length]...)
344			//println(length, "litLen")
345			lits += length
346			s += length
347			continue
348
349		case snappyTagCopy1:
350			s += 2
351			if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
352				println("uint(s) > uint(len(src)", s, len(src))
353				return ErrSnappyCorrupt
354			}
355			length = 4 + int(src[s-2])>>2&0x7
356			offset = uint32(src[s-2])&0xe0<<3 | uint32(src[s-1])
357
358		case snappyTagCopy2:
359			s += 3
360			if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
361				println("uint(s) > uint(len(src)", s, len(src))
362				return ErrSnappyCorrupt
363			}
364			length = 1 + int(src[s-3])>>2
365			offset = uint32(src[s-2]) | uint32(src[s-1])<<8
366
367		case snappyTagCopy4:
368			s += 5
369			if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
370				println("uint(s) > uint(len(src)", s, len(src))
371				return ErrSnappyCorrupt
372			}
373			length = 1 + int(src[s-5])>>2
374			offset = uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24
375		}
376
377		if offset <= 0 || blk.size+lits < int(offset) /*|| length > len(blk)-d */ {
378			println("offset <= 0 || blk.size+lits < int(offset)", offset, blk.size+lits, int(offset), blk.size, lits)
379
380			return ErrSnappyCorrupt
381		}
382
383		// Check if offset is one of the recent offsets.
384		// Adjusts the output offset accordingly.
385		// Gives a tiny bit of compression, typically around 1%.
386		if false {
387			offset = blk.matchOffset(offset, uint32(lits))
388		} else {
389			offset += 3
390		}
391
392		blk.sequences = append(blk.sequences, seq{
393			litLen:   uint32(lits),
394			offset:   offset,
395			matchLen: uint32(length) - zstdMinMatch,
396		})
397		blk.size += length + lits
398		lits = 0
399	}
400	blk.extraLits = lits
401	return nil
402}
403
404func (r *SnappyConverter) readFull(p []byte, allowEOF bool) (ok bool) {
405	if _, r.err = io.ReadFull(r.r, p); r.err != nil {
406		if r.err == io.ErrUnexpectedEOF || (r.err == io.EOF && !allowEOF) {
407			r.err = ErrSnappyCorrupt
408		}
409		return false
410	}
411	return true
412}
413
414var crcTable = crc32.MakeTable(crc32.Castagnoli)
415
416// crc implements the checksum specified in section 3 of
417// https://github.com/google/snappy/blob/master/framing_format.txt
418func snappyCRC(b []byte) uint32 {
419	c := crc32.Update(0, crcTable, b)
420	return uint32(c>>15|c<<17) + 0xa282ead8
421}
422
423// snappyDecodedLen returns the length of the decoded block and the number of bytes
424// that the length header occupied.
425func snappyDecodedLen(src []byte) (blockLen, headerLen int, err error) {
426	v, n := binary.Uvarint(src)
427	if n <= 0 || v > 0xffffffff {
428		return 0, 0, ErrSnappyCorrupt
429	}
430
431	const wordSize = 32 << (^uint(0) >> 32 & 1)
432	if wordSize == 32 && v > 0x7fffffff {
433		return 0, 0, ErrSnappyTooLarge
434	}
435	return int(v), n, nil
436}
437