1// Copyright 2016 The Snappy-Go Authors. All rights reserved.
2// Copyright (c) 2019 Klaus Post. All rights reserved.
3// Use of this source code is governed by a BSD-style
4// license that can be found in the LICENSE file.
5
6package s2
7
8import (
9	"fmt"
10	"math/bits"
11)
12
13// encodeBlockBest encodes a non-empty src to a guaranteed-large-enough dst. It
14// assumes that the varint-encoded length of the decompressed bytes has already
15// been written.
16//
17// It also assumes that:
18//	len(dst) >= MaxEncodedLen(len(src)) &&
19// 	minNonLiteralBlockSize <= len(src) && len(src) <= maxBlockSize
20func encodeBlockBest(dst, src []byte) (d int) {
21	// Initialize the hash tables.
22	const (
23		// Long hash matches.
24		lTableBits    = 19
25		maxLTableSize = 1 << lTableBits
26
27		// Short hash matches.
28		sTableBits    = 16
29		maxSTableSize = 1 << sTableBits
30
31		inputMargin = 8 + 2
32	)
33
34	// sLimit is when to stop looking for offset/length copies. The inputMargin
35	// lets us use a fast path for emitLiteral in the main loop, while we are
36	// looking for copies.
37	sLimit := len(src) - inputMargin
38	if len(src) < minNonLiteralBlockSize {
39		return 0
40	}
41
42	var lTable [maxLTableSize]uint64
43	var sTable [maxSTableSize]uint64
44
45	// Bail if we can't compress to at least this.
46	dstLimit := len(src) - 5
47
48	// nextEmit is where in src the next emitLiteral should start from.
49	nextEmit := 0
50
51	// The encoded form must start with a literal, as there are no previous
52	// bytes to copy, so we start looking for hash matches at s == 1.
53	s := 1
54	cv := load64(src, s)
55
56	// We search for a repeat at -1, but don't output repeats when nextEmit == 0
57	repeat := 1
58	const lowbitMask = 0xffffffff
59	getCur := func(x uint64) int {
60		return int(x & lowbitMask)
61	}
62	getPrev := func(x uint64) int {
63		return int(x >> 32)
64	}
65	const maxSkip = 64
66
67	for {
68		type match struct {
69			offset int
70			s      int
71			length int
72			score  int
73			rep    bool
74		}
75		var best match
76		for {
77			// Next src position to check
78			nextS := (s-nextEmit)>>8 + 1
79			if nextS > maxSkip {
80				nextS = s + maxSkip
81			} else {
82				nextS += s
83			}
84			if nextS > sLimit {
85				goto emitRemainder
86			}
87			hashL := hash8(cv, lTableBits)
88			hashS := hash4(cv, sTableBits)
89			candidateL := lTable[hashL]
90			candidateS := sTable[hashS]
91
92			score := func(m match) int {
93				// Matches that are longer forward are penalized since we must emit it as a literal.
94				score := m.length - m.s
95				if nextEmit == m.s {
96					// If we do not have to emit literals, we save 1 byte
97					score++
98				}
99				offset := m.s - m.offset
100				if m.rep {
101					return score - emitRepeatSize(offset, m.length)
102				}
103				return score - emitCopySize(offset, m.length)
104			}
105
106			matchAt := func(offset, s int, first uint32, rep bool) match {
107				if best.length != 0 && best.s-best.offset == s-offset {
108					// Don't retest if we have the same offset.
109					return match{offset: offset, s: s}
110				}
111				if load32(src, offset) != first {
112					return match{offset: offset, s: s}
113				}
114				m := match{offset: offset, s: s, length: 4 + offset, rep: rep}
115				s += 4
116				for s <= sLimit {
117					if diff := load64(src, s) ^ load64(src, m.length); diff != 0 {
118						m.length += bits.TrailingZeros64(diff) >> 3
119						break
120					}
121					s += 8
122					m.length += 8
123				}
124				m.length -= offset
125				m.score = score(m)
126				if m.score <= -m.s {
127					// Eliminate if no savings, we might find a better one.
128					m.length = 0
129				}
130				return m
131			}
132
133			bestOf := func(a, b match) match {
134				if b.length == 0 {
135					return a
136				}
137				if a.length == 0 {
138					return b
139				}
140				as := a.score + b.s
141				bs := b.score + a.s
142				if as >= bs {
143					return a
144				}
145				return b
146			}
147
148			best = bestOf(matchAt(getCur(candidateL), s, uint32(cv), false), matchAt(getPrev(candidateL), s, uint32(cv), false))
149			best = bestOf(best, matchAt(getCur(candidateS), s, uint32(cv), false))
150			best = bestOf(best, matchAt(getPrev(candidateS), s, uint32(cv), false))
151
152			{
153				best = bestOf(best, matchAt(s-repeat+1, s+1, uint32(cv>>8), true))
154				if best.length > 0 {
155					// s+1
156					nextShort := sTable[hash4(cv>>8, sTableBits)]
157					s := s + 1
158					cv := load64(src, s)
159					nextLong := lTable[hash8(cv, lTableBits)]
160					best = bestOf(best, matchAt(getCur(nextShort), s, uint32(cv), false))
161					best = bestOf(best, matchAt(getPrev(nextShort), s, uint32(cv), false))
162					best = bestOf(best, matchAt(getCur(nextLong), s, uint32(cv), false))
163					best = bestOf(best, matchAt(getPrev(nextLong), s, uint32(cv), false))
164					// Repeat at + 2
165					best = bestOf(best, matchAt(s-repeat+1, s+1, uint32(cv>>8), true))
166
167					// s+2
168					if true {
169						nextShort = sTable[hash4(cv>>8, sTableBits)]
170						s++
171						cv = load64(src, s)
172						nextLong = lTable[hash8(cv, lTableBits)]
173						best = bestOf(best, matchAt(getCur(nextShort), s, uint32(cv), false))
174						best = bestOf(best, matchAt(getPrev(nextShort), s, uint32(cv), false))
175						best = bestOf(best, matchAt(getCur(nextLong), s, uint32(cv), false))
176						best = bestOf(best, matchAt(getPrev(nextLong), s, uint32(cv), false))
177					}
178					// Search for a match at best match end, see if that is better.
179					if sAt := best.s + best.length; sAt < sLimit {
180						sBack := best.s
181						backL := best.length
182						// Load initial values
183						cv = load64(src, sBack)
184						// Search for mismatch
185						next := lTable[hash8(load64(src, sAt), lTableBits)]
186						//next := sTable[hash4(load64(src, sAt), sTableBits)]
187
188						if checkAt := getCur(next) - backL; checkAt > 0 {
189							best = bestOf(best, matchAt(checkAt, sBack, uint32(cv), false))
190						}
191						if checkAt := getPrev(next) - backL; checkAt > 0 {
192							best = bestOf(best, matchAt(checkAt, sBack, uint32(cv), false))
193						}
194					}
195				}
196			}
197
198			// Update table
199			lTable[hashL] = uint64(s) | candidateL<<32
200			sTable[hashS] = uint64(s) | candidateS<<32
201
202			if best.length > 0 {
203				break
204			}
205
206			cv = load64(src, nextS)
207			s = nextS
208		}
209
210		// Extend backwards, not needed for repeats...
211		s = best.s
212		if !best.rep {
213			for best.offset > 0 && s > nextEmit && src[best.offset-1] == src[s-1] {
214				best.offset--
215				best.length++
216				s--
217			}
218		}
219		if false && best.offset >= s {
220			panic(fmt.Errorf("t %d >= s %d", best.offset, s))
221		}
222		// Bail if we exceed the maximum size.
223		if d+(s-nextEmit) > dstLimit {
224			return 0
225		}
226
227		base := s
228		offset := s - best.offset
229
230		s += best.length
231
232		if offset > 65535 && s-base <= 5 && !best.rep {
233			// Bail if the match is equal or worse to the encoding.
234			s = best.s + 1
235			if s >= sLimit {
236				goto emitRemainder
237			}
238			cv = load64(src, s)
239			continue
240		}
241		d += emitLiteral(dst[d:], src[nextEmit:base])
242		if best.rep {
243			if nextEmit > 0 {
244				// same as `add := emitCopy(dst[d:], repeat, s-base)` but skips storing offset.
245				d += emitRepeat(dst[d:], offset, best.length)
246			} else {
247				// First match, cannot be repeat.
248				d += emitCopy(dst[d:], offset, best.length)
249			}
250		} else {
251			d += emitCopy(dst[d:], offset, best.length)
252		}
253		repeat = offset
254
255		nextEmit = s
256		if s >= sLimit {
257			goto emitRemainder
258		}
259
260		if d > dstLimit {
261			// Do we have space for more, if not bail.
262			return 0
263		}
264		// Fill tables...
265		for i := best.s + 1; i < s; i++ {
266			cv0 := load64(src, i)
267			long0 := hash8(cv0, lTableBits)
268			short0 := hash4(cv0, sTableBits)
269			lTable[long0] = uint64(i) | lTable[long0]<<32
270			sTable[short0] = uint64(i) | sTable[short0]<<32
271		}
272		cv = load64(src, s)
273	}
274
275emitRemainder:
276	if nextEmit < len(src) {
277		// Bail if we exceed the maximum size.
278		if d+len(src)-nextEmit > dstLimit {
279			return 0
280		}
281		d += emitLiteral(dst[d:], src[nextEmit:])
282	}
283	return d
284}
285
286// encodeBlockBestSnappy encodes a non-empty src to a guaranteed-large-enough dst. It
287// assumes that the varint-encoded length of the decompressed bytes has already
288// been written.
289//
290// It also assumes that:
291//	len(dst) >= MaxEncodedLen(len(src)) &&
292// 	minNonLiteralBlockSize <= len(src) && len(src) <= maxBlockSize
293func encodeBlockBestSnappy(dst, src []byte) (d int) {
294	// Initialize the hash tables.
295	const (
296		// Long hash matches.
297		lTableBits    = 19
298		maxLTableSize = 1 << lTableBits
299
300		// Short hash matches.
301		sTableBits    = 16
302		maxSTableSize = 1 << sTableBits
303
304		inputMargin = 8 + 2
305	)
306
307	// sLimit is when to stop looking for offset/length copies. The inputMargin
308	// lets us use a fast path for emitLiteral in the main loop, while we are
309	// looking for copies.
310	sLimit := len(src) - inputMargin
311	if len(src) < minNonLiteralBlockSize {
312		return 0
313	}
314
315	var lTable [maxLTableSize]uint64
316	var sTable [maxSTableSize]uint64
317
318	// Bail if we can't compress to at least this.
319	dstLimit := len(src) - 5
320
321	// nextEmit is where in src the next emitLiteral should start from.
322	nextEmit := 0
323
324	// The encoded form must start with a literal, as there are no previous
325	// bytes to copy, so we start looking for hash matches at s == 1.
326	s := 1
327	cv := load64(src, s)
328
329	// We search for a repeat at -1, but don't output repeats when nextEmit == 0
330	repeat := 1
331	const lowbitMask = 0xffffffff
332	getCur := func(x uint64) int {
333		return int(x & lowbitMask)
334	}
335	getPrev := func(x uint64) int {
336		return int(x >> 32)
337	}
338	const maxSkip = 64
339
340	for {
341		type match struct {
342			offset int
343			s      int
344			length int
345			score  int
346		}
347		var best match
348		for {
349			// Next src position to check
350			nextS := (s-nextEmit)>>8 + 1
351			if nextS > maxSkip {
352				nextS = s + maxSkip
353			} else {
354				nextS += s
355			}
356			if nextS > sLimit {
357				goto emitRemainder
358			}
359			hashL := hash8(cv, lTableBits)
360			hashS := hash4(cv, sTableBits)
361			candidateL := lTable[hashL]
362			candidateS := sTable[hashS]
363
364			score := func(m match) int {
365				// Matches that are longer forward are penalized since we must emit it as a literal.
366				score := m.length - m.s
367				if nextEmit == m.s {
368					// If we do not have to emit literals, we save 1 byte
369					score++
370				}
371				offset := m.s - m.offset
372
373				return score - emitCopySize(offset, m.length)
374			}
375
376			matchAt := func(offset, s int, first uint32) match {
377				if best.length != 0 && best.s-best.offset == s-offset {
378					// Don't retest if we have the same offset.
379					return match{offset: offset, s: s}
380				}
381				if load32(src, offset) != first {
382					return match{offset: offset, s: s}
383				}
384				m := match{offset: offset, s: s, length: 4 + offset}
385				s += 4
386				for s <= sLimit {
387					if diff := load64(src, s) ^ load64(src, m.length); diff != 0 {
388						m.length += bits.TrailingZeros64(diff) >> 3
389						break
390					}
391					s += 8
392					m.length += 8
393				}
394				m.length -= offset
395				m.score = score(m)
396				if m.score <= -m.s {
397					// Eliminate if no savings, we might find a better one.
398					m.length = 0
399				}
400				return m
401			}
402
403			bestOf := func(a, b match) match {
404				if b.length == 0 {
405					return a
406				}
407				if a.length == 0 {
408					return b
409				}
410				as := a.score + b.s
411				bs := b.score + a.s
412				if as >= bs {
413					return a
414				}
415				return b
416			}
417
418			best = bestOf(matchAt(getCur(candidateL), s, uint32(cv)), matchAt(getPrev(candidateL), s, uint32(cv)))
419			best = bestOf(best, matchAt(getCur(candidateS), s, uint32(cv)))
420			best = bestOf(best, matchAt(getPrev(candidateS), s, uint32(cv)))
421
422			{
423				best = bestOf(best, matchAt(s-repeat+1, s+1, uint32(cv>>8)))
424				if best.length > 0 {
425					// s+1
426					nextShort := sTable[hash4(cv>>8, sTableBits)]
427					s := s + 1
428					cv := load64(src, s)
429					nextLong := lTable[hash8(cv, lTableBits)]
430					best = bestOf(best, matchAt(getCur(nextShort), s, uint32(cv)))
431					best = bestOf(best, matchAt(getPrev(nextShort), s, uint32(cv)))
432					best = bestOf(best, matchAt(getCur(nextLong), s, uint32(cv)))
433					best = bestOf(best, matchAt(getPrev(nextLong), s, uint32(cv)))
434					// Repeat at + 2
435					best = bestOf(best, matchAt(s-repeat+1, s+1, uint32(cv>>8)))
436
437					// s+2
438					if true {
439						nextShort = sTable[hash4(cv>>8, sTableBits)]
440						s++
441						cv = load64(src, s)
442						nextLong = lTable[hash8(cv, lTableBits)]
443						best = bestOf(best, matchAt(getCur(nextShort), s, uint32(cv)))
444						best = bestOf(best, matchAt(getPrev(nextShort), s, uint32(cv)))
445						best = bestOf(best, matchAt(getCur(nextLong), s, uint32(cv)))
446						best = bestOf(best, matchAt(getPrev(nextLong), s, uint32(cv)))
447					}
448					// Search for a match at best match end, see if that is better.
449					if sAt := best.s + best.length; sAt < sLimit {
450						sBack := best.s
451						backL := best.length
452						// Load initial values
453						cv = load64(src, sBack)
454						// Search for mismatch
455						next := lTable[hash8(load64(src, sAt), lTableBits)]
456						//next := sTable[hash4(load64(src, sAt), sTableBits)]
457
458						if checkAt := getCur(next) - backL; checkAt > 0 {
459							best = bestOf(best, matchAt(checkAt, sBack, uint32(cv)))
460						}
461						if checkAt := getPrev(next) - backL; checkAt > 0 {
462							best = bestOf(best, matchAt(checkAt, sBack, uint32(cv)))
463						}
464					}
465				}
466			}
467
468			// Update table
469			lTable[hashL] = uint64(s) | candidateL<<32
470			sTable[hashS] = uint64(s) | candidateS<<32
471
472			if best.length > 0 {
473				break
474			}
475
476			cv = load64(src, nextS)
477			s = nextS
478		}
479
480		// Extend backwards, not needed for repeats...
481		s = best.s
482		if true {
483			for best.offset > 0 && s > nextEmit && src[best.offset-1] == src[s-1] {
484				best.offset--
485				best.length++
486				s--
487			}
488		}
489		if false && best.offset >= s {
490			panic(fmt.Errorf("t %d >= s %d", best.offset, s))
491		}
492		// Bail if we exceed the maximum size.
493		if d+(s-nextEmit) > dstLimit {
494			return 0
495		}
496
497		base := s
498		offset := s - best.offset
499
500		s += best.length
501
502		if offset > 65535 && s-base <= 5 {
503			// Bail if the match is equal or worse to the encoding.
504			s = best.s + 1
505			if s >= sLimit {
506				goto emitRemainder
507			}
508			cv = load64(src, s)
509			continue
510		}
511		d += emitLiteral(dst[d:], src[nextEmit:base])
512		d += emitCopyNoRepeat(dst[d:], offset, best.length)
513		repeat = offset
514
515		nextEmit = s
516		if s >= sLimit {
517			goto emitRemainder
518		}
519
520		if d > dstLimit {
521			// Do we have space for more, if not bail.
522			return 0
523		}
524		// Fill tables...
525		for i := best.s + 1; i < s; i++ {
526			cv0 := load64(src, i)
527			long0 := hash8(cv0, lTableBits)
528			short0 := hash4(cv0, sTableBits)
529			lTable[long0] = uint64(i) | lTable[long0]<<32
530			sTable[short0] = uint64(i) | sTable[short0]<<32
531		}
532		cv = load64(src, s)
533	}
534
535emitRemainder:
536	if nextEmit < len(src) {
537		// Bail if we exceed the maximum size.
538		if d+len(src)-nextEmit > dstLimit {
539			return 0
540		}
541		d += emitLiteral(dst[d:], src[nextEmit:])
542	}
543	return d
544}
545
546// emitCopySize returns the size to encode the offset+length
547//
548// It assumes that:
549//	1 <= offset && offset <= math.MaxUint32
550//	4 <= length && length <= 1 << 24
551func emitCopySize(offset, length int) int {
552	if offset >= 65536 {
553		i := 0
554		if length > 64 {
555			length -= 64
556			if length >= 4 {
557				// Emit remaining as repeats
558				return 5 + emitRepeatSize(offset, length)
559			}
560			i = 5
561		}
562		if length == 0 {
563			return i
564		}
565		return i + 5
566	}
567
568	// Offset no more than 2 bytes.
569	if length > 64 {
570		// Emit remaining as repeats, at least 4 bytes remain.
571		return 3 + emitRepeatSize(offset, length-60)
572	}
573	if length >= 12 || offset >= 2048 {
574		return 3
575	}
576	// Emit the remaining copy, encoded as 2 bytes.
577	return 2
578}
579
580// emitRepeatSize returns the number of bytes required to encode a repeat.
581// Length must be at least 4 and < 1<<24
582func emitRepeatSize(offset, length int) int {
583	// Repeat offset, make length cheaper
584	if length <= 4+4 || (length < 8+4 && offset < 2048) {
585		return 2
586	}
587	if length < (1<<8)+4+4 {
588		return 3
589	}
590	if length < (1<<16)+(1<<8)+4 {
591		return 4
592	}
593	const maxRepeat = (1 << 24) - 1
594	length -= (1 << 16) - 4
595	left := 0
596	if length > maxRepeat {
597		left = length - maxRepeat + 4
598		length = maxRepeat - 4
599	}
600	if left > 0 {
601		return 5 + emitRepeatSize(offset, left)
602	}
603	return 5
604}
605