1// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import (
8	"fmt"
9	"math"
10	"math/bits"
11
12	"github.com/klauspost/compress/zstd/internal/xxhash"
13)
14
15const (
16	tableBits      = 15             // Bits used in the table
17	tableSize      = 1 << tableBits // Size of the table
18	tableMask      = tableSize - 1  // Mask for table indices. Redundant, but can eliminate bounds checks.
19	maxMatchLength = 131074
20)
21
22type tableEntry struct {
23	val    uint32
24	offset int32
25}
26
27type fastBase struct {
28	// cur is the offset at the start of hist
29	cur int32
30	// maximum offset. Should be at least 2x block size.
31	maxMatchOff int32
32	hist        []byte
33	crc         *xxhash.Digest
34	tmp         [8]byte
35	blk         *blockEnc
36}
37
38type fastEncoder struct {
39	fastBase
40	table [tableSize]tableEntry
41}
42
43// CRC returns the underlying CRC writer.
44func (e *fastBase) CRC() *xxhash.Digest {
45	return e.crc
46}
47
48// AppendCRC will append the CRC to the destination slice and return it.
49func (e *fastBase) AppendCRC(dst []byte) []byte {
50	crc := e.crc.Sum(e.tmp[:0])
51	dst = append(dst, crc[7], crc[6], crc[5], crc[4])
52	return dst
53}
54
55// WindowSize returns the window size of the encoder,
56// or a window size small enough to contain the input size, if > 0.
57func (e *fastBase) WindowSize(size int) int32 {
58	if size > 0 && size < int(e.maxMatchOff) {
59		b := int32(1) << uint(bits.Len(uint(size)))
60		// Keep minimum window.
61		if b < 1024 {
62			b = 1024
63		}
64		return b
65	}
66	return e.maxMatchOff
67}
68
69// Block returns the current block.
70func (e *fastBase) Block() *blockEnc {
71	return e.blk
72}
73
74// Encode mimmics functionality in zstd_fast.c
75func (e *fastEncoder) Encode(blk *blockEnc, src []byte) {
76	const (
77		inputMargin            = 8
78		minNonLiteralBlockSize = 1 + 1 + inputMargin
79	)
80
81	// Protect against e.cur wraparound.
82	for e.cur >= bufferReset {
83		if len(e.hist) == 0 {
84			for i := range e.table[:] {
85				e.table[i] = tableEntry{}
86			}
87			e.cur = e.maxMatchOff
88			break
89		}
90		// Shift down everything in the table that isn't already too far away.
91		minOff := e.cur + int32(len(e.hist)) - e.maxMatchOff
92		for i := range e.table[:] {
93			v := e.table[i].offset
94			if v < minOff {
95				v = 0
96			} else {
97				v = v - e.cur + e.maxMatchOff
98			}
99			e.table[i].offset = v
100		}
101		e.cur = e.maxMatchOff
102		break
103	}
104
105	s := e.addBlock(src)
106	blk.size = len(src)
107	if len(src) < minNonLiteralBlockSize {
108		blk.extraLits = len(src)
109		blk.literals = blk.literals[:len(src)]
110		copy(blk.literals, src)
111		return
112	}
113
114	// Override src
115	src = e.hist
116	sLimit := int32(len(src)) - inputMargin
117	// stepSize is the number of bytes to skip on every main loop iteration.
118	// It should be >= 2.
119	const stepSize = 2
120
121	// TEMPLATE
122	const hashLog = tableBits
123	// seems global, but would be nice to tweak.
124	const kSearchStrength = 8
125
126	// nextEmit is where in src the next emitLiteral should start from.
127	nextEmit := s
128	cv := load6432(src, s)
129
130	// Relative offsets
131	offset1 := int32(blk.recentOffsets[0])
132	offset2 := int32(blk.recentOffsets[1])
133
134	addLiterals := func(s *seq, until int32) {
135		if until == nextEmit {
136			return
137		}
138		blk.literals = append(blk.literals, src[nextEmit:until]...)
139		s.litLen = uint32(until - nextEmit)
140	}
141	if debug {
142		println("recent offsets:", blk.recentOffsets)
143	}
144
145encodeLoop:
146	for {
147		// t will contain the match offset when we find one.
148		// When existing the search loop, we have already checked 4 bytes.
149		var t int32
150
151		// We will not use repeat offsets across blocks.
152		// By not using them for the first 3 matches
153		canRepeat := len(blk.sequences) > 2
154
155		for {
156			if debugAsserts && canRepeat && offset1 == 0 {
157				panic("offset0 was 0")
158			}
159
160			nextHash := hash6(cv, hashLog)
161			nextHash2 := hash6(cv>>8, hashLog)
162			candidate := e.table[nextHash]
163			candidate2 := e.table[nextHash2]
164			repIndex := s - offset1 + 2
165
166			e.table[nextHash] = tableEntry{offset: s + e.cur, val: uint32(cv)}
167			e.table[nextHash2] = tableEntry{offset: s + e.cur + 1, val: uint32(cv >> 8)}
168
169			if canRepeat && repIndex >= 0 && load3232(src, repIndex) == uint32(cv>>16) {
170				// Consider history as well.
171				var seq seq
172				var length int32
173				// length = 4 + e.matchlen(s+6, repIndex+4, src)
174				{
175					a := src[s+6:]
176					b := src[repIndex+4:]
177					endI := len(a) & (math.MaxInt32 - 7)
178					length = int32(endI) + 4
179					for i := 0; i < endI; i += 8 {
180						if diff := load64(a, i) ^ load64(b, i); diff != 0 {
181							length = int32(i+bits.TrailingZeros64(diff)>>3) + 4
182							break
183						}
184					}
185				}
186
187				seq.matchLen = uint32(length - zstdMinMatch)
188
189				// We might be able to match backwards.
190				// Extend as long as we can.
191				start := s + 2
192				// We end the search early, so we don't risk 0 literals
193				// and have to do special offset treatment.
194				startLimit := nextEmit + 1
195
196				sMin := s - e.maxMatchOff
197				if sMin < 0 {
198					sMin = 0
199				}
200				for repIndex > sMin && start > startLimit && src[repIndex-1] == src[start-1] && seq.matchLen < maxMatchLength-zstdMinMatch {
201					repIndex--
202					start--
203					seq.matchLen++
204				}
205				addLiterals(&seq, start)
206
207				// rep 0
208				seq.offset = 1
209				if debugSequences {
210					println("repeat sequence", seq, "next s:", s)
211				}
212				blk.sequences = append(blk.sequences, seq)
213				s += length + 2
214				nextEmit = s
215				if s >= sLimit {
216					if debug {
217						println("repeat ended", s, length)
218
219					}
220					break encodeLoop
221				}
222				cv = load6432(src, s)
223				continue
224			}
225			coffset0 := s - (candidate.offset - e.cur)
226			coffset1 := s - (candidate2.offset - e.cur) + 1
227			if coffset0 < e.maxMatchOff && uint32(cv) == candidate.val {
228				// found a regular match
229				t = candidate.offset - e.cur
230				if debugAsserts && s <= t {
231					panic(fmt.Sprintf("s (%d) <= t (%d)", s, t))
232				}
233				if debugAsserts && s-t > e.maxMatchOff {
234					panic("s - t >e.maxMatchOff")
235				}
236				break
237			}
238
239			if coffset1 < e.maxMatchOff && uint32(cv>>8) == candidate2.val {
240				// found a regular match
241				t = candidate2.offset - e.cur
242				s++
243				if debugAsserts && s <= t {
244					panic(fmt.Sprintf("s (%d) <= t (%d)", s, t))
245				}
246				if debugAsserts && s-t > e.maxMatchOff {
247					panic("s - t >e.maxMatchOff")
248				}
249				if debugAsserts && t < 0 {
250					panic("t<0")
251				}
252				break
253			}
254			s += stepSize + ((s - nextEmit) >> (kSearchStrength - 1))
255			if s >= sLimit {
256				break encodeLoop
257			}
258			cv = load6432(src, s)
259		}
260		// A 4-byte match has been found. We'll later see if more than 4 bytes.
261		offset2 = offset1
262		offset1 = s - t
263
264		if debugAsserts && s <= t {
265			panic(fmt.Sprintf("s (%d) <= t (%d)", s, t))
266		}
267
268		if debugAsserts && canRepeat && int(offset1) > len(src) {
269			panic("invalid offset")
270		}
271
272		// Extend the 4-byte match as long as possible.
273		//l := e.matchlen(s+4, t+4, src) + 4
274		var l int32
275		{
276			a := src[s+4:]
277			b := src[t+4:]
278			endI := len(a) & (math.MaxInt32 - 7)
279			l = int32(endI) + 4
280			for i := 0; i < endI; i += 8 {
281				if diff := load64(a, i) ^ load64(b, i); diff != 0 {
282					l = int32(i+bits.TrailingZeros64(diff)>>3) + 4
283					break
284				}
285			}
286		}
287
288		// Extend backwards
289		tMin := s - e.maxMatchOff
290		if tMin < 0 {
291			tMin = 0
292		}
293		for t > tMin && s > nextEmit && src[t-1] == src[s-1] && l < maxMatchLength {
294			s--
295			t--
296			l++
297		}
298
299		// Write our sequence.
300		var seq seq
301		seq.litLen = uint32(s - nextEmit)
302		seq.matchLen = uint32(l - zstdMinMatch)
303		if seq.litLen > 0 {
304			blk.literals = append(blk.literals, src[nextEmit:s]...)
305		}
306		// Don't use repeat offsets
307		seq.offset = uint32(s-t) + 3
308		s += l
309		if debugSequences {
310			println("sequence", seq, "next s:", s)
311		}
312		blk.sequences = append(blk.sequences, seq)
313		nextEmit = s
314		if s >= sLimit {
315			break encodeLoop
316		}
317		cv = load6432(src, s)
318
319		// Check offset 2
320		if o2 := s - offset2; canRepeat && load3232(src, o2) == uint32(cv) {
321			// We have at least 4 byte match.
322			// No need to check backwards. We come straight from a match
323			//l := 4 + e.matchlen(s+4, o2+4, src)
324			var l int32
325			{
326				a := src[s+4:]
327				b := src[o2+4:]
328				endI := len(a) & (math.MaxInt32 - 7)
329				l = int32(endI) + 4
330				for i := 0; i < endI; i += 8 {
331					if diff := load64(a, i) ^ load64(b, i); diff != 0 {
332						l = int32(i+bits.TrailingZeros64(diff)>>3) + 4
333						break
334					}
335				}
336			}
337
338			// Store this, since we have it.
339			nextHash := hash6(cv, hashLog)
340			e.table[nextHash] = tableEntry{offset: s + e.cur, val: uint32(cv)}
341			seq.matchLen = uint32(l) - zstdMinMatch
342			seq.litLen = 0
343			// Since litlen is always 0, this is offset 1.
344			seq.offset = 1
345			s += l
346			nextEmit = s
347			if debugSequences {
348				println("sequence", seq, "next s:", s)
349			}
350			blk.sequences = append(blk.sequences, seq)
351
352			// Swap offset 1 and 2.
353			offset1, offset2 = offset2, offset1
354			if s >= sLimit {
355				break encodeLoop
356			}
357			// Prepare next loop.
358			cv = load6432(src, s)
359		}
360	}
361
362	if int(nextEmit) < len(src) {
363		blk.literals = append(blk.literals, src[nextEmit:]...)
364		blk.extraLits = len(src) - int(nextEmit)
365	}
366	blk.recentOffsets[0] = uint32(offset1)
367	blk.recentOffsets[1] = uint32(offset2)
368	if debug {
369		println("returning, recent offsets:", blk.recentOffsets, "extra literals:", blk.extraLits)
370	}
371}
372
373// EncodeNoHist will encode a block with no history and no following blocks.
374// Most notable difference is that src will not be copied for history and
375// we do not need to check for max match length.
376func (e *fastEncoder) EncodeNoHist(blk *blockEnc, src []byte) {
377	const (
378		inputMargin            = 8
379		minNonLiteralBlockSize = 1 + 1 + inputMargin
380	)
381	if debug {
382		if len(src) > maxBlockSize {
383			panic("src too big")
384		}
385	}
386
387	// Protect against e.cur wraparound.
388	if e.cur >= bufferReset {
389		for i := range e.table[:] {
390			e.table[i] = tableEntry{}
391		}
392		e.cur = e.maxMatchOff
393	}
394
395	s := int32(0)
396	blk.size = len(src)
397	if len(src) < minNonLiteralBlockSize {
398		blk.extraLits = len(src)
399		blk.literals = blk.literals[:len(src)]
400		copy(blk.literals, src)
401		return
402	}
403
404	sLimit := int32(len(src)) - inputMargin
405	// stepSize is the number of bytes to skip on every main loop iteration.
406	// It should be >= 2.
407	const stepSize = 2
408
409	// TEMPLATE
410	const hashLog = tableBits
411	// seems global, but would be nice to tweak.
412	const kSearchStrength = 8
413
414	// nextEmit is where in src the next emitLiteral should start from.
415	nextEmit := s
416	cv := load6432(src, s)
417
418	// Relative offsets
419	offset1 := int32(blk.recentOffsets[0])
420	offset2 := int32(blk.recentOffsets[1])
421
422	addLiterals := func(s *seq, until int32) {
423		if until == nextEmit {
424			return
425		}
426		blk.literals = append(blk.literals, src[nextEmit:until]...)
427		s.litLen = uint32(until - nextEmit)
428	}
429	if debug {
430		println("recent offsets:", blk.recentOffsets)
431	}
432
433encodeLoop:
434	for {
435		// t will contain the match offset when we find one.
436		// When existing the search loop, we have already checked 4 bytes.
437		var t int32
438
439		// We will not use repeat offsets across blocks.
440		// By not using them for the first 3 matches
441
442		for {
443			nextHash := hash6(cv, hashLog)
444			nextHash2 := hash6(cv>>8, hashLog)
445			candidate := e.table[nextHash]
446			candidate2 := e.table[nextHash2]
447			repIndex := s - offset1 + 2
448
449			e.table[nextHash] = tableEntry{offset: s + e.cur, val: uint32(cv)}
450			e.table[nextHash2] = tableEntry{offset: s + e.cur + 1, val: uint32(cv >> 8)}
451
452			if len(blk.sequences) > 2 && load3232(src, repIndex) == uint32(cv>>16) {
453				// Consider history as well.
454				var seq seq
455				// length := 4 + e.matchlen(s+6, repIndex+4, src)
456				// length := 4 + int32(matchLen(src[s+6:], src[repIndex+4:]))
457				var length int32
458				{
459					a := src[s+6:]
460					b := src[repIndex+4:]
461					endI := len(a) & (math.MaxInt32 - 7)
462					length = int32(endI) + 4
463					for i := 0; i < endI; i += 8 {
464						if diff := load64(a, i) ^ load64(b, i); diff != 0 {
465							length = int32(i+bits.TrailingZeros64(diff)>>3) + 4
466							break
467						}
468					}
469				}
470
471				seq.matchLen = uint32(length - zstdMinMatch)
472
473				// We might be able to match backwards.
474				// Extend as long as we can.
475				start := s + 2
476				// We end the search early, so we don't risk 0 literals
477				// and have to do special offset treatment.
478				startLimit := nextEmit + 1
479
480				sMin := s - e.maxMatchOff
481				if sMin < 0 {
482					sMin = 0
483				}
484				for repIndex > sMin && start > startLimit && src[repIndex-1] == src[start-1] {
485					repIndex--
486					start--
487					seq.matchLen++
488				}
489				addLiterals(&seq, start)
490
491				// rep 0
492				seq.offset = 1
493				if debugSequences {
494					println("repeat sequence", seq, "next s:", s)
495				}
496				blk.sequences = append(blk.sequences, seq)
497				s += length + 2
498				nextEmit = s
499				if s >= sLimit {
500					if debug {
501						println("repeat ended", s, length)
502
503					}
504					break encodeLoop
505				}
506				cv = load6432(src, s)
507				continue
508			}
509			coffset0 := s - (candidate.offset - e.cur)
510			coffset1 := s - (candidate2.offset - e.cur) + 1
511			if coffset0 < e.maxMatchOff && uint32(cv) == candidate.val {
512				// found a regular match
513				t = candidate.offset - e.cur
514				if debugAsserts && s <= t {
515					panic(fmt.Sprintf("s (%d) <= t (%d)", s, t))
516				}
517				if debugAsserts && s-t > e.maxMatchOff {
518					panic("s - t >e.maxMatchOff")
519				}
520				if debugAsserts && t < 0 {
521					panic(fmt.Sprintf("t (%d) < 0, candidate.offset: %d, e.cur: %d, coffset0: %d, e.maxMatchOff: %d", t, candidate.offset, e.cur, coffset0, e.maxMatchOff))
522				}
523				break
524			}
525
526			if coffset1 < e.maxMatchOff && uint32(cv>>8) == candidate2.val {
527				// found a regular match
528				t = candidate2.offset - e.cur
529				s++
530				if debugAsserts && s <= t {
531					panic(fmt.Sprintf("s (%d) <= t (%d)", s, t))
532				}
533				if debugAsserts && s-t > e.maxMatchOff {
534					panic("s - t >e.maxMatchOff")
535				}
536				if debugAsserts && t < 0 {
537					panic("t<0")
538				}
539				break
540			}
541			s += stepSize + ((s - nextEmit) >> (kSearchStrength - 1))
542			if s >= sLimit {
543				break encodeLoop
544			}
545			cv = load6432(src, s)
546		}
547		// A 4-byte match has been found. We'll later see if more than 4 bytes.
548		offset2 = offset1
549		offset1 = s - t
550
551		if debugAsserts && s <= t {
552			panic(fmt.Sprintf("s (%d) <= t (%d)", s, t))
553		}
554
555		if debugAsserts && t < 0 {
556			panic(fmt.Sprintf("t (%d) < 0 ", t))
557		}
558		// Extend the 4-byte match as long as possible.
559		//l := e.matchlenNoHist(s+4, t+4, src) + 4
560		// l := int32(matchLen(src[s+4:], src[t+4:])) + 4
561		var l int32
562		{
563			a := src[s+4:]
564			b := src[t+4:]
565			endI := len(a) & (math.MaxInt32 - 7)
566			l = int32(endI) + 4
567			for i := 0; i < endI; i += 8 {
568				if diff := load64(a, i) ^ load64(b, i); diff != 0 {
569					l = int32(i+bits.TrailingZeros64(diff)>>3) + 4
570					break
571				}
572			}
573		}
574
575		// Extend backwards
576		tMin := s - e.maxMatchOff
577		if tMin < 0 {
578			tMin = 0
579		}
580		for t > tMin && s > nextEmit && src[t-1] == src[s-1] {
581			s--
582			t--
583			l++
584		}
585
586		// Write our sequence.
587		var seq seq
588		seq.litLen = uint32(s - nextEmit)
589		seq.matchLen = uint32(l - zstdMinMatch)
590		if seq.litLen > 0 {
591			blk.literals = append(blk.literals, src[nextEmit:s]...)
592		}
593		// Don't use repeat offsets
594		seq.offset = uint32(s-t) + 3
595		s += l
596		if debugSequences {
597			println("sequence", seq, "next s:", s)
598		}
599		blk.sequences = append(blk.sequences, seq)
600		nextEmit = s
601		if s >= sLimit {
602			break encodeLoop
603		}
604		cv = load6432(src, s)
605
606		// Check offset 2
607		if o2 := s - offset2; len(blk.sequences) > 2 && load3232(src, o2) == uint32(cv) {
608			// We have at least 4 byte match.
609			// No need to check backwards. We come straight from a match
610			//l := 4 + e.matchlenNoHist(s+4, o2+4, src)
611			// l := 4 + int32(matchLen(src[s+4:], src[o2+4:]))
612			var l int32
613			{
614				a := src[s+4:]
615				b := src[o2+4:]
616				endI := len(a) & (math.MaxInt32 - 7)
617				l = int32(endI) + 4
618				for i := 0; i < endI; i += 8 {
619					if diff := load64(a, i) ^ load64(b, i); diff != 0 {
620						l = int32(i+bits.TrailingZeros64(diff)>>3) + 4
621						break
622					}
623				}
624			}
625
626			// Store this, since we have it.
627			nextHash := hash6(cv, hashLog)
628			e.table[nextHash] = tableEntry{offset: s + e.cur, val: uint32(cv)}
629			seq.matchLen = uint32(l) - zstdMinMatch
630			seq.litLen = 0
631			// Since litlen is always 0, this is offset 1.
632			seq.offset = 1
633			s += l
634			nextEmit = s
635			if debugSequences {
636				println("sequence", seq, "next s:", s)
637			}
638			blk.sequences = append(blk.sequences, seq)
639
640			// Swap offset 1 and 2.
641			offset1, offset2 = offset2, offset1
642			if s >= sLimit {
643				break encodeLoop
644			}
645			// Prepare next loop.
646			cv = load6432(src, s)
647		}
648	}
649
650	if int(nextEmit) < len(src) {
651		blk.literals = append(blk.literals, src[nextEmit:]...)
652		blk.extraLits = len(src) - int(nextEmit)
653	}
654	if debug {
655		println("returning, recent offsets:", blk.recentOffsets, "extra literals:", blk.extraLits)
656	}
657	// We do not store history, so we must offset e.cur to avoid false matches for next user.
658	if e.cur < bufferReset {
659		e.cur += int32(len(src))
660	}
661}
662
663func (e *fastBase) addBlock(src []byte) int32 {
664	if debugAsserts && e.cur > bufferReset {
665		panic(fmt.Sprintf("ecur (%d) > buffer reset (%d)", e.cur, bufferReset))
666	}
667	// check if we have space already
668	if len(e.hist)+len(src) > cap(e.hist) {
669		if cap(e.hist) == 0 {
670			l := e.maxMatchOff * 2
671			// Make it at least 1MB.
672			if l < 1<<20 {
673				l = 1 << 20
674			}
675			e.hist = make([]byte, 0, l)
676		} else {
677			if cap(e.hist) < int(e.maxMatchOff*2) {
678				panic("unexpected buffer size")
679			}
680			// Move down
681			offset := int32(len(e.hist)) - e.maxMatchOff
682			copy(e.hist[0:e.maxMatchOff], e.hist[offset:])
683			e.cur += offset
684			e.hist = e.hist[:e.maxMatchOff]
685		}
686	}
687	s := int32(len(e.hist))
688	e.hist = append(e.hist, src...)
689	return s
690}
691
692// useBlock will replace the block with the provided one,
693// but transfer recent offsets from the previous.
694func (e *fastBase) UseBlock(enc *blockEnc) {
695	enc.reset(e.blk)
696	e.blk = enc
697}
698
699func (e *fastBase) matchlenNoHist(s, t int32, src []byte) int32 {
700	// Extend the match to be as long as possible.
701	return int32(matchLen(src[s:], src[t:]))
702}
703
704func (e *fastBase) matchlen(s, t int32, src []byte) int32 {
705	if debugAsserts {
706		if s < 0 {
707			err := fmt.Sprintf("s (%d) < 0", s)
708			panic(err)
709		}
710		if t < 0 {
711			err := fmt.Sprintf("s (%d) < 0", s)
712			panic(err)
713		}
714		if s-t > e.maxMatchOff {
715			err := fmt.Sprintf("s (%d) - t (%d) > maxMatchOff (%d)", s, t, e.maxMatchOff)
716			panic(err)
717		}
718		if len(src)-int(s) > maxCompressedBlockSize {
719			panic(fmt.Sprintf("len(src)-s (%d) > maxCompressedBlockSize (%d)", len(src)-int(s), maxCompressedBlockSize))
720		}
721	}
722
723	// Extend the match to be as long as possible.
724	return int32(matchLen(src[s:], src[t:]))
725}
726
727// Reset the encoding table.
728func (e *fastBase) Reset(singleBlock bool) {
729	if e.blk == nil {
730		e.blk = &blockEnc{}
731		e.blk.init()
732	} else {
733		e.blk.reset(nil)
734	}
735	e.blk.initNewEncode()
736	if e.crc == nil {
737		e.crc = xxhash.New()
738	} else {
739		e.crc.Reset()
740	}
741	if !singleBlock && cap(e.hist) < int(e.maxMatchOff*2) {
742		l := e.maxMatchOff * 2
743		// Make it at least 1MB.
744		if l < 1<<20 {
745			l = 1 << 20
746		}
747		e.hist = make([]byte, 0, l)
748	}
749	// We offset current position so everything will be out of reach.
750	// If above reset line, history will be purged.
751	if e.cur < bufferReset {
752		e.cur += e.maxMatchOff + int32(len(e.hist))
753	}
754	e.hist = e.hist[:0]
755}
756