1// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import "fmt"
8
9const (
10	betterLongTableBits = 19                       // Bits used in the long match table
11	betterLongTableSize = 1 << betterLongTableBits // Size of the table
12
13	// Note: Increasing the short table bits or making the hash shorter
14	// can actually lead to compression degradation since it will 'steal' more from the
15	// long match table and match offsets are quite big.
16	// This greatly depends on the type of input.
17	betterShortTableBits = 13                        // Bits used in the short match table
18	betterShortTableSize = 1 << betterShortTableBits // Size of the table
19)
20
21type prevEntry struct {
22	offset int32
23	prev   int32
24}
25
26// betterFastEncoder uses 2 tables, one for short matches (5 bytes) and one for long matches.
27// The long match table contains the previous entry with the same hash,
28// effectively making it a "chain" of length 2.
29// When we find a long match we choose between the two values and select the longest.
30// When we find a short match, after checking the long, we check if we can find a long at n+1
31// and that it is longer (lazy matching).
32type betterFastEncoder struct {
33	fastBase
34	table         [betterShortTableSize]tableEntry
35	longTable     [betterLongTableSize]prevEntry
36	dictTable     []tableEntry
37	dictLongTable []prevEntry
38}
39
40// Encode improves compression...
41func (e *betterFastEncoder) Encode(blk *blockEnc, src []byte) {
42	const (
43		// Input margin is the number of bytes we read (8)
44		// and the maximum we will read ahead (2)
45		inputMargin            = 8 + 2
46		minNonLiteralBlockSize = 16
47	)
48
49	// Protect against e.cur wraparound.
50	for e.cur >= bufferReset {
51		if len(e.hist) == 0 {
52			for i := range e.table[:] {
53				e.table[i] = tableEntry{}
54			}
55			for i := range e.longTable[:] {
56				e.longTable[i] = prevEntry{}
57			}
58			e.cur = e.maxMatchOff
59			break
60		}
61		// Shift down everything in the table that isn't already too far away.
62		minOff := e.cur + int32(len(e.hist)) - e.maxMatchOff
63		for i := range e.table[:] {
64			v := e.table[i].offset
65			if v < minOff {
66				v = 0
67			} else {
68				v = v - e.cur + e.maxMatchOff
69			}
70			e.table[i].offset = v
71		}
72		for i := range e.longTable[:] {
73			v := e.longTable[i].offset
74			v2 := e.longTable[i].prev
75			if v < minOff {
76				v = 0
77				v2 = 0
78			} else {
79				v = v - e.cur + e.maxMatchOff
80				if v2 < minOff {
81					v2 = 0
82				} else {
83					v2 = v2 - e.cur + e.maxMatchOff
84				}
85			}
86			e.longTable[i] = prevEntry{
87				offset: v,
88				prev:   v2,
89			}
90		}
91		e.cur = e.maxMatchOff
92		break
93	}
94
95	s := e.addBlock(src)
96	blk.size = len(src)
97	if len(src) < minNonLiteralBlockSize {
98		blk.extraLits = len(src)
99		blk.literals = blk.literals[:len(src)]
100		copy(blk.literals, src)
101		return
102	}
103
104	// Override src
105	src = e.hist
106	sLimit := int32(len(src)) - inputMargin
107	// stepSize is the number of bytes to skip on every main loop iteration.
108	// It should be >= 1.
109	const stepSize = 1
110
111	const kSearchStrength = 9
112
113	// nextEmit is where in src the next emitLiteral should start from.
114	nextEmit := s
115	cv := load6432(src, s)
116
117	// Relative offsets
118	offset1 := int32(blk.recentOffsets[0])
119	offset2 := int32(blk.recentOffsets[1])
120
121	addLiterals := func(s *seq, until int32) {
122		if until == nextEmit {
123			return
124		}
125		blk.literals = append(blk.literals, src[nextEmit:until]...)
126		s.litLen = uint32(until - nextEmit)
127	}
128	if debug {
129		println("recent offsets:", blk.recentOffsets)
130	}
131
132encodeLoop:
133	for {
134		var t int32
135		// We allow the encoder to optionally turn off repeat offsets across blocks
136		canRepeat := len(blk.sequences) > 2
137		var matched int32
138
139		for {
140			if debugAsserts && canRepeat && offset1 == 0 {
141				panic("offset0 was 0")
142			}
143
144			nextHashS := hash5(cv, betterShortTableBits)
145			nextHashL := hash8(cv, betterLongTableBits)
146			candidateL := e.longTable[nextHashL]
147			candidateS := e.table[nextHashS]
148
149			const repOff = 1
150			repIndex := s - offset1 + repOff
151			off := s + e.cur
152			e.longTable[nextHashL] = prevEntry{offset: off, prev: candidateL.offset}
153			e.table[nextHashS] = tableEntry{offset: off, val: uint32(cv)}
154
155			if canRepeat {
156				if repIndex >= 0 && load3232(src, repIndex) == uint32(cv>>(repOff*8)) {
157					// Consider history as well.
158					var seq seq
159					lenght := 4 + e.matchlen(s+4+repOff, repIndex+4, src)
160
161					seq.matchLen = uint32(lenght - zstdMinMatch)
162
163					// We might be able to match backwards.
164					// Extend as long as we can.
165					start := s + repOff
166					// We end the search early, so we don't risk 0 literals
167					// and have to do special offset treatment.
168					startLimit := nextEmit + 1
169
170					tMin := s - e.maxMatchOff
171					if tMin < 0 {
172						tMin = 0
173					}
174					for repIndex > tMin && start > startLimit && src[repIndex-1] == src[start-1] && seq.matchLen < maxMatchLength-zstdMinMatch-1 {
175						repIndex--
176						start--
177						seq.matchLen++
178					}
179					addLiterals(&seq, start)
180
181					// rep 0
182					seq.offset = 1
183					if debugSequences {
184						println("repeat sequence", seq, "next s:", s)
185					}
186					blk.sequences = append(blk.sequences, seq)
187
188					// Index match start+1 (long) -> s - 1
189					index0 := s + repOff
190					s += lenght + repOff
191
192					nextEmit = s
193					if s >= sLimit {
194						if debug {
195							println("repeat ended", s, lenght)
196
197						}
198						break encodeLoop
199					}
200					// Index skipped...
201					for index0 < s-1 {
202						cv0 := load6432(src, index0)
203						cv1 := cv0 >> 8
204						h0 := hash8(cv0, betterLongTableBits)
205						off := index0 + e.cur
206						e.longTable[h0] = prevEntry{offset: off, prev: e.longTable[h0].offset}
207						e.table[hash5(cv1, betterShortTableBits)] = tableEntry{offset: off + 1, val: uint32(cv1)}
208						index0 += 2
209					}
210					cv = load6432(src, s)
211					continue
212				}
213				const repOff2 = 1
214
215				// We deviate from the reference encoder and also check offset 2.
216				// Still slower and not much better, so disabled.
217				// repIndex = s - offset2 + repOff2
218				if false && repIndex >= 0 && load6432(src, repIndex) == load6432(src, s+repOff) {
219					// Consider history as well.
220					var seq seq
221					lenght := 8 + e.matchlen(s+8+repOff2, repIndex+8, src)
222
223					seq.matchLen = uint32(lenght - zstdMinMatch)
224
225					// We might be able to match backwards.
226					// Extend as long as we can.
227					start := s + repOff2
228					// We end the search early, so we don't risk 0 literals
229					// and have to do special offset treatment.
230					startLimit := nextEmit + 1
231
232					tMin := s - e.maxMatchOff
233					if tMin < 0 {
234						tMin = 0
235					}
236					for repIndex > tMin && start > startLimit && src[repIndex-1] == src[start-1] && seq.matchLen < maxMatchLength-zstdMinMatch-1 {
237						repIndex--
238						start--
239						seq.matchLen++
240					}
241					addLiterals(&seq, start)
242
243					// rep 2
244					seq.offset = 2
245					if debugSequences {
246						println("repeat sequence 2", seq, "next s:", s)
247					}
248					blk.sequences = append(blk.sequences, seq)
249
250					index0 := s + repOff2
251					s += lenght + repOff2
252					nextEmit = s
253					if s >= sLimit {
254						if debug {
255							println("repeat ended", s, lenght)
256
257						}
258						break encodeLoop
259					}
260
261					// Index skipped...
262					for index0 < s-1 {
263						cv0 := load6432(src, index0)
264						cv1 := cv0 >> 8
265						h0 := hash8(cv0, betterLongTableBits)
266						off := index0 + e.cur
267						e.longTable[h0] = prevEntry{offset: off, prev: e.longTable[h0].offset}
268						e.table[hash5(cv1, betterShortTableBits)] = tableEntry{offset: off + 1, val: uint32(cv1)}
269						index0 += 2
270					}
271					cv = load6432(src, s)
272					// Swap offsets
273					offset1, offset2 = offset2, offset1
274					continue
275				}
276			}
277			// Find the offsets of our two matches.
278			coffsetL := candidateL.offset - e.cur
279			coffsetLP := candidateL.prev - e.cur
280
281			// Check if we have a long match.
282			if s-coffsetL < e.maxMatchOff && cv == load6432(src, coffsetL) {
283				// Found a long match, at least 8 bytes.
284				matched = e.matchlen(s+8, coffsetL+8, src) + 8
285				t = coffsetL
286				if debugAsserts && s <= t {
287					panic(fmt.Sprintf("s (%d) <= t (%d)", s, t))
288				}
289				if debugAsserts && s-t > e.maxMatchOff {
290					panic("s - t >e.maxMatchOff")
291				}
292				if debugMatches {
293					println("long match")
294				}
295
296				if s-coffsetLP < e.maxMatchOff && cv == load6432(src, coffsetLP) {
297					// Found a long match, at least 8 bytes.
298					prevMatch := e.matchlen(s+8, coffsetLP+8, src) + 8
299					if prevMatch > matched {
300						matched = prevMatch
301						t = coffsetLP
302					}
303					if debugAsserts && s <= t {
304						panic(fmt.Sprintf("s (%d) <= t (%d)", s, t))
305					}
306					if debugAsserts && s-t > e.maxMatchOff {
307						panic("s - t >e.maxMatchOff")
308					}
309					if debugMatches {
310						println("long match")
311					}
312				}
313				break
314			}
315
316			// Check if we have a long match on prev.
317			if s-coffsetLP < e.maxMatchOff && cv == load6432(src, coffsetLP) {
318				// Found a long match, at least 8 bytes.
319				matched = e.matchlen(s+8, coffsetLP+8, src) + 8
320				t = coffsetLP
321				if debugAsserts && s <= t {
322					panic(fmt.Sprintf("s (%d) <= t (%d)", s, t))
323				}
324				if debugAsserts && s-t > e.maxMatchOff {
325					panic("s - t >e.maxMatchOff")
326				}
327				if debugMatches {
328					println("long match")
329				}
330				break
331			}
332
333			coffsetS := candidateS.offset - e.cur
334
335			// Check if we have a short match.
336			if s-coffsetS < e.maxMatchOff && uint32(cv) == candidateS.val {
337				// found a regular match
338				matched = e.matchlen(s+4, coffsetS+4, src) + 4
339
340				// See if we can find a long match at s+1
341				const checkAt = 1
342				cv := load6432(src, s+checkAt)
343				nextHashL = hash8(cv, betterLongTableBits)
344				candidateL = e.longTable[nextHashL]
345				coffsetL = candidateL.offset - e.cur
346
347				// We can store it, since we have at least a 4 byte match.
348				e.longTable[nextHashL] = prevEntry{offset: s + checkAt + e.cur, prev: candidateL.offset}
349				if s-coffsetL < e.maxMatchOff && cv == load6432(src, coffsetL) {
350					// Found a long match, at least 8 bytes.
351					matchedNext := e.matchlen(s+8+checkAt, coffsetL+8, src) + 8
352					if matchedNext > matched {
353						t = coffsetL
354						s += checkAt
355						matched = matchedNext
356						if debugMatches {
357							println("long match (after short)")
358						}
359						break
360					}
361				}
362
363				// Check prev long...
364				coffsetL = candidateL.prev - e.cur
365				if s-coffsetL < e.maxMatchOff && cv == load6432(src, coffsetL) {
366					// Found a long match, at least 8 bytes.
367					matchedNext := e.matchlen(s+8+checkAt, coffsetL+8, src) + 8
368					if matchedNext > matched {
369						t = coffsetL
370						s += checkAt
371						matched = matchedNext
372						if debugMatches {
373							println("prev long match (after short)")
374						}
375						break
376					}
377				}
378				t = coffsetS
379				if debugAsserts && s <= t {
380					panic(fmt.Sprintf("s (%d) <= t (%d)", s, t))
381				}
382				if debugAsserts && s-t > e.maxMatchOff {
383					panic("s - t >e.maxMatchOff")
384				}
385				if debugAsserts && t < 0 {
386					panic("t<0")
387				}
388				if debugMatches {
389					println("short match")
390				}
391				break
392			}
393
394			// No match found, move forward in input.
395			s += stepSize + ((s - nextEmit) >> (kSearchStrength - 1))
396			if s >= sLimit {
397				break encodeLoop
398			}
399			cv = load6432(src, s)
400		}
401
402		// A 4-byte match has been found. Update recent offsets.
403		// We'll later see if more than 4 bytes.
404		offset2 = offset1
405		offset1 = s - t
406
407		if debugAsserts && s <= t {
408			panic(fmt.Sprintf("s (%d) <= t (%d)", s, t))
409		}
410
411		if debugAsserts && canRepeat && int(offset1) > len(src) {
412			panic("invalid offset")
413		}
414
415		// Extend the n-byte match as long as possible.
416		l := matched
417
418		// Extend backwards
419		tMin := s - e.maxMatchOff
420		if tMin < 0 {
421			tMin = 0
422		}
423		for t > tMin && s > nextEmit && src[t-1] == src[s-1] && l < maxMatchLength {
424			s--
425			t--
426			l++
427		}
428
429		// Write our sequence
430		var seq seq
431		seq.litLen = uint32(s - nextEmit)
432		seq.matchLen = uint32(l - zstdMinMatch)
433		if seq.litLen > 0 {
434			blk.literals = append(blk.literals, src[nextEmit:s]...)
435		}
436		seq.offset = uint32(s-t) + 3
437		s += l
438		if debugSequences {
439			println("sequence", seq, "next s:", s)
440		}
441		blk.sequences = append(blk.sequences, seq)
442		nextEmit = s
443		if s >= sLimit {
444			break encodeLoop
445		}
446
447		// Index match start+1 (long) -> s - 1
448		index0 := s - l + 1
449		for index0 < s-1 {
450			cv0 := load6432(src, index0)
451			cv1 := cv0 >> 8
452			h0 := hash8(cv0, betterLongTableBits)
453			off := index0 + e.cur
454			e.longTable[h0] = prevEntry{offset: off, prev: e.longTable[h0].offset}
455			e.table[hash5(cv1, betterShortTableBits)] = tableEntry{offset: off + 1, val: uint32(cv1)}
456			index0 += 2
457		}
458
459		cv = load6432(src, s)
460		if !canRepeat {
461			continue
462		}
463
464		// Check offset 2
465		for {
466			o2 := s - offset2
467			if load3232(src, o2) != uint32(cv) {
468				// Do regular search
469				break
470			}
471
472			// Store this, since we have it.
473			nextHashS := hash5(cv, betterShortTableBits)
474			nextHashL := hash8(cv, betterLongTableBits)
475
476			// We have at least 4 byte match.
477			// No need to check backwards. We come straight from a match
478			l := 4 + e.matchlen(s+4, o2+4, src)
479
480			e.longTable[nextHashL] = prevEntry{offset: s + e.cur, prev: e.longTable[nextHashL].offset}
481			e.table[nextHashS] = tableEntry{offset: s + e.cur, val: uint32(cv)}
482			seq.matchLen = uint32(l) - zstdMinMatch
483			seq.litLen = 0
484
485			// Since litlen is always 0, this is offset 1.
486			seq.offset = 1
487			s += l
488			nextEmit = s
489			if debugSequences {
490				println("sequence", seq, "next s:", s)
491			}
492			blk.sequences = append(blk.sequences, seq)
493
494			// Swap offset 1 and 2.
495			offset1, offset2 = offset2, offset1
496			if s >= sLimit {
497				// Finished
498				break encodeLoop
499			}
500			cv = load6432(src, s)
501		}
502	}
503
504	if int(nextEmit) < len(src) {
505		blk.literals = append(blk.literals, src[nextEmit:]...)
506		blk.extraLits = len(src) - int(nextEmit)
507	}
508	blk.recentOffsets[0] = uint32(offset1)
509	blk.recentOffsets[1] = uint32(offset2)
510	if debug {
511		println("returning, recent offsets:", blk.recentOffsets, "extra literals:", blk.extraLits)
512	}
513}
514
515// EncodeNoHist will encode a block with no history and no following blocks.
516// Most notable difference is that src will not be copied for history and
517// we do not need to check for max match length.
518func (e *betterFastEncoder) EncodeNoHist(blk *blockEnc, src []byte) {
519	e.Encode(blk, src)
520}
521
522// ResetDict will reset and set a dictionary if not nil
523func (e *betterFastEncoder) Reset(d *dict, singleBlock bool) {
524	e.resetBase(d, singleBlock)
525	if d == nil {
526		return
527	}
528	// Init or copy dict table
529	if len(e.dictTable) != len(e.table) || d.id != e.lastDictID {
530		if len(e.dictTable) != len(e.table) {
531			e.dictTable = make([]tableEntry, len(e.table))
532		}
533		end := int32(len(d.content)) - 8 + e.maxMatchOff
534		for i := e.maxMatchOff; i < end; i += 4 {
535			const hashLog = betterShortTableBits
536
537			cv := load6432(d.content, i-e.maxMatchOff)
538			nextHash := hash5(cv, hashLog)      // 0 -> 4
539			nextHash1 := hash5(cv>>8, hashLog)  // 1 -> 5
540			nextHash2 := hash5(cv>>16, hashLog) // 2 -> 6
541			nextHash3 := hash5(cv>>24, hashLog) // 3 -> 7
542			e.dictTable[nextHash] = tableEntry{
543				val:    uint32(cv),
544				offset: i,
545			}
546			e.dictTable[nextHash1] = tableEntry{
547				val:    uint32(cv >> 8),
548				offset: i + 1,
549			}
550			e.dictTable[nextHash2] = tableEntry{
551				val:    uint32(cv >> 16),
552				offset: i + 2,
553			}
554			e.dictTable[nextHash3] = tableEntry{
555				val:    uint32(cv >> 24),
556				offset: i + 3,
557			}
558		}
559		e.lastDictID = d.id
560	}
561
562	// Init or copy dict table
563	if len(e.dictLongTable) != len(e.longTable) || d.id != e.lastDictID {
564		if len(e.dictLongTable) != len(e.longTable) {
565			e.dictLongTable = make([]prevEntry, len(e.longTable))
566		}
567		if len(d.content) >= 8 {
568			cv := load6432(d.content, 0)
569			h := hash8(cv, betterLongTableBits)
570			e.dictLongTable[h] = prevEntry{
571				offset: e.maxMatchOff,
572				prev:   e.dictLongTable[h].offset,
573			}
574
575			end := int32(len(d.content)) - 8 + e.maxMatchOff
576			off := 8 // First to read
577			for i := e.maxMatchOff + 1; i < end; i++ {
578				cv = cv>>8 | (uint64(d.content[off]) << 56)
579				h := hash8(cv, betterLongTableBits)
580				e.dictLongTable[h] = prevEntry{
581					offset: i,
582					prev:   e.dictLongTable[h].offset,
583				}
584				off++
585			}
586		}
587		e.lastDictID = d.id
588	}
589	// Reset table to initial state
590	copy(e.longTable[:], e.dictLongTable)
591
592	e.cur = e.maxMatchOff
593	// Reset table to initial state
594	copy(e.table[:], e.dictTable)
595}
596