1// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import "fmt"
8
9const (
10	betterLongTableBits = 19                       // Bits used in the long match table
11	betterLongTableSize = 1 << betterLongTableBits // Size of the table
12
13	// Note: Increasing the short table bits or making the hash shorter
14	// can actually lead to compression degradation since it will 'steal' more from the
15	// long match table and match offsets are quite big.
16	// This greatly depends on the type of input.
17	betterShortTableBits = 13                        // Bits used in the short match table
18	betterShortTableSize = 1 << betterShortTableBits // Size of the table
19)
20
21type prevEntry struct {
22	offset int32
23	prev   int32
24}
25
26// betterFastEncoder uses 2 tables, one for short matches (5 bytes) and one for long matches.
27// The long match table contains the previous entry with the same hash,
28// effectively making it a "chain" of length 2.
29// When we find a long match we choose between the two values and select the longest.
30// When we find a short match, after checking the long, we check if we can find a long at n+1
31// and that it is longer (lazy matching).
32type betterFastEncoder struct {
33	fastBase
34	table     [betterShortTableSize]tableEntry
35	longTable [betterLongTableSize]prevEntry
36}
37
38// Encode improves compression...
39func (e *betterFastEncoder) Encode(blk *blockEnc, src []byte) {
40	const (
41		// Input margin is the number of bytes we read (8)
42		// and the maximum we will read ahead (2)
43		inputMargin            = 8 + 2
44		minNonLiteralBlockSize = 16
45	)
46
47	// Protect against e.cur wraparound.
48	for e.cur >= bufferReset {
49		if len(e.hist) == 0 {
50			for i := range e.table[:] {
51				e.table[i] = tableEntry{}
52			}
53			for i := range e.longTable[:] {
54				e.longTable[i] = prevEntry{}
55			}
56			e.cur = e.maxMatchOff
57			break
58		}
59		// Shift down everything in the table that isn't already too far away.
60		minOff := e.cur + int32(len(e.hist)) - e.maxMatchOff
61		for i := range e.table[:] {
62			v := e.table[i].offset
63			if v < minOff {
64				v = 0
65			} else {
66				v = v - e.cur + e.maxMatchOff
67			}
68			e.table[i].offset = v
69		}
70		for i := range e.longTable[:] {
71			v := e.longTable[i].offset
72			v2 := e.longTable[i].prev
73			if v < minOff {
74				v = 0
75				v2 = 0
76			} else {
77				v = v - e.cur + e.maxMatchOff
78				if v2 < minOff {
79					v2 = 0
80				} else {
81					v2 = v2 - e.cur + e.maxMatchOff
82				}
83			}
84			e.longTable[i] = prevEntry{
85				offset: v,
86				prev:   v2,
87			}
88		}
89		e.cur = e.maxMatchOff
90		break
91	}
92
93	s := e.addBlock(src)
94	blk.size = len(src)
95	if len(src) < minNonLiteralBlockSize {
96		blk.extraLits = len(src)
97		blk.literals = blk.literals[:len(src)]
98		copy(blk.literals, src)
99		return
100	}
101
102	// Override src
103	src = e.hist
104	sLimit := int32(len(src)) - inputMargin
105	// stepSize is the number of bytes to skip on every main loop iteration.
106	// It should be >= 1.
107	const stepSize = 1
108
109	const kSearchStrength = 9
110
111	// nextEmit is where in src the next emitLiteral should start from.
112	nextEmit := s
113	cv := load6432(src, s)
114
115	// Relative offsets
116	offset1 := int32(blk.recentOffsets[0])
117	offset2 := int32(blk.recentOffsets[1])
118
119	addLiterals := func(s *seq, until int32) {
120		if until == nextEmit {
121			return
122		}
123		blk.literals = append(blk.literals, src[nextEmit:until]...)
124		s.litLen = uint32(until - nextEmit)
125	}
126	if debug {
127		println("recent offsets:", blk.recentOffsets)
128	}
129
130encodeLoop:
131	for {
132		var t int32
133		// We allow the encoder to optionally turn off repeat offsets across blocks
134		canRepeat := len(blk.sequences) > 2
135		var matched int32
136
137		for {
138			if debugAsserts && canRepeat && offset1 == 0 {
139				panic("offset0 was 0")
140			}
141
142			nextHashS := hash5(cv, betterShortTableBits)
143			nextHashL := hash8(cv, betterLongTableBits)
144			candidateL := e.longTable[nextHashL]
145			candidateS := e.table[nextHashS]
146
147			const repOff = 1
148			repIndex := s - offset1 + repOff
149			off := s + e.cur
150			e.longTable[nextHashL] = prevEntry{offset: off, prev: candidateL.offset}
151			e.table[nextHashS] = tableEntry{offset: off, val: uint32(cv)}
152
153			if canRepeat {
154				if repIndex >= 0 && load3232(src, repIndex) == uint32(cv>>(repOff*8)) {
155					// Consider history as well.
156					var seq seq
157					lenght := 4 + e.matchlen(s+4+repOff, repIndex+4, src)
158
159					seq.matchLen = uint32(lenght - zstdMinMatch)
160
161					// We might be able to match backwards.
162					// Extend as long as we can.
163					start := s + repOff
164					// We end the search early, so we don't risk 0 literals
165					// and have to do special offset treatment.
166					startLimit := nextEmit + 1
167
168					tMin := s - e.maxMatchOff
169					if tMin < 0 {
170						tMin = 0
171					}
172					for repIndex > tMin && start > startLimit && src[repIndex-1] == src[start-1] && seq.matchLen < maxMatchLength-zstdMinMatch-1 {
173						repIndex--
174						start--
175						seq.matchLen++
176					}
177					addLiterals(&seq, start)
178
179					// rep 0
180					seq.offset = 1
181					if debugSequences {
182						println("repeat sequence", seq, "next s:", s)
183					}
184					blk.sequences = append(blk.sequences, seq)
185
186					// Index match start+1 (long) -> s - 1
187					index0 := s + repOff
188					s += lenght + repOff
189
190					nextEmit = s
191					if s >= sLimit {
192						if debug {
193							println("repeat ended", s, lenght)
194
195						}
196						break encodeLoop
197					}
198					// Index skipped...
199					for index0 < s-1 {
200						cv0 := load6432(src, index0)
201						cv1 := cv0 >> 8
202						h0 := hash8(cv0, betterLongTableBits)
203						off := index0 + e.cur
204						e.longTable[h0] = prevEntry{offset: off, prev: e.longTable[h0].offset}
205						e.table[hash5(cv1, betterShortTableBits)] = tableEntry{offset: off + 1, val: uint32(cv1)}
206						index0 += 2
207					}
208					cv = load6432(src, s)
209					continue
210				}
211				const repOff2 = 1
212
213				// We deviate from the reference encoder and also check offset 2.
214				// Still slower and not much better, so disabled.
215				// repIndex = s - offset2 + repOff2
216				if false && repIndex >= 0 && load6432(src, repIndex) == load6432(src, s+repOff) {
217					// Consider history as well.
218					var seq seq
219					lenght := 8 + e.matchlen(s+8+repOff2, repIndex+8, src)
220
221					seq.matchLen = uint32(lenght - zstdMinMatch)
222
223					// We might be able to match backwards.
224					// Extend as long as we can.
225					start := s + repOff2
226					// We end the search early, so we don't risk 0 literals
227					// and have to do special offset treatment.
228					startLimit := nextEmit + 1
229
230					tMin := s - e.maxMatchOff
231					if tMin < 0 {
232						tMin = 0
233					}
234					for repIndex > tMin && start > startLimit && src[repIndex-1] == src[start-1] && seq.matchLen < maxMatchLength-zstdMinMatch-1 {
235						repIndex--
236						start--
237						seq.matchLen++
238					}
239					addLiterals(&seq, start)
240
241					// rep 2
242					seq.offset = 2
243					if debugSequences {
244						println("repeat sequence 2", seq, "next s:", s)
245					}
246					blk.sequences = append(blk.sequences, seq)
247
248					index0 := s + repOff2
249					s += lenght + repOff2
250					nextEmit = s
251					if s >= sLimit {
252						if debug {
253							println("repeat ended", s, lenght)
254
255						}
256						break encodeLoop
257					}
258
259					// Index skipped...
260					for index0 < s-1 {
261						cv0 := load6432(src, index0)
262						cv1 := cv0 >> 8
263						h0 := hash8(cv0, betterLongTableBits)
264						off := index0 + e.cur
265						e.longTable[h0] = prevEntry{offset: off, prev: e.longTable[h0].offset}
266						e.table[hash5(cv1, betterShortTableBits)] = tableEntry{offset: off + 1, val: uint32(cv1)}
267						index0 += 2
268					}
269					cv = load6432(src, s)
270					// Swap offsets
271					offset1, offset2 = offset2, offset1
272					continue
273				}
274			}
275			// Find the offsets of our two matches.
276			coffsetL := candidateL.offset - e.cur
277			coffsetLP := candidateL.prev - e.cur
278
279			// Check if we have a long match.
280			if s-coffsetL < e.maxMatchOff && cv == load6432(src, coffsetL) {
281				// Found a long match, at least 8 bytes.
282				matched = e.matchlen(s+8, coffsetL+8, src) + 8
283				t = coffsetL
284				if debugAsserts && s <= t {
285					panic(fmt.Sprintf("s (%d) <= t (%d)", s, t))
286				}
287				if debugAsserts && s-t > e.maxMatchOff {
288					panic("s - t >e.maxMatchOff")
289				}
290				if debugMatches {
291					println("long match")
292				}
293
294				if s-coffsetLP < e.maxMatchOff && cv == load6432(src, coffsetLP) {
295					// Found a long match, at least 8 bytes.
296					prevMatch := e.matchlen(s+8, coffsetLP+8, src) + 8
297					if prevMatch > matched {
298						matched = prevMatch
299						t = coffsetLP
300					}
301					if debugAsserts && s <= t {
302						panic(fmt.Sprintf("s (%d) <= t (%d)", s, t))
303					}
304					if debugAsserts && s-t > e.maxMatchOff {
305						panic("s - t >e.maxMatchOff")
306					}
307					if debugMatches {
308						println("long match")
309					}
310				}
311				break
312			}
313
314			// Check if we have a long match on prev.
315			if s-coffsetLP < e.maxMatchOff && cv == load6432(src, coffsetLP) {
316				// Found a long match, at least 8 bytes.
317				matched = e.matchlen(s+8, coffsetLP+8, src) + 8
318				t = coffsetLP
319				if debugAsserts && s <= t {
320					panic(fmt.Sprintf("s (%d) <= t (%d)", s, t))
321				}
322				if debugAsserts && s-t > e.maxMatchOff {
323					panic("s - t >e.maxMatchOff")
324				}
325				if debugMatches {
326					println("long match")
327				}
328				break
329			}
330
331			coffsetS := candidateS.offset - e.cur
332
333			// Check if we have a short match.
334			if s-coffsetS < e.maxMatchOff && uint32(cv) == candidateS.val {
335				// found a regular match
336				matched = e.matchlen(s+4, coffsetS+4, src) + 4
337
338				// See if we can find a long match at s+1
339				const checkAt = 1
340				cv := load6432(src, s+checkAt)
341				nextHashL = hash8(cv, betterLongTableBits)
342				candidateL = e.longTable[nextHashL]
343				coffsetL = candidateL.offset - e.cur
344
345				// We can store it, since we have at least a 4 byte match.
346				e.longTable[nextHashL] = prevEntry{offset: s + checkAt + e.cur, prev: candidateL.offset}
347				if s-coffsetL < e.maxMatchOff && cv == load6432(src, coffsetL) {
348					// Found a long match, at least 8 bytes.
349					matchedNext := e.matchlen(s+8+checkAt, coffsetL+8, src) + 8
350					if matchedNext > matched {
351						t = coffsetL
352						s += checkAt
353						matched = matchedNext
354						if debugMatches {
355							println("long match (after short)")
356						}
357						break
358					}
359				}
360
361				// Check prev long...
362				coffsetL = candidateL.prev - e.cur
363				if s-coffsetL < e.maxMatchOff && cv == load6432(src, coffsetL) {
364					// Found a long match, at least 8 bytes.
365					matchedNext := e.matchlen(s+8+checkAt, coffsetL+8, src) + 8
366					if matchedNext > matched {
367						t = coffsetL
368						s += checkAt
369						matched = matchedNext
370						if debugMatches {
371							println("prev long match (after short)")
372						}
373						break
374					}
375				}
376				t = coffsetS
377				if debugAsserts && s <= t {
378					panic(fmt.Sprintf("s (%d) <= t (%d)", s, t))
379				}
380				if debugAsserts && s-t > e.maxMatchOff {
381					panic("s - t >e.maxMatchOff")
382				}
383				if debugAsserts && t < 0 {
384					panic("t<0")
385				}
386				if debugMatches {
387					println("short match")
388				}
389				break
390			}
391
392			// No match found, move forward in input.
393			s += stepSize + ((s - nextEmit) >> (kSearchStrength - 1))
394			if s >= sLimit {
395				break encodeLoop
396			}
397			cv = load6432(src, s)
398		}
399
400		// A 4-byte match has been found. Update recent offsets.
401		// We'll later see if more than 4 bytes.
402		offset2 = offset1
403		offset1 = s - t
404
405		if debugAsserts && s <= t {
406			panic(fmt.Sprintf("s (%d) <= t (%d)", s, t))
407		}
408
409		if debugAsserts && canRepeat && int(offset1) > len(src) {
410			panic("invalid offset")
411		}
412
413		// Extend the n-byte match as long as possible.
414		l := matched
415
416		// Extend backwards
417		tMin := s - e.maxMatchOff
418		if tMin < 0 {
419			tMin = 0
420		}
421		for t > tMin && s > nextEmit && src[t-1] == src[s-1] && l < maxMatchLength {
422			s--
423			t--
424			l++
425		}
426
427		// Write our sequence
428		var seq seq
429		seq.litLen = uint32(s - nextEmit)
430		seq.matchLen = uint32(l - zstdMinMatch)
431		if seq.litLen > 0 {
432			blk.literals = append(blk.literals, src[nextEmit:s]...)
433		}
434		seq.offset = uint32(s-t) + 3
435		s += l
436		if debugSequences {
437			println("sequence", seq, "next s:", s)
438		}
439		blk.sequences = append(blk.sequences, seq)
440		nextEmit = s
441		if s >= sLimit {
442			break encodeLoop
443		}
444
445		// Index match start+1 (long) -> s - 1
446		index0 := s - l + 1
447		for index0 < s-1 {
448			cv0 := load6432(src, index0)
449			cv1 := cv0 >> 8
450			h0 := hash8(cv0, betterLongTableBits)
451			off := index0 + e.cur
452			e.longTable[h0] = prevEntry{offset: off, prev: e.longTable[h0].offset}
453			e.table[hash5(cv1, betterShortTableBits)] = tableEntry{offset: off + 1, val: uint32(cv1)}
454			index0 += 2
455		}
456
457		cv = load6432(src, s)
458		if !canRepeat {
459			continue
460		}
461
462		// Check offset 2
463		for {
464			o2 := s - offset2
465			if load3232(src, o2) != uint32(cv) {
466				// Do regular search
467				break
468			}
469
470			// Store this, since we have it.
471			nextHashS := hash5(cv, betterShortTableBits)
472			nextHashL := hash8(cv, betterLongTableBits)
473
474			// We have at least 4 byte match.
475			// No need to check backwards. We come straight from a match
476			l := 4 + e.matchlen(s+4, o2+4, src)
477
478			e.longTable[nextHashL] = prevEntry{offset: s + e.cur, prev: e.longTable[nextHashL].offset}
479			e.table[nextHashS] = tableEntry{offset: s + e.cur, val: uint32(cv)}
480			seq.matchLen = uint32(l) - zstdMinMatch
481			seq.litLen = 0
482
483			// Since litlen is always 0, this is offset 1.
484			seq.offset = 1
485			s += l
486			nextEmit = s
487			if debugSequences {
488				println("sequence", seq, "next s:", s)
489			}
490			blk.sequences = append(blk.sequences, seq)
491
492			// Swap offset 1 and 2.
493			offset1, offset2 = offset2, offset1
494			if s >= sLimit {
495				// Finished
496				break encodeLoop
497			}
498			cv = load6432(src, s)
499		}
500	}
501
502	if int(nextEmit) < len(src) {
503		blk.literals = append(blk.literals, src[nextEmit:]...)
504		blk.extraLits = len(src) - int(nextEmit)
505	}
506	blk.recentOffsets[0] = uint32(offset1)
507	blk.recentOffsets[1] = uint32(offset2)
508	if debug {
509		println("returning, recent offsets:", blk.recentOffsets, "extra literals:", blk.extraLits)
510	}
511}
512
513// EncodeNoHist will encode a block with no history and no following blocks.
514// Most notable difference is that src will not be copied for history and
515// we do not need to check for max match length.
516func (e *betterFastEncoder) EncodeNoHist(blk *blockEnc, src []byte) {
517	e.Encode(blk, src)
518}
519