1// Copyright 2019+ Klaus Post. All rights reserved.
2// License information can be found in the LICENSE file.
3// Based on work by Yann Collet, released under BSD License.
4
5package zstd
6
7import "fmt"
8
9const (
10	dFastLongTableBits = 17                      // Bits used in the long match table
11	dFastLongTableSize = 1 << dFastLongTableBits // Size of the table
12	dFastLongTableMask = dFastLongTableSize - 1  // Mask for table indices. Redundant, but can eliminate bounds checks.
13
14	dFastShortTableBits = tableBits                // Bits used in the short match table
15	dFastShortTableSize = 1 << dFastShortTableBits // Size of the table
16	dFastShortTableMask = dFastShortTableSize - 1  // Mask for table indices. Redundant, but can eliminate bounds checks.
17)
18
19type doubleFastEncoder struct {
20	fastEncoder
21	longTable     [dFastLongTableSize]tableEntry
22	dictLongTable []tableEntry
23}
24
25// Encode mimmics functionality in zstd_dfast.c
26func (e *doubleFastEncoder) Encode(blk *blockEnc, src []byte) {
27	const (
28		// Input margin is the number of bytes we read (8)
29		// and the maximum we will read ahead (2)
30		inputMargin            = 8 + 2
31		minNonLiteralBlockSize = 16
32	)
33
34	// Protect against e.cur wraparound.
35	for e.cur >= bufferReset {
36		if len(e.hist) == 0 {
37			for i := range e.table[:] {
38				e.table[i] = tableEntry{}
39			}
40			for i := range e.longTable[:] {
41				e.longTable[i] = tableEntry{}
42			}
43			e.cur = e.maxMatchOff
44			break
45		}
46		// Shift down everything in the table that isn't already too far away.
47		minOff := e.cur + int32(len(e.hist)) - e.maxMatchOff
48		for i := range e.table[:] {
49			v := e.table[i].offset
50			if v < minOff {
51				v = 0
52			} else {
53				v = v - e.cur + e.maxMatchOff
54			}
55			e.table[i].offset = v
56		}
57		for i := range e.longTable[:] {
58			v := e.longTable[i].offset
59			if v < minOff {
60				v = 0
61			} else {
62				v = v - e.cur + e.maxMatchOff
63			}
64			e.longTable[i].offset = v
65		}
66		e.cur = e.maxMatchOff
67		break
68	}
69
70	s := e.addBlock(src)
71	blk.size = len(src)
72	if len(src) < minNonLiteralBlockSize {
73		blk.extraLits = len(src)
74		blk.literals = blk.literals[:len(src)]
75		copy(blk.literals, src)
76		return
77	}
78
79	// Override src
80	src = e.hist
81	sLimit := int32(len(src)) - inputMargin
82	// stepSize is the number of bytes to skip on every main loop iteration.
83	// It should be >= 1.
84	const stepSize = 1
85
86	const kSearchStrength = 8
87
88	// nextEmit is where in src the next emitLiteral should start from.
89	nextEmit := s
90	cv := load6432(src, s)
91
92	// Relative offsets
93	offset1 := int32(blk.recentOffsets[0])
94	offset2 := int32(blk.recentOffsets[1])
95
96	addLiterals := func(s *seq, until int32) {
97		if until == nextEmit {
98			return
99		}
100		blk.literals = append(blk.literals, src[nextEmit:until]...)
101		s.litLen = uint32(until - nextEmit)
102	}
103	if debug {
104		println("recent offsets:", blk.recentOffsets)
105	}
106
107encodeLoop:
108	for {
109		var t int32
110		// We allow the encoder to optionally turn off repeat offsets across blocks
111		canRepeat := len(blk.sequences) > 2
112
113		for {
114			if debugAsserts && canRepeat && offset1 == 0 {
115				panic("offset0 was 0")
116			}
117
118			nextHashS := hash5(cv, dFastShortTableBits)
119			nextHashL := hash8(cv, dFastLongTableBits)
120			candidateL := e.longTable[nextHashL]
121			candidateS := e.table[nextHashS]
122
123			const repOff = 1
124			repIndex := s - offset1 + repOff
125			entry := tableEntry{offset: s + e.cur, val: uint32(cv)}
126			e.longTable[nextHashL] = entry
127			e.table[nextHashS] = entry
128
129			if canRepeat {
130				if repIndex >= 0 && load3232(src, repIndex) == uint32(cv>>(repOff*8)) {
131					// Consider history as well.
132					var seq seq
133					lenght := 4 + e.matchlen(s+4+repOff, repIndex+4, src)
134
135					seq.matchLen = uint32(lenght - zstdMinMatch)
136
137					// We might be able to match backwards.
138					// Extend as long as we can.
139					start := s + repOff
140					// We end the search early, so we don't risk 0 literals
141					// and have to do special offset treatment.
142					startLimit := nextEmit + 1
143
144					tMin := s - e.maxMatchOff
145					if tMin < 0 {
146						tMin = 0
147					}
148					for repIndex > tMin && start > startLimit && src[repIndex-1] == src[start-1] && seq.matchLen < maxMatchLength-zstdMinMatch-1 {
149						repIndex--
150						start--
151						seq.matchLen++
152					}
153					addLiterals(&seq, start)
154
155					// rep 0
156					seq.offset = 1
157					if debugSequences {
158						println("repeat sequence", seq, "next s:", s)
159					}
160					blk.sequences = append(blk.sequences, seq)
161					s += lenght + repOff
162					nextEmit = s
163					if s >= sLimit {
164						if debug {
165							println("repeat ended", s, lenght)
166
167						}
168						break encodeLoop
169					}
170					cv = load6432(src, s)
171					continue
172				}
173			}
174			// Find the offsets of our two matches.
175			coffsetL := s - (candidateL.offset - e.cur)
176			coffsetS := s - (candidateS.offset - e.cur)
177
178			// Check if we have a long match.
179			if coffsetL < e.maxMatchOff && uint32(cv) == candidateL.val {
180				// Found a long match, likely at least 8 bytes.
181				// Reference encoder checks all 8 bytes, we only check 4,
182				// but the likelihood of both the first 4 bytes and the hash matching should be enough.
183				t = candidateL.offset - e.cur
184				if debugAsserts && s <= t {
185					panic(fmt.Sprintf("s (%d) <= t (%d)", s, t))
186				}
187				if debugAsserts && s-t > e.maxMatchOff {
188					panic("s - t >e.maxMatchOff")
189				}
190				if debugMatches {
191					println("long match")
192				}
193				break
194			}
195
196			// Check if we have a short match.
197			if coffsetS < e.maxMatchOff && uint32(cv) == candidateS.val {
198				// found a regular match
199				// See if we can find a long match at s+1
200				const checkAt = 1
201				cv := load6432(src, s+checkAt)
202				nextHashL = hash8(cv, dFastLongTableBits)
203				candidateL = e.longTable[nextHashL]
204				coffsetL = s - (candidateL.offset - e.cur) + checkAt
205
206				// We can store it, since we have at least a 4 byte match.
207				e.longTable[nextHashL] = tableEntry{offset: s + checkAt + e.cur, val: uint32(cv)}
208				if coffsetL < e.maxMatchOff && uint32(cv) == candidateL.val {
209					// Found a long match, likely at least 8 bytes.
210					// Reference encoder checks all 8 bytes, we only check 4,
211					// but the likelihood of both the first 4 bytes and the hash matching should be enough.
212					t = candidateL.offset - e.cur
213					s += checkAt
214					if debugMatches {
215						println("long match (after short)")
216					}
217					break
218				}
219
220				t = candidateS.offset - e.cur
221				if debugAsserts && s <= t {
222					panic(fmt.Sprintf("s (%d) <= t (%d)", s, t))
223				}
224				if debugAsserts && s-t > e.maxMatchOff {
225					panic("s - t >e.maxMatchOff")
226				}
227				if debugAsserts && t < 0 {
228					panic("t<0")
229				}
230				if debugMatches {
231					println("short match")
232				}
233				break
234			}
235
236			// No match found, move forward in input.
237			s += stepSize + ((s - nextEmit) >> (kSearchStrength - 1))
238			if s >= sLimit {
239				break encodeLoop
240			}
241			cv = load6432(src, s)
242		}
243
244		// A 4-byte match has been found. Update recent offsets.
245		// We'll later see if more than 4 bytes.
246		offset2 = offset1
247		offset1 = s - t
248
249		if debugAsserts && s <= t {
250			panic(fmt.Sprintf("s (%d) <= t (%d)", s, t))
251		}
252
253		if debugAsserts && canRepeat && int(offset1) > len(src) {
254			panic("invalid offset")
255		}
256
257		// Extend the 4-byte match as long as possible.
258		l := e.matchlen(s+4, t+4, src) + 4
259
260		// Extend backwards
261		tMin := s - e.maxMatchOff
262		if tMin < 0 {
263			tMin = 0
264		}
265		for t > tMin && s > nextEmit && src[t-1] == src[s-1] && l < maxMatchLength {
266			s--
267			t--
268			l++
269		}
270
271		// Write our sequence
272		var seq seq
273		seq.litLen = uint32(s - nextEmit)
274		seq.matchLen = uint32(l - zstdMinMatch)
275		if seq.litLen > 0 {
276			blk.literals = append(blk.literals, src[nextEmit:s]...)
277		}
278		seq.offset = uint32(s-t) + 3
279		s += l
280		if debugSequences {
281			println("sequence", seq, "next s:", s)
282		}
283		blk.sequences = append(blk.sequences, seq)
284		nextEmit = s
285		if s >= sLimit {
286			break encodeLoop
287		}
288
289		// Index match start+1 (long) and start+2 (short)
290		index0 := s - l + 1
291		// Index match end-2 (long) and end-1 (short)
292		index1 := s - 2
293
294		cv0 := load6432(src, index0)
295		cv1 := load6432(src, index1)
296		te0 := tableEntry{offset: index0 + e.cur, val: uint32(cv0)}
297		te1 := tableEntry{offset: index1 + e.cur, val: uint32(cv1)}
298		e.longTable[hash8(cv0, dFastLongTableBits)] = te0
299		e.longTable[hash8(cv1, dFastLongTableBits)] = te1
300		cv0 >>= 8
301		cv1 >>= 8
302		te0.offset++
303		te1.offset++
304		te0.val = uint32(cv0)
305		te1.val = uint32(cv1)
306		e.table[hash5(cv0, dFastShortTableBits)] = te0
307		e.table[hash5(cv1, dFastShortTableBits)] = te1
308
309		cv = load6432(src, s)
310
311		if !canRepeat {
312			continue
313		}
314
315		// Check offset 2
316		for {
317			o2 := s - offset2
318			if load3232(src, o2) != uint32(cv) {
319				// Do regular search
320				break
321			}
322
323			// Store this, since we have it.
324			nextHashS := hash5(cv, dFastShortTableBits)
325			nextHashL := hash8(cv, dFastLongTableBits)
326
327			// We have at least 4 byte match.
328			// No need to check backwards. We come straight from a match
329			l := 4 + e.matchlen(s+4, o2+4, src)
330
331			entry := tableEntry{offset: s + e.cur, val: uint32(cv)}
332			e.longTable[nextHashL] = entry
333			e.table[nextHashS] = entry
334			seq.matchLen = uint32(l) - zstdMinMatch
335			seq.litLen = 0
336
337			// Since litlen is always 0, this is offset 1.
338			seq.offset = 1
339			s += l
340			nextEmit = s
341			if debugSequences {
342				println("sequence", seq, "next s:", s)
343			}
344			blk.sequences = append(blk.sequences, seq)
345
346			// Swap offset 1 and 2.
347			offset1, offset2 = offset2, offset1
348			if s >= sLimit {
349				// Finished
350				break encodeLoop
351			}
352			cv = load6432(src, s)
353		}
354	}
355
356	if int(nextEmit) < len(src) {
357		blk.literals = append(blk.literals, src[nextEmit:]...)
358		blk.extraLits = len(src) - int(nextEmit)
359	}
360	blk.recentOffsets[0] = uint32(offset1)
361	blk.recentOffsets[1] = uint32(offset2)
362	if debug {
363		println("returning, recent offsets:", blk.recentOffsets, "extra literals:", blk.extraLits)
364	}
365}
366
367// EncodeNoHist will encode a block with no history and no following blocks.
368// Most notable difference is that src will not be copied for history and
369// we do not need to check for max match length.
370func (e *doubleFastEncoder) EncodeNoHist(blk *blockEnc, src []byte) {
371	const (
372		// Input margin is the number of bytes we read (8)
373		// and the maximum we will read ahead (2)
374		inputMargin            = 8 + 2
375		minNonLiteralBlockSize = 16
376	)
377
378	// Protect against e.cur wraparound.
379	if e.cur >= bufferReset {
380		for i := range e.table[:] {
381			e.table[i] = tableEntry{}
382		}
383		for i := range e.longTable[:] {
384			e.longTable[i] = tableEntry{}
385		}
386		e.cur = e.maxMatchOff
387	}
388
389	s := int32(0)
390	blk.size = len(src)
391	if len(src) < minNonLiteralBlockSize {
392		blk.extraLits = len(src)
393		blk.literals = blk.literals[:len(src)]
394		copy(blk.literals, src)
395		return
396	}
397
398	// Override src
399	sLimit := int32(len(src)) - inputMargin
400	// stepSize is the number of bytes to skip on every main loop iteration.
401	// It should be >= 1.
402	const stepSize = 1
403
404	const kSearchStrength = 8
405
406	// nextEmit is where in src the next emitLiteral should start from.
407	nextEmit := s
408	cv := load6432(src, s)
409
410	// Relative offsets
411	offset1 := int32(blk.recentOffsets[0])
412	offset2 := int32(blk.recentOffsets[1])
413
414	addLiterals := func(s *seq, until int32) {
415		if until == nextEmit {
416			return
417		}
418		blk.literals = append(blk.literals, src[nextEmit:until]...)
419		s.litLen = uint32(until - nextEmit)
420	}
421	if debug {
422		println("recent offsets:", blk.recentOffsets)
423	}
424
425encodeLoop:
426	for {
427		var t int32
428		for {
429
430			nextHashS := hash5(cv, dFastShortTableBits)
431			nextHashL := hash8(cv, dFastLongTableBits)
432			candidateL := e.longTable[nextHashL]
433			candidateS := e.table[nextHashS]
434
435			const repOff = 1
436			repIndex := s - offset1 + repOff
437			entry := tableEntry{offset: s + e.cur, val: uint32(cv)}
438			e.longTable[nextHashL] = entry
439			e.table[nextHashS] = entry
440
441			if len(blk.sequences) > 2 {
442				if load3232(src, repIndex) == uint32(cv>>(repOff*8)) {
443					// Consider history as well.
444					var seq seq
445					//length := 4 + e.matchlen(s+4+repOff, repIndex+4, src)
446					length := 4 + int32(matchLen(src[s+4+repOff:], src[repIndex+4:]))
447
448					seq.matchLen = uint32(length - zstdMinMatch)
449
450					// We might be able to match backwards.
451					// Extend as long as we can.
452					start := s + repOff
453					// We end the search early, so we don't risk 0 literals
454					// and have to do special offset treatment.
455					startLimit := nextEmit + 1
456
457					tMin := s - e.maxMatchOff
458					if tMin < 0 {
459						tMin = 0
460					}
461					for repIndex > tMin && start > startLimit && src[repIndex-1] == src[start-1] {
462						repIndex--
463						start--
464						seq.matchLen++
465					}
466					addLiterals(&seq, start)
467
468					// rep 0
469					seq.offset = 1
470					if debugSequences {
471						println("repeat sequence", seq, "next s:", s)
472					}
473					blk.sequences = append(blk.sequences, seq)
474					s += length + repOff
475					nextEmit = s
476					if s >= sLimit {
477						if debug {
478							println("repeat ended", s, length)
479
480						}
481						break encodeLoop
482					}
483					cv = load6432(src, s)
484					continue
485				}
486			}
487			// Find the offsets of our two matches.
488			coffsetL := s - (candidateL.offset - e.cur)
489			coffsetS := s - (candidateS.offset - e.cur)
490
491			// Check if we have a long match.
492			if coffsetL < e.maxMatchOff && uint32(cv) == candidateL.val {
493				// Found a long match, likely at least 8 bytes.
494				// Reference encoder checks all 8 bytes, we only check 4,
495				// but the likelihood of both the first 4 bytes and the hash matching should be enough.
496				t = candidateL.offset - e.cur
497				if debugAsserts && s <= t {
498					panic(fmt.Sprintf("s (%d) <= t (%d). cur: %d", s, t, e.cur))
499				}
500				if debugAsserts && s-t > e.maxMatchOff {
501					panic("s - t >e.maxMatchOff")
502				}
503				if debugMatches {
504					println("long match")
505				}
506				break
507			}
508
509			// Check if we have a short match.
510			if coffsetS < e.maxMatchOff && uint32(cv) == candidateS.val {
511				// found a regular match
512				// See if we can find a long match at s+1
513				const checkAt = 1
514				cv := load6432(src, s+checkAt)
515				nextHashL = hash8(cv, dFastLongTableBits)
516				candidateL = e.longTable[nextHashL]
517				coffsetL = s - (candidateL.offset - e.cur) + checkAt
518
519				// We can store it, since we have at least a 4 byte match.
520				e.longTable[nextHashL] = tableEntry{offset: s + checkAt + e.cur, val: uint32(cv)}
521				if coffsetL < e.maxMatchOff && uint32(cv) == candidateL.val {
522					// Found a long match, likely at least 8 bytes.
523					// Reference encoder checks all 8 bytes, we only check 4,
524					// but the likelihood of both the first 4 bytes and the hash matching should be enough.
525					t = candidateL.offset - e.cur
526					s += checkAt
527					if debugMatches {
528						println("long match (after short)")
529					}
530					break
531				}
532
533				t = candidateS.offset - e.cur
534				if debugAsserts && s <= t {
535					panic(fmt.Sprintf("s (%d) <= t (%d)", s, t))
536				}
537				if debugAsserts && s-t > e.maxMatchOff {
538					panic("s - t >e.maxMatchOff")
539				}
540				if debugAsserts && t < 0 {
541					panic("t<0")
542				}
543				if debugMatches {
544					println("short match")
545				}
546				break
547			}
548
549			// No match found, move forward in input.
550			s += stepSize + ((s - nextEmit) >> (kSearchStrength - 1))
551			if s >= sLimit {
552				break encodeLoop
553			}
554			cv = load6432(src, s)
555		}
556
557		// A 4-byte match has been found. Update recent offsets.
558		// We'll later see if more than 4 bytes.
559		offset2 = offset1
560		offset1 = s - t
561
562		if debugAsserts && s <= t {
563			panic(fmt.Sprintf("s (%d) <= t (%d)", s, t))
564		}
565
566		// Extend the 4-byte match as long as possible.
567		//l := e.matchlen(s+4, t+4, src) + 4
568		l := int32(matchLen(src[s+4:], src[t+4:])) + 4
569
570		// Extend backwards
571		tMin := s - e.maxMatchOff
572		if tMin < 0 {
573			tMin = 0
574		}
575		for t > tMin && s > nextEmit && src[t-1] == src[s-1] {
576			s--
577			t--
578			l++
579		}
580
581		// Write our sequence
582		var seq seq
583		seq.litLen = uint32(s - nextEmit)
584		seq.matchLen = uint32(l - zstdMinMatch)
585		if seq.litLen > 0 {
586			blk.literals = append(blk.literals, src[nextEmit:s]...)
587		}
588		seq.offset = uint32(s-t) + 3
589		s += l
590		if debugSequences {
591			println("sequence", seq, "next s:", s)
592		}
593		blk.sequences = append(blk.sequences, seq)
594		nextEmit = s
595		if s >= sLimit {
596			break encodeLoop
597		}
598
599		// Index match start+1 (long) and start+2 (short)
600		index0 := s - l + 1
601		// Index match end-2 (long) and end-1 (short)
602		index1 := s - 2
603
604		cv0 := load6432(src, index0)
605		cv1 := load6432(src, index1)
606		te0 := tableEntry{offset: index0 + e.cur, val: uint32(cv0)}
607		te1 := tableEntry{offset: index1 + e.cur, val: uint32(cv1)}
608		e.longTable[hash8(cv0, dFastLongTableBits)] = te0
609		e.longTable[hash8(cv1, dFastLongTableBits)] = te1
610		cv0 >>= 8
611		cv1 >>= 8
612		te0.offset++
613		te1.offset++
614		te0.val = uint32(cv0)
615		te1.val = uint32(cv1)
616		e.table[hash5(cv0, dFastShortTableBits)] = te0
617		e.table[hash5(cv1, dFastShortTableBits)] = te1
618
619		cv = load6432(src, s)
620
621		if len(blk.sequences) <= 2 {
622			continue
623		}
624
625		// Check offset 2
626		for {
627			o2 := s - offset2
628			if load3232(src, o2) != uint32(cv) {
629				// Do regular search
630				break
631			}
632
633			// Store this, since we have it.
634			nextHashS := hash5(cv1>>8, dFastShortTableBits)
635			nextHashL := hash8(cv, dFastLongTableBits)
636
637			// We have at least 4 byte match.
638			// No need to check backwards. We come straight from a match
639			//l := 4 + e.matchlen(s+4, o2+4, src)
640			l := 4 + int32(matchLen(src[s+4:], src[o2+4:]))
641
642			entry := tableEntry{offset: s + e.cur, val: uint32(cv)}
643			e.longTable[nextHashL] = entry
644			e.table[nextHashS] = entry
645			seq.matchLen = uint32(l) - zstdMinMatch
646			seq.litLen = 0
647
648			// Since litlen is always 0, this is offset 1.
649			seq.offset = 1
650			s += l
651			nextEmit = s
652			if debugSequences {
653				println("sequence", seq, "next s:", s)
654			}
655			blk.sequences = append(blk.sequences, seq)
656
657			// Swap offset 1 and 2.
658			offset1, offset2 = offset2, offset1
659			if s >= sLimit {
660				// Finished
661				break encodeLoop
662			}
663			cv = load6432(src, s)
664		}
665	}
666
667	if int(nextEmit) < len(src) {
668		blk.literals = append(blk.literals, src[nextEmit:]...)
669		blk.extraLits = len(src) - int(nextEmit)
670	}
671	if debug {
672		println("returning, recent offsets:", blk.recentOffsets, "extra literals:", blk.extraLits)
673	}
674
675	// We do not store history, so we must offset e.cur to avoid false matches for next user.
676	if e.cur < bufferReset {
677		e.cur += int32(len(src))
678	}
679}
680
681// ResetDict will reset and set a dictionary if not nil
682func (e *doubleFastEncoder) Reset(d *dict, singleBlock bool) {
683	e.fastEncoder.Reset(d, singleBlock)
684	if d == nil {
685		return
686	}
687
688	// Init or copy dict table
689	if len(e.dictLongTable) != len(e.longTable) || d.id != e.lastDictID {
690		if len(e.dictLongTable) != len(e.longTable) {
691			e.dictLongTable = make([]tableEntry, len(e.longTable))
692		}
693		if len(d.content) >= 8 {
694			cv := load6432(d.content, 0)
695			e.dictLongTable[hash8(cv, dFastLongTableBits)] = tableEntry{
696				val:    uint32(cv),
697				offset: e.maxMatchOff,
698			}
699			end := int32(len(d.content)) - 8 + e.maxMatchOff
700			for i := e.maxMatchOff + 1; i < end; i++ {
701				cv = cv>>8 | (uint64(d.content[i-e.maxMatchOff+7]) << 56)
702				e.dictLongTable[hash8(cv, dFastLongTableBits)] = tableEntry{
703					val:    uint32(cv),
704					offset: i,
705				}
706			}
707		}
708		e.lastDictID = d.id
709	}
710	// Reset table to initial state
711	e.cur = e.maxMatchOff
712	copy(e.longTable[:], e.dictLongTable)
713}
714