1// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import (
8	"fmt"
9	"math/bits"
10)
11
12const (
13	bestLongTableBits = 20                     // Bits used in the long match table
14	bestLongTableSize = 1 << bestLongTableBits // Size of the table
15
16	// Note: Increasing the short table bits or making the hash shorter
17	// can actually lead to compression degradation since it will 'steal' more from the
18	// long match table and match offsets are quite big.
19	// This greatly depends on the type of input.
20	bestShortTableBits = 16                      // Bits used in the short match table
21	bestShortTableSize = 1 << bestShortTableBits // Size of the table
22)
23
24// bestFastEncoder uses 2 tables, one for short matches (5 bytes) and one for long matches.
25// The long match table contains the previous entry with the same hash,
26// effectively making it a "chain" of length 2.
27// When we find a long match we choose between the two values and select the longest.
28// When we find a short match, after checking the long, we check if we can find a long at n+1
29// and that it is longer (lazy matching).
30type bestFastEncoder struct {
31	fastBase
32	table         [bestShortTableSize]prevEntry
33	longTable     [bestLongTableSize]prevEntry
34	dictTable     []prevEntry
35	dictLongTable []prevEntry
36}
37
38// Encode improves compression...
39func (e *bestFastEncoder) Encode(blk *blockEnc, src []byte) {
40	const (
41		// Input margin is the number of bytes we read (8)
42		// and the maximum we will read ahead (2)
43		inputMargin            = 8 + 4
44		minNonLiteralBlockSize = 16
45	)
46
47	// Protect against e.cur wraparound.
48	for e.cur >= bufferReset {
49		if len(e.hist) == 0 {
50			for i := range e.table[:] {
51				e.table[i] = prevEntry{}
52			}
53			for i := range e.longTable[:] {
54				e.longTable[i] = prevEntry{}
55			}
56			e.cur = e.maxMatchOff
57			break
58		}
59		// Shift down everything in the table that isn't already too far away.
60		minOff := e.cur + int32(len(e.hist)) - e.maxMatchOff
61		for i := range e.table[:] {
62			v := e.table[i].offset
63			v2 := e.table[i].prev
64			if v < minOff {
65				v = 0
66				v2 = 0
67			} else {
68				v = v - e.cur + e.maxMatchOff
69				if v2 < minOff {
70					v2 = 0
71				} else {
72					v2 = v2 - e.cur + e.maxMatchOff
73				}
74			}
75			e.table[i] = prevEntry{
76				offset: v,
77				prev:   v2,
78			}
79		}
80		for i := range e.longTable[:] {
81			v := e.longTable[i].offset
82			v2 := e.longTable[i].prev
83			if v < minOff {
84				v = 0
85				v2 = 0
86			} else {
87				v = v - e.cur + e.maxMatchOff
88				if v2 < minOff {
89					v2 = 0
90				} else {
91					v2 = v2 - e.cur + e.maxMatchOff
92				}
93			}
94			e.longTable[i] = prevEntry{
95				offset: v,
96				prev:   v2,
97			}
98		}
99		e.cur = e.maxMatchOff
100		break
101	}
102
103	s := e.addBlock(src)
104	blk.size = len(src)
105	if len(src) < minNonLiteralBlockSize {
106		blk.extraLits = len(src)
107		blk.literals = blk.literals[:len(src)]
108		copy(blk.literals, src)
109		return
110	}
111
112	// Override src
113	src = e.hist
114	sLimit := int32(len(src)) - inputMargin
115	const kSearchStrength = 12
116
117	// nextEmit is where in src the next emitLiteral should start from.
118	nextEmit := s
119	cv := load6432(src, s)
120
121	// Relative offsets
122	offset1 := int32(blk.recentOffsets[0])
123	offset2 := int32(blk.recentOffsets[1])
124	offset3 := int32(blk.recentOffsets[2])
125
126	addLiterals := func(s *seq, until int32) {
127		if until == nextEmit {
128			return
129		}
130		blk.literals = append(blk.literals, src[nextEmit:until]...)
131		s.litLen = uint32(until - nextEmit)
132	}
133	_ = addLiterals
134
135	if debug {
136		println("recent offsets:", blk.recentOffsets)
137	}
138
139encodeLoop:
140	for {
141		// We allow the encoder to optionally turn off repeat offsets across blocks
142		canRepeat := len(blk.sequences) > 2
143
144		if debugAsserts && canRepeat && offset1 == 0 {
145			panic("offset0 was 0")
146		}
147
148		type match struct {
149			offset int32
150			s      int32
151			length int32
152			rep    int32
153		}
154		matchAt := func(offset int32, s int32, first uint32, rep int32) match {
155			if s-offset >= e.maxMatchOff || load3232(src, offset) != first {
156				return match{offset: offset, s: s}
157			}
158			return match{offset: offset, s: s, length: 4 + e.matchlen(s+4, offset+4, src), rep: rep}
159		}
160
161		bestOf := func(a, b match) match {
162			aScore := b.s - a.s + a.length
163			bScore := a.s - b.s + b.length
164			if a.rep < 0 {
165				aScore = aScore - int32(bits.Len32(uint32(a.offset)))/8
166			}
167			if b.rep < 0 {
168				bScore = bScore - int32(bits.Len32(uint32(b.offset)))/8
169			}
170			if aScore >= bScore {
171				return a
172			}
173			return b
174		}
175		const goodEnough = 100
176
177		nextHashL := hash8(cv, bestLongTableBits)
178		nextHashS := hash4x64(cv, bestShortTableBits)
179		candidateL := e.longTable[nextHashL]
180		candidateS := e.table[nextHashS]
181
182		best := bestOf(matchAt(candidateL.offset-e.cur, s, uint32(cv), -1), matchAt(candidateL.prev-e.cur, s, uint32(cv), -1))
183		best = bestOf(best, matchAt(candidateS.offset-e.cur, s, uint32(cv), -1))
184		best = bestOf(best, matchAt(candidateS.prev-e.cur, s, uint32(cv), -1))
185		if canRepeat && best.length < goodEnough {
186			best = bestOf(best, matchAt(s-offset1+1, s+1, uint32(cv>>8), 1))
187			best = bestOf(best, matchAt(s-offset2+1, s+1, uint32(cv>>8), 2))
188			best = bestOf(best, matchAt(s-offset3+1, s+1, uint32(cv>>8), 3))
189			best = bestOf(best, matchAt(s-offset1+3, s+3, uint32(cv>>24), 1))
190			best = bestOf(best, matchAt(s-offset2+3, s+3, uint32(cv>>24), 2))
191			best = bestOf(best, matchAt(s-offset3+3, s+3, uint32(cv>>24), 3))
192		}
193		// Load next and check...
194		e.longTable[nextHashL] = prevEntry{offset: s + e.cur, prev: candidateL.offset}
195		e.table[nextHashS] = prevEntry{offset: s + e.cur, prev: candidateS.offset}
196
197		// Look far ahead, unless we have a really long match already...
198		if best.length < goodEnough {
199			// No match found, move forward on input, no need to check forward...
200			if best.length < 4 {
201				s += 1 + (s-nextEmit)>>(kSearchStrength-1)
202				if s >= sLimit {
203					break encodeLoop
204				}
205				cv = load6432(src, s)
206				continue
207			}
208
209			s++
210			candidateS = e.table[hash4x64(cv>>8, bestShortTableBits)]
211			cv = load6432(src, s)
212			cv2 := load6432(src, s+1)
213			candidateL = e.longTable[hash8(cv, bestLongTableBits)]
214			candidateL2 := e.longTable[hash8(cv2, bestLongTableBits)]
215
216			best = bestOf(best, matchAt(candidateS.offset-e.cur, s, uint32(cv), -1))
217			best = bestOf(best, matchAt(candidateL.offset-e.cur, s, uint32(cv), -1))
218			best = bestOf(best, matchAt(candidateL.prev-e.cur, s, uint32(cv), -1))
219			best = bestOf(best, matchAt(candidateL2.offset-e.cur, s+1, uint32(cv2), -1))
220			best = bestOf(best, matchAt(candidateL2.prev-e.cur, s+1, uint32(cv2), -1))
221		}
222
223		// We have a match, we can store the forward value
224		if best.rep > 0 {
225			s = best.s
226			var seq seq
227			seq.matchLen = uint32(best.length - zstdMinMatch)
228
229			// We might be able to match backwards.
230			// Extend as long as we can.
231			start := best.s
232			// We end the search early, so we don't risk 0 literals
233			// and have to do special offset treatment.
234			startLimit := nextEmit + 1
235
236			tMin := s - e.maxMatchOff
237			if tMin < 0 {
238				tMin = 0
239			}
240			repIndex := best.offset
241			for repIndex > tMin && start > startLimit && src[repIndex-1] == src[start-1] && seq.matchLen < maxMatchLength-zstdMinMatch-1 {
242				repIndex--
243				start--
244				seq.matchLen++
245			}
246			addLiterals(&seq, start)
247
248			// rep 0
249			seq.offset = uint32(best.rep)
250			if debugSequences {
251				println("repeat sequence", seq, "next s:", s)
252			}
253			blk.sequences = append(blk.sequences, seq)
254
255			// Index match start+1 (long) -> s - 1
256			index0 := s
257			s = best.s + best.length
258
259			nextEmit = s
260			if s >= sLimit {
261				if debug {
262					println("repeat ended", s, best.length)
263
264				}
265				break encodeLoop
266			}
267			// Index skipped...
268			off := index0 + e.cur
269			for index0 < s-1 {
270				cv0 := load6432(src, index0)
271				h0 := hash8(cv0, bestLongTableBits)
272				h1 := hash4x64(cv0, bestShortTableBits)
273				e.longTable[h0] = prevEntry{offset: off, prev: e.longTable[h0].offset}
274				e.table[h1] = prevEntry{offset: off, prev: e.table[h1].offset}
275				off++
276				index0++
277			}
278			switch best.rep {
279			case 2:
280				offset1, offset2 = offset2, offset1
281			case 3:
282				offset1, offset2, offset3 = offset3, offset1, offset2
283			}
284			cv = load6432(src, s)
285			continue
286		}
287
288		// A 4-byte match has been found. Update recent offsets.
289		// We'll later see if more than 4 bytes.
290		s = best.s
291		t := best.offset
292		offset1, offset2, offset3 = s-t, offset1, offset2
293
294		if debugAsserts && s <= t {
295			panic(fmt.Sprintf("s (%d) <= t (%d)", s, t))
296		}
297
298		if debugAsserts && canRepeat && int(offset1) > len(src) {
299			panic("invalid offset")
300		}
301
302		// Extend the n-byte match as long as possible.
303		l := best.length
304
305		// Extend backwards
306		tMin := s - e.maxMatchOff
307		if tMin < 0 {
308			tMin = 0
309		}
310		for t > tMin && s > nextEmit && src[t-1] == src[s-1] && l < maxMatchLength {
311			s--
312			t--
313			l++
314		}
315
316		// Write our sequence
317		var seq seq
318		seq.litLen = uint32(s - nextEmit)
319		seq.matchLen = uint32(l - zstdMinMatch)
320		if seq.litLen > 0 {
321			blk.literals = append(blk.literals, src[nextEmit:s]...)
322		}
323		seq.offset = uint32(s-t) + 3
324		s += l
325		if debugSequences {
326			println("sequence", seq, "next s:", s)
327		}
328		blk.sequences = append(blk.sequences, seq)
329		nextEmit = s
330		if s >= sLimit {
331			break encodeLoop
332		}
333
334		// Index match start+1 (long) -> s - 1
335		index0 := s - l + 1
336		// every entry
337		for index0 < s-1 {
338			cv0 := load6432(src, index0)
339			h0 := hash8(cv0, bestLongTableBits)
340			h1 := hash4x64(cv0, bestShortTableBits)
341			off := index0 + e.cur
342			e.longTable[h0] = prevEntry{offset: off, prev: e.longTable[h0].offset}
343			e.table[h1] = prevEntry{offset: off, prev: e.table[h1].offset}
344			index0++
345		}
346
347		cv = load6432(src, s)
348		if !canRepeat {
349			continue
350		}
351
352		// Check offset 2
353		for {
354			o2 := s - offset2
355			if load3232(src, o2) != uint32(cv) {
356				// Do regular search
357				break
358			}
359
360			// Store this, since we have it.
361			nextHashS := hash4x64(cv, bestShortTableBits)
362			nextHashL := hash8(cv, bestLongTableBits)
363
364			// We have at least 4 byte match.
365			// No need to check backwards. We come straight from a match
366			l := 4 + e.matchlen(s+4, o2+4, src)
367
368			e.longTable[nextHashL] = prevEntry{offset: s + e.cur, prev: e.longTable[nextHashL].offset}
369			e.table[nextHashS] = prevEntry{offset: s + e.cur, prev: e.table[nextHashS].offset}
370			seq.matchLen = uint32(l) - zstdMinMatch
371			seq.litLen = 0
372
373			// Since litlen is always 0, this is offset 1.
374			seq.offset = 1
375			s += l
376			nextEmit = s
377			if debugSequences {
378				println("sequence", seq, "next s:", s)
379			}
380			blk.sequences = append(blk.sequences, seq)
381
382			// Swap offset 1 and 2.
383			offset1, offset2 = offset2, offset1
384			if s >= sLimit {
385				// Finished
386				break encodeLoop
387			}
388			cv = load6432(src, s)
389		}
390	}
391
392	if int(nextEmit) < len(src) {
393		blk.literals = append(blk.literals, src[nextEmit:]...)
394		blk.extraLits = len(src) - int(nextEmit)
395	}
396	blk.recentOffsets[0] = uint32(offset1)
397	blk.recentOffsets[1] = uint32(offset2)
398	blk.recentOffsets[2] = uint32(offset3)
399	if debug {
400		println("returning, recent offsets:", blk.recentOffsets, "extra literals:", blk.extraLits)
401	}
402}
403
404// EncodeNoHist will encode a block with no history and no following blocks.
405// Most notable difference is that src will not be copied for history and
406// we do not need to check for max match length.
407func (e *bestFastEncoder) EncodeNoHist(blk *blockEnc, src []byte) {
408	e.Encode(blk, src)
409}
410
411// ResetDict will reset and set a dictionary if not nil
412func (e *bestFastEncoder) Reset(d *dict, singleBlock bool) {
413	e.resetBase(d, singleBlock)
414	if d == nil {
415		return
416	}
417	// Init or copy dict table
418	if len(e.dictTable) != len(e.table) || d.id != e.lastDictID {
419		if len(e.dictTable) != len(e.table) {
420			e.dictTable = make([]prevEntry, len(e.table))
421		}
422		end := int32(len(d.content)) - 8 + e.maxMatchOff
423		for i := e.maxMatchOff; i < end; i += 4 {
424			const hashLog = bestShortTableBits
425
426			cv := load6432(d.content, i-e.maxMatchOff)
427			nextHash := hash4x64(cv, hashLog)      // 0 -> 4
428			nextHash1 := hash4x64(cv>>8, hashLog)  // 1 -> 5
429			nextHash2 := hash4x64(cv>>16, hashLog) // 2 -> 6
430			nextHash3 := hash4x64(cv>>24, hashLog) // 3 -> 7
431			e.dictTable[nextHash] = prevEntry{
432				prev:   e.dictTable[nextHash].offset,
433				offset: i,
434			}
435			e.dictTable[nextHash1] = prevEntry{
436				prev:   e.dictTable[nextHash1].offset,
437				offset: i + 1,
438			}
439			e.dictTable[nextHash2] = prevEntry{
440				prev:   e.dictTable[nextHash2].offset,
441				offset: i + 2,
442			}
443			e.dictTable[nextHash3] = prevEntry{
444				prev:   e.dictTable[nextHash3].offset,
445				offset: i + 3,
446			}
447		}
448		e.lastDictID = d.id
449	}
450
451	// Init or copy dict table
452	if len(e.dictLongTable) != len(e.longTable) || d.id != e.lastDictID {
453		if len(e.dictLongTable) != len(e.longTable) {
454			e.dictLongTable = make([]prevEntry, len(e.longTable))
455		}
456		if len(d.content) >= 8 {
457			cv := load6432(d.content, 0)
458			h := hash8(cv, bestLongTableBits)
459			e.dictLongTable[h] = prevEntry{
460				offset: e.maxMatchOff,
461				prev:   e.dictLongTable[h].offset,
462			}
463
464			end := int32(len(d.content)) - 8 + e.maxMatchOff
465			off := 8 // First to read
466			for i := e.maxMatchOff + 1; i < end; i++ {
467				cv = cv>>8 | (uint64(d.content[off]) << 56)
468				h := hash8(cv, bestLongTableBits)
469				e.dictLongTable[h] = prevEntry{
470					offset: i,
471					prev:   e.dictLongTable[h].offset,
472				}
473				off++
474			}
475		}
476		e.lastDictID = d.id
477	}
478	// Reset table to initial state
479	copy(e.longTable[:], e.dictLongTable)
480
481	e.cur = e.maxMatchOff
482	// Reset table to initial state
483	copy(e.table[:], e.dictTable)
484}
485