1// Copyright 2014-2019 Ulrich Kunitz. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package lzma
6
7import (
8	"errors"
9	"fmt"
10
11	"github.com/ulikunitz/xz/internal/hash"
12)
13
14/* For compression we need to find byte sequences that match the byte
15 * sequence at the dictionary head. A hash table is a simple method to
16 * provide this capability.
17 */
18
19// maxMatches limits the number of matches requested from the Matches
20// function. This controls the speed of the overall encoding.
21const maxMatches = 16
22
23// shortDists defines the number of short distances supported by the
24// implementation.
25const shortDists = 8
26
27// The minimum is somehow arbitrary but the maximum is limited by the
28// memory requirements of the hash table.
29const (
30	minTableExponent = 9
31	maxTableExponent = 20
32)
33
34// newRoller contains the function used to create an instance of the
35// hash.Roller.
36var newRoller = func(n int) hash.Roller { return hash.NewCyclicPoly(n) }
37
38// hashTable stores the hash table including the rolling hash method.
39//
40// We implement chained hashing into a circular buffer. Each entry in
41// the circular buffer stores the delta distance to the next position with a
42// word that has the same hash value.
43type hashTable struct {
44	dict *encoderDict
45	// actual hash table
46	t []int64
47	// circular list data with the offset to the next word
48	data  []uint32
49	front int
50	// mask for computing the index for the hash table
51	mask uint64
52	// hash offset; initial value is -int64(wordLen)
53	hoff int64
54	// length of the hashed word
55	wordLen int
56	// hash roller for computing the hash values for the Write
57	// method
58	wr hash.Roller
59	// hash roller for computing arbitrary hashes
60	hr hash.Roller
61	// preallocated slices
62	p         [maxMatches]int64
63	distances [maxMatches + shortDists]int
64}
65
66// hashTableExponent derives the hash table exponent from the dictionary
67// capacity.
68func hashTableExponent(n uint32) int {
69	e := 30 - nlz32(n)
70	switch {
71	case e < minTableExponent:
72		e = minTableExponent
73	case e > maxTableExponent:
74		e = maxTableExponent
75	}
76	return e
77}
78
79// newHashTable creates a new hash table for words of length wordLen
80func newHashTable(capacity int, wordLen int) (t *hashTable, err error) {
81	if !(0 < capacity) {
82		return nil, errors.New(
83			"newHashTable: capacity must not be negative")
84	}
85	exp := hashTableExponent(uint32(capacity))
86	if !(1 <= wordLen && wordLen <= 4) {
87		return nil, errors.New("newHashTable: " +
88			"argument wordLen out of range")
89	}
90	n := 1 << uint(exp)
91	if n <= 0 {
92		panic("newHashTable: exponent is too large")
93	}
94	t = &hashTable{
95		t:       make([]int64, n),
96		data:    make([]uint32, capacity),
97		mask:    (uint64(1) << uint(exp)) - 1,
98		hoff:    -int64(wordLen),
99		wordLen: wordLen,
100		wr:      newRoller(wordLen),
101		hr:      newRoller(wordLen),
102	}
103	return t, nil
104}
105
106func (t *hashTable) SetDict(d *encoderDict) { t.dict = d }
107
108// buffered returns the number of bytes that are currently hashed.
109func (t *hashTable) buffered() int {
110	n := t.hoff + 1
111	switch {
112	case n <= 0:
113		return 0
114	case n >= int64(len(t.data)):
115		return len(t.data)
116	}
117	return int(n)
118}
119
120// addIndex adds n to an index ensuring that is stays inside the
121// circular buffer for the hash chain.
122func (t *hashTable) addIndex(i, n int) int {
123	i += n - len(t.data)
124	if i < 0 {
125		i += len(t.data)
126	}
127	return i
128}
129
130// putDelta puts the delta instance at the current front of the circular
131// chain buffer.
132func (t *hashTable) putDelta(delta uint32) {
133	t.data[t.front] = delta
134	t.front = t.addIndex(t.front, 1)
135}
136
137// putEntry puts a new entry into the hash table. If there is already a
138// value stored it is moved into the circular chain buffer.
139func (t *hashTable) putEntry(h uint64, pos int64) {
140	if pos < 0 {
141		return
142	}
143	i := h & t.mask
144	old := t.t[i] - 1
145	t.t[i] = pos + 1
146	var delta int64
147	if old >= 0 {
148		delta = pos - old
149		if delta > 1<<32-1 || delta > int64(t.buffered()) {
150			delta = 0
151		}
152	}
153	t.putDelta(uint32(delta))
154}
155
156// WriteByte converts a single byte into a hash and puts them into the hash
157// table.
158func (t *hashTable) WriteByte(b byte) error {
159	h := t.wr.RollByte(b)
160	t.hoff++
161	t.putEntry(h, t.hoff)
162	return nil
163}
164
165// Write converts the bytes provided into hash tables and stores the
166// abbreviated offsets into the hash table. The method will never return an
167// error.
168func (t *hashTable) Write(p []byte) (n int, err error) {
169	for _, b := range p {
170		// WriteByte doesn't generate an error.
171		t.WriteByte(b)
172	}
173	return len(p), nil
174}
175
176// getMatches the matches for a specific hash. The functions returns the
177// number of positions found.
178//
179// TODO: Make a getDistances because that we are actually interested in.
180func (t *hashTable) getMatches(h uint64, positions []int64) (n int) {
181	if t.hoff < 0 || len(positions) == 0 {
182		return 0
183	}
184	buffered := t.buffered()
185	tailPos := t.hoff + 1 - int64(buffered)
186	rear := t.front - buffered
187	if rear >= 0 {
188		rear -= len(t.data)
189	}
190	// get the slot for the hash
191	pos := t.t[h&t.mask] - 1
192	delta := pos - tailPos
193	for {
194		if delta < 0 {
195			return n
196		}
197		positions[n] = tailPos + delta
198		n++
199		if n >= len(positions) {
200			return n
201		}
202		i := rear + int(delta)
203		if i < 0 {
204			i += len(t.data)
205		}
206		u := t.data[i]
207		if u == 0 {
208			return n
209		}
210		delta -= int64(u)
211	}
212}
213
214// hash computes the rolling hash for the word stored in p. For correct
215// results its length must be equal to t.wordLen.
216func (t *hashTable) hash(p []byte) uint64 {
217	var h uint64
218	for _, b := range p {
219		h = t.hr.RollByte(b)
220	}
221	return h
222}
223
224// Matches fills the positions slice with potential matches. The
225// functions returns the number of positions filled into positions. The
226// byte slice p must have word length of the hash table.
227func (t *hashTable) Matches(p []byte, positions []int64) int {
228	if len(p) != t.wordLen {
229		panic(fmt.Errorf(
230			"byte slice must have length %d", t.wordLen))
231	}
232	h := t.hash(p)
233	return t.getMatches(h, positions)
234}
235
236// NextOp identifies the next operation using the hash table.
237//
238// TODO: Use all repetitions to find matches.
239func (t *hashTable) NextOp(rep [4]uint32) operation {
240	// get positions
241	data := t.dict.data[:maxMatchLen]
242	n, _ := t.dict.buf.Peek(data)
243	data = data[:n]
244	var p []int64
245	if n < t.wordLen {
246		p = t.p[:0]
247	} else {
248		p = t.p[:maxMatches]
249		n = t.Matches(data[:t.wordLen], p)
250		p = p[:n]
251	}
252
253	// convert positions in potential distances
254	head := t.dict.head
255	dists := append(t.distances[:0], 1, 2, 3, 4, 5, 6, 7, 8)
256	for _, pos := range p {
257		dis := int(head - pos)
258		if dis > shortDists {
259			dists = append(dists, dis)
260		}
261	}
262
263	// check distances
264	var m match
265	dictLen := t.dict.DictLen()
266	for _, dist := range dists {
267		if dist > dictLen {
268			continue
269		}
270
271		// Here comes a trick. We are only interested in matches
272		// that are longer than the matches we have been found
273		// before. So before we test the whole byte sequence at
274		// the given distance, we test the first byte that would
275		// make the match longer. If it doesn't match the byte
276		// to match, we don't to care any longer.
277		i := t.dict.buf.rear - dist + m.n
278		if i < 0 {
279			i += len(t.dict.buf.data)
280		}
281		if t.dict.buf.data[i] != data[m.n] {
282			// We can't get a longer match. Jump to the next
283			// distance.
284			continue
285		}
286
287		n := t.dict.buf.matchLen(dist, data)
288		switch n {
289		case 0:
290			continue
291		case 1:
292			if uint32(dist-minDistance) != rep[0] {
293				continue
294			}
295		}
296		if n > m.n {
297			m = match{int64(dist), n}
298			if n == len(data) {
299				// No better match will be found.
300				break
301			}
302		}
303	}
304
305	if m.n == 0 {
306		return lit{data[0]}
307	}
308	return m
309}
310