1// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import "fmt"
8
9const (
10	dFastLongTableBits = 17                      // Bits used in the long match table
11	dFastLongTableSize = 1 << dFastLongTableBits // Size of the table
12	dFastLongTableMask = dFastLongTableSize - 1  // Mask for table indices. Redundant, but can eliminate bounds checks.
13
14	dFastShortTableBits = tableBits                // Bits used in the short match table
15	dFastShortTableSize = 1 << dFastShortTableBits // Size of the table
16	dFastShortTableMask = dFastShortTableSize - 1  // Mask for table indices. Redundant, but can eliminate bounds checks.
17)
18
19type doubleFastEncoder struct {
20	fastEncoder
21	longTable [dFastLongTableSize]tableEntry
22}
23
24// Encode mimmics functionality in zstd_dfast.c
25func (e *doubleFastEncoder) Encode(blk *blockEnc, src []byte) {
26	const (
27		// Input margin is the number of bytes we read (8)
28		// and the maximum we will read ahead (2)
29		inputMargin            = 8 + 2
30		minNonLiteralBlockSize = 16
31	)
32
33	// Protect against e.cur wraparound.
34	for e.cur >= bufferReset {
35		if len(e.hist) == 0 {
36			for i := range e.table[:] {
37				e.table[i] = tableEntry{}
38			}
39			for i := range e.longTable[:] {
40				e.longTable[i] = tableEntry{}
41			}
42			e.cur = e.maxMatchOff
43			break
44		}
45		// Shift down everything in the table that isn't already too far away.
46		minOff := e.cur + int32(len(e.hist)) - e.maxMatchOff
47		for i := range e.table[:] {
48			v := e.table[i].offset
49			if v < minOff {
50				v = 0
51			} else {
52				v = v - e.cur + e.maxMatchOff
53			}
54			e.table[i].offset = v
55		}
56		for i := range e.longTable[:] {
57			v := e.longTable[i].offset
58			if v < minOff {
59				v = 0
60			} else {
61				v = v - e.cur + e.maxMatchOff
62			}
63			e.longTable[i].offset = v
64		}
65		e.cur = e.maxMatchOff
66		break
67	}
68
69	s := e.addBlock(src)
70	blk.size = len(src)
71	if len(src) < minNonLiteralBlockSize {
72		blk.extraLits = len(src)
73		blk.literals = blk.literals[:len(src)]
74		copy(blk.literals, src)
75		return
76	}
77
78	// Override src
79	src = e.hist
80	sLimit := int32(len(src)) - inputMargin
81	// stepSize is the number of bytes to skip on every main loop iteration.
82	// It should be >= 1.
83	const stepSize = 1
84
85	const kSearchStrength = 8
86
87	// nextEmit is where in src the next emitLiteral should start from.
88	nextEmit := s
89	cv := load6432(src, s)
90
91	// Relative offsets
92	offset1 := int32(blk.recentOffsets[0])
93	offset2 := int32(blk.recentOffsets[1])
94
95	addLiterals := func(s *seq, until int32) {
96		if until == nextEmit {
97			return
98		}
99		blk.literals = append(blk.literals, src[nextEmit:until]...)
100		s.litLen = uint32(until - nextEmit)
101	}
102	if debug {
103		println("recent offsets:", blk.recentOffsets)
104	}
105
106encodeLoop:
107	for {
108		var t int32
109		// We allow the encoder to optionally turn off repeat offsets across blocks
110		canRepeat := len(blk.sequences) > 2
111
112		for {
113			if debugAsserts && canRepeat && offset1 == 0 {
114				panic("offset0 was 0")
115			}
116
117			nextHashS := hash5(cv, dFastShortTableBits)
118			nextHashL := hash8(cv, dFastLongTableBits)
119			candidateL := e.longTable[nextHashL]
120			candidateS := e.table[nextHashS]
121
122			const repOff = 1
123			repIndex := s - offset1 + repOff
124			entry := tableEntry{offset: s + e.cur, val: uint32(cv)}
125			e.longTable[nextHashL] = entry
126			e.table[nextHashS] = entry
127
128			if canRepeat {
129				if repIndex >= 0 && load3232(src, repIndex) == uint32(cv>>(repOff*8)) {
130					// Consider history as well.
131					var seq seq
132					lenght := 4 + e.matchlen(s+4+repOff, repIndex+4, src)
133
134					seq.matchLen = uint32(lenght - zstdMinMatch)
135
136					// We might be able to match backwards.
137					// Extend as long as we can.
138					start := s + repOff
139					// We end the search early, so we don't risk 0 literals
140					// and have to do special offset treatment.
141					startLimit := nextEmit + 1
142
143					tMin := s - e.maxMatchOff
144					if tMin < 0 {
145						tMin = 0
146					}
147					for repIndex > tMin && start > startLimit && src[repIndex-1] == src[start-1] && seq.matchLen < maxMatchLength-zstdMinMatch-1 {
148						repIndex--
149						start--
150						seq.matchLen++
151					}
152					addLiterals(&seq, start)
153
154					// rep 0
155					seq.offset = 1
156					if debugSequences {
157						println("repeat sequence", seq, "next s:", s)
158					}
159					blk.sequences = append(blk.sequences, seq)
160					s += lenght + repOff
161					nextEmit = s
162					if s >= sLimit {
163						if debug {
164							println("repeat ended", s, lenght)
165
166						}
167						break encodeLoop
168					}
169					cv = load6432(src, s)
170					continue
171				}
172			}
173			// Find the offsets of our two matches.
174			coffsetL := s - (candidateL.offset - e.cur)
175			coffsetS := s - (candidateS.offset - e.cur)
176
177			// Check if we have a long match.
178			if coffsetL < e.maxMatchOff && uint32(cv) == candidateL.val {
179				// Found a long match, likely at least 8 bytes.
180				// Reference encoder checks all 8 bytes, we only check 4,
181				// but the likelihood of both the first 4 bytes and the hash matching should be enough.
182				t = candidateL.offset - e.cur
183				if debugAsserts && s <= t {
184					panic(fmt.Sprintf("s (%d) <= t (%d)", s, t))
185				}
186				if debugAsserts && s-t > e.maxMatchOff {
187					panic("s - t >e.maxMatchOff")
188				}
189				if debugMatches {
190					println("long match")
191				}
192				break
193			}
194
195			// Check if we have a short match.
196			if coffsetS < e.maxMatchOff && uint32(cv) == candidateS.val {
197				// found a regular match
198				// See if we can find a long match at s+1
199				const checkAt = 1
200				cv := load6432(src, s+checkAt)
201				nextHashL = hash8(cv, dFastLongTableBits)
202				candidateL = e.longTable[nextHashL]
203				coffsetL = s - (candidateL.offset - e.cur) + checkAt
204
205				// We can store it, since we have at least a 4 byte match.
206				e.longTable[nextHashL] = tableEntry{offset: s + checkAt + e.cur, val: uint32(cv)}
207				if coffsetL < e.maxMatchOff && uint32(cv) == candidateL.val {
208					// Found a long match, likely at least 8 bytes.
209					// Reference encoder checks all 8 bytes, we only check 4,
210					// but the likelihood of both the first 4 bytes and the hash matching should be enough.
211					t = candidateL.offset - e.cur
212					s += checkAt
213					if debugMatches {
214						println("long match (after short)")
215					}
216					break
217				}
218
219				t = candidateS.offset - e.cur
220				if debugAsserts && s <= t {
221					panic(fmt.Sprintf("s (%d) <= t (%d)", s, t))
222				}
223				if debugAsserts && s-t > e.maxMatchOff {
224					panic("s - t >e.maxMatchOff")
225				}
226				if debugAsserts && t < 0 {
227					panic("t<0")
228				}
229				if debugMatches {
230					println("short match")
231				}
232				break
233			}
234
235			// No match found, move forward in input.
236			s += stepSize + ((s - nextEmit) >> (kSearchStrength - 1))
237			if s >= sLimit {
238				break encodeLoop
239			}
240			cv = load6432(src, s)
241		}
242
243		// A 4-byte match has been found. Update recent offsets.
244		// We'll later see if more than 4 bytes.
245		offset2 = offset1
246		offset1 = s - t
247
248		if debugAsserts && s <= t {
249			panic(fmt.Sprintf("s (%d) <= t (%d)", s, t))
250		}
251
252		if debugAsserts && canRepeat && int(offset1) > len(src) {
253			panic("invalid offset")
254		}
255
256		// Extend the 4-byte match as long as possible.
257		l := e.matchlen(s+4, t+4, src) + 4
258
259		// Extend backwards
260		tMin := s - e.maxMatchOff
261		if tMin < 0 {
262			tMin = 0
263		}
264		for t > tMin && s > nextEmit && src[t-1] == src[s-1] && l < maxMatchLength {
265			s--
266			t--
267			l++
268		}
269
270		// Write our sequence
271		var seq seq
272		seq.litLen = uint32(s - nextEmit)
273		seq.matchLen = uint32(l - zstdMinMatch)
274		if seq.litLen > 0 {
275			blk.literals = append(blk.literals, src[nextEmit:s]...)
276		}
277		seq.offset = uint32(s-t) + 3
278		s += l
279		if debugSequences {
280			println("sequence", seq, "next s:", s)
281		}
282		blk.sequences = append(blk.sequences, seq)
283		nextEmit = s
284		if s >= sLimit {
285			break encodeLoop
286		}
287
288		// Index match start+1 (long) and start+2 (short)
289		index0 := s - l + 1
290		// Index match end-2 (long) and end-1 (short)
291		index1 := s - 2
292
293		cv0 := load6432(src, index0)
294		cv1 := load6432(src, index1)
295		te0 := tableEntry{offset: index0 + e.cur, val: uint32(cv0)}
296		te1 := tableEntry{offset: index1 + e.cur, val: uint32(cv1)}
297		e.longTable[hash8(cv0, dFastLongTableBits)] = te0
298		e.longTable[hash8(cv1, dFastLongTableBits)] = te1
299		cv0 >>= 8
300		cv1 >>= 8
301		te0.offset++
302		te1.offset++
303		te0.val = uint32(cv0)
304		te1.val = uint32(cv1)
305		e.table[hash5(cv0, dFastShortTableBits)] = te0
306		e.table[hash5(cv1, dFastShortTableBits)] = te1
307
308		cv = load6432(src, s)
309
310		if !canRepeat {
311			continue
312		}
313
314		// Check offset 2
315		for {
316			o2 := s - offset2
317			if load3232(src, o2) != uint32(cv) {
318				// Do regular search
319				break
320			}
321
322			// Store this, since we have it.
323			nextHashS := hash5(cv, dFastShortTableBits)
324			nextHashL := hash8(cv, dFastLongTableBits)
325
326			// We have at least 4 byte match.
327			// No need to check backwards. We come straight from a match
328			l := 4 + e.matchlen(s+4, o2+4, src)
329
330			entry := tableEntry{offset: s + e.cur, val: uint32(cv)}
331			e.longTable[nextHashL] = entry
332			e.table[nextHashS] = entry
333			seq.matchLen = uint32(l) - zstdMinMatch
334			seq.litLen = 0
335
336			// Since litlen is always 0, this is offset 1.
337			seq.offset = 1
338			s += l
339			nextEmit = s
340			if debugSequences {
341				println("sequence", seq, "next s:", s)
342			}
343			blk.sequences = append(blk.sequences, seq)
344
345			// Swap offset 1 and 2.
346			offset1, offset2 = offset2, offset1
347			if s >= sLimit {
348				// Finished
349				break encodeLoop
350			}
351			cv = load6432(src, s)
352		}
353	}
354
355	if int(nextEmit) < len(src) {
356		blk.literals = append(blk.literals, src[nextEmit:]...)
357		blk.extraLits = len(src) - int(nextEmit)
358	}
359	blk.recentOffsets[0] = uint32(offset1)
360	blk.recentOffsets[1] = uint32(offset2)
361	if debug {
362		println("returning, recent offsets:", blk.recentOffsets, "extra literals:", blk.extraLits)
363	}
364}
365
366// EncodeNoHist will encode a block with no history and no following blocks.
367// Most notable difference is that src will not be copied for history and
368// we do not need to check for max match length.
369func (e *doubleFastEncoder) EncodeNoHist(blk *blockEnc, src []byte) {
370	const (
371		// Input margin is the number of bytes we read (8)
372		// and the maximum we will read ahead (2)
373		inputMargin            = 8 + 2
374		minNonLiteralBlockSize = 16
375	)
376
377	// Protect against e.cur wraparound.
378	if e.cur >= bufferReset {
379		for i := range e.table[:] {
380			e.table[i] = tableEntry{}
381		}
382		for i := range e.longTable[:] {
383			e.longTable[i] = tableEntry{}
384		}
385		e.cur = e.maxMatchOff
386	}
387
388	s := int32(0)
389	blk.size = len(src)
390	if len(src) < minNonLiteralBlockSize {
391		blk.extraLits = len(src)
392		blk.literals = blk.literals[:len(src)]
393		copy(blk.literals, src)
394		return
395	}
396
397	// Override src
398	sLimit := int32(len(src)) - inputMargin
399	// stepSize is the number of bytes to skip on every main loop iteration.
400	// It should be >= 1.
401	const stepSize = 1
402
403	const kSearchStrength = 8
404
405	// nextEmit is where in src the next emitLiteral should start from.
406	nextEmit := s
407	cv := load6432(src, s)
408
409	// Relative offsets
410	offset1 := int32(blk.recentOffsets[0])
411	offset2 := int32(blk.recentOffsets[1])
412
413	addLiterals := func(s *seq, until int32) {
414		if until == nextEmit {
415			return
416		}
417		blk.literals = append(blk.literals, src[nextEmit:until]...)
418		s.litLen = uint32(until - nextEmit)
419	}
420	if debug {
421		println("recent offsets:", blk.recentOffsets)
422	}
423
424encodeLoop:
425	for {
426		var t int32
427		for {
428
429			nextHashS := hash5(cv, dFastShortTableBits)
430			nextHashL := hash8(cv, dFastLongTableBits)
431			candidateL := e.longTable[nextHashL]
432			candidateS := e.table[nextHashS]
433
434			const repOff = 1
435			repIndex := s - offset1 + repOff
436			entry := tableEntry{offset: s + e.cur, val: uint32(cv)}
437			e.longTable[nextHashL] = entry
438			e.table[nextHashS] = entry
439
440			if len(blk.sequences) > 2 {
441				if load3232(src, repIndex) == uint32(cv>>(repOff*8)) {
442					// Consider history as well.
443					var seq seq
444					//length := 4 + e.matchlen(s+4+repOff, repIndex+4, src)
445					length := 4 + int32(matchLen(src[s+4+repOff:], src[repIndex+4:]))
446
447					seq.matchLen = uint32(length - zstdMinMatch)
448
449					// We might be able to match backwards.
450					// Extend as long as we can.
451					start := s + repOff
452					// We end the search early, so we don't risk 0 literals
453					// and have to do special offset treatment.
454					startLimit := nextEmit + 1
455
456					tMin := s - e.maxMatchOff
457					if tMin < 0 {
458						tMin = 0
459					}
460					for repIndex > tMin && start > startLimit && src[repIndex-1] == src[start-1] {
461						repIndex--
462						start--
463						seq.matchLen++
464					}
465					addLiterals(&seq, start)
466
467					// rep 0
468					seq.offset = 1
469					if debugSequences {
470						println("repeat sequence", seq, "next s:", s)
471					}
472					blk.sequences = append(blk.sequences, seq)
473					s += length + repOff
474					nextEmit = s
475					if s >= sLimit {
476						if debug {
477							println("repeat ended", s, length)
478
479						}
480						break encodeLoop
481					}
482					cv = load6432(src, s)
483					continue
484				}
485			}
486			// Find the offsets of our two matches.
487			coffsetL := s - (candidateL.offset - e.cur)
488			coffsetS := s - (candidateS.offset - e.cur)
489
490			// Check if we have a long match.
491			if coffsetL < e.maxMatchOff && uint32(cv) == candidateL.val {
492				// Found a long match, likely at least 8 bytes.
493				// Reference encoder checks all 8 bytes, we only check 4,
494				// but the likelihood of both the first 4 bytes and the hash matching should be enough.
495				t = candidateL.offset - e.cur
496				if debugAsserts && s <= t {
497					panic(fmt.Sprintf("s (%d) <= t (%d)", s, t))
498				}
499				if debugAsserts && s-t > e.maxMatchOff {
500					panic("s - t >e.maxMatchOff")
501				}
502				if debugMatches {
503					println("long match")
504				}
505				break
506			}
507
508			// Check if we have a short match.
509			if coffsetS < e.maxMatchOff && uint32(cv) == candidateS.val {
510				// found a regular match
511				// See if we can find a long match at s+1
512				const checkAt = 1
513				cv := load6432(src, s+checkAt)
514				nextHashL = hash8(cv, dFastLongTableBits)
515				candidateL = e.longTable[nextHashL]
516				coffsetL = s - (candidateL.offset - e.cur) + checkAt
517
518				// We can store it, since we have at least a 4 byte match.
519				e.longTable[nextHashL] = tableEntry{offset: s + checkAt + e.cur, val: uint32(cv)}
520				if coffsetL < e.maxMatchOff && uint32(cv) == candidateL.val {
521					// Found a long match, likely at least 8 bytes.
522					// Reference encoder checks all 8 bytes, we only check 4,
523					// but the likelihood of both the first 4 bytes and the hash matching should be enough.
524					t = candidateL.offset - e.cur
525					s += checkAt
526					if debugMatches {
527						println("long match (after short)")
528					}
529					break
530				}
531
532				t = candidateS.offset - e.cur
533				if debugAsserts && s <= t {
534					panic(fmt.Sprintf("s (%d) <= t (%d)", s, t))
535				}
536				if debugAsserts && s-t > e.maxMatchOff {
537					panic("s - t >e.maxMatchOff")
538				}
539				if debugAsserts && t < 0 {
540					panic("t<0")
541				}
542				if debugMatches {
543					println("short match")
544				}
545				break
546			}
547
548			// No match found, move forward in input.
549			s += stepSize + ((s - nextEmit) >> (kSearchStrength - 1))
550			if s >= sLimit {
551				break encodeLoop
552			}
553			cv = load6432(src, s)
554		}
555
556		// A 4-byte match has been found. Update recent offsets.
557		// We'll later see if more than 4 bytes.
558		offset2 = offset1
559		offset1 = s - t
560
561		if debugAsserts && s <= t {
562			panic(fmt.Sprintf("s (%d) <= t (%d)", s, t))
563		}
564
565		// Extend the 4-byte match as long as possible.
566		//l := e.matchlen(s+4, t+4, src) + 4
567		l := int32(matchLen(src[s+4:], src[t+4:])) + 4
568
569		// Extend backwards
570		tMin := s - e.maxMatchOff
571		if tMin < 0 {
572			tMin = 0
573		}
574		for t > tMin && s > nextEmit && src[t-1] == src[s-1] {
575			s--
576			t--
577			l++
578		}
579
580		// Write our sequence
581		var seq seq
582		seq.litLen = uint32(s - nextEmit)
583		seq.matchLen = uint32(l - zstdMinMatch)
584		if seq.litLen > 0 {
585			blk.literals = append(blk.literals, src[nextEmit:s]...)
586		}
587		seq.offset = uint32(s-t) + 3
588		s += l
589		if debugSequences {
590			println("sequence", seq, "next s:", s)
591		}
592		blk.sequences = append(blk.sequences, seq)
593		nextEmit = s
594		if s >= sLimit {
595			break encodeLoop
596		}
597
598		// Index match start+1 (long) and start+2 (short)
599		index0 := s - l + 1
600		// Index match end-2 (long) and end-1 (short)
601		index1 := s - 2
602
603		cv0 := load6432(src, index0)
604		cv1 := load6432(src, index1)
605		te0 := tableEntry{offset: index0 + e.cur, val: uint32(cv0)}
606		te1 := tableEntry{offset: index1 + e.cur, val: uint32(cv1)}
607		e.longTable[hash8(cv0, dFastLongTableBits)] = te0
608		e.longTable[hash8(cv1, dFastLongTableBits)] = te1
609		cv0 >>= 8
610		cv1 >>= 8
611		te0.offset++
612		te1.offset++
613		te0.val = uint32(cv0)
614		te1.val = uint32(cv1)
615		e.table[hash5(cv0, dFastShortTableBits)] = te0
616		e.table[hash5(cv1, dFastShortTableBits)] = te1
617
618		cv = load6432(src, s)
619
620		if len(blk.sequences) <= 2 {
621			continue
622		}
623
624		// Check offset 2
625		for {
626			o2 := s - offset2
627			if load3232(src, o2) != uint32(cv) {
628				// Do regular search
629				break
630			}
631
632			// Store this, since we have it.
633			nextHashS := hash5(cv1>>8, dFastShortTableBits)
634			nextHashL := hash8(cv, dFastLongTableBits)
635
636			// We have at least 4 byte match.
637			// No need to check backwards. We come straight from a match
638			//l := 4 + e.matchlen(s+4, o2+4, src)
639			l := 4 + int32(matchLen(src[s+4:], src[o2+4:]))
640
641			entry := tableEntry{offset: s + e.cur, val: uint32(cv)}
642			e.longTable[nextHashL] = entry
643			e.table[nextHashS] = entry
644			seq.matchLen = uint32(l) - zstdMinMatch
645			seq.litLen = 0
646
647			// Since litlen is always 0, this is offset 1.
648			seq.offset = 1
649			s += l
650			nextEmit = s
651			if debugSequences {
652				println("sequence", seq, "next s:", s)
653			}
654			blk.sequences = append(blk.sequences, seq)
655
656			// Swap offset 1 and 2.
657			offset1, offset2 = offset2, offset1
658			if s >= sLimit {
659				// Finished
660				break encodeLoop
661			}
662			cv = load6432(src, s)
663		}
664	}
665
666	if int(nextEmit) < len(src) {
667		blk.literals = append(blk.literals, src[nextEmit:]...)
668		blk.extraLits = len(src) - int(nextEmit)
669	}
670	if debug {
671		println("returning, recent offsets:", blk.recentOffsets, "extra literals:", blk.extraLits)
672	}
673
674	// We do not store history, so we must offset e.cur to avoid false matches for next user.
675	if e.cur < bufferReset {
676		e.cur += int32(len(src))
677	}
678}
679