1// Copyright 2014-2017 Ulrich Kunitz. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package lzma
6
7import (
8	"bufio"
9	"errors"
10	"fmt"
11	"io"
12	"unicode"
13)
14
15// node represents a node in the binary tree.
16type node struct {
17	// x is the search value
18	x uint32
19	// p parent node
20	p uint32
21	// l left child
22	l uint32
23	// r right child
24	r uint32
25}
26
27// wordLen is the number of bytes represented by the v field of a node.
28const wordLen = 4
29
30// binTree supports the identification of the next operation based on a
31// binary tree.
32//
33// Nodes will be identified by their index into the ring buffer.
34type binTree struct {
35	dict *encoderDict
36	// ring buffer of nodes
37	node []node
38	// absolute offset of the entry for the next node. Position 4
39	// byte larger.
40	hoff int64
41	// front position in the node ring buffer
42	front uint32
43	// index of the root node
44	root uint32
45	// current x value
46	x uint32
47	// preallocated array
48	data []byte
49}
50
51// null represents the nonexistent index. We can't use zero because it
52// would always exist or we would need to decrease the index for each
53// reference.
54const null uint32 = 1<<32 - 1
55
56// newBinTree initializes the binTree structure. The capacity defines
57// the size of the buffer and defines the maximum distance for which
58// matches will be found.
59func newBinTree(capacity int) (t *binTree, err error) {
60	if capacity < 1 {
61		return nil, errors.New(
62			"newBinTree: capacity must be larger than zero")
63	}
64	if int64(capacity) >= int64(null) {
65		return nil, errors.New(
66			"newBinTree: capacity must less 2^{32}-1")
67	}
68	t = &binTree{
69		node: make([]node, capacity),
70		hoff: -int64(wordLen),
71		root: null,
72		data: make([]byte, maxMatchLen),
73	}
74	return t, nil
75}
76
77func (t *binTree) SetDict(d *encoderDict) { t.dict = d }
78
79// WriteByte writes a single byte into the binary tree.
80func (t *binTree) WriteByte(c byte) error {
81	t.x = (t.x << 8) | uint32(c)
82	t.hoff++
83	if t.hoff < 0 {
84		return nil
85	}
86	v := t.front
87	if int64(v) < t.hoff {
88		// We are overwriting old nodes stored in the tree.
89		t.remove(v)
90	}
91	t.node[v].x = t.x
92	t.add(v)
93	t.front++
94	if int64(t.front) >= int64(len(t.node)) {
95		t.front = 0
96	}
97	return nil
98}
99
100// Writes writes a sequence of bytes into the binTree structure.
101func (t *binTree) Write(p []byte) (n int, err error) {
102	for _, c := range p {
103		t.WriteByte(c)
104	}
105	return len(p), nil
106}
107
108// add puts the node v into the tree. The node must not be part of the
109// tree before.
110func (t *binTree) add(v uint32) {
111	vn := &t.node[v]
112	// Set left and right to null indices.
113	vn.l, vn.r = null, null
114	// If the binary tree is empty make v the root.
115	if t.root == null {
116		t.root = v
117		vn.p = null
118		return
119	}
120	x := vn.x
121	p := t.root
122	// Search for the right leave link and add the new node.
123	for {
124		pn := &t.node[p]
125		if x <= pn.x {
126			if pn.l == null {
127				pn.l = v
128				vn.p = p
129				return
130			}
131			p = pn.l
132		} else {
133			if pn.r == null {
134				pn.r = v
135				vn.p = p
136				return
137			}
138			p = pn.r
139		}
140	}
141}
142
143// parent returns the parent node index of v and the pointer to v value
144// in the parent.
145func (t *binTree) parent(v uint32) (p uint32, ptr *uint32) {
146	if t.root == v {
147		return null, &t.root
148	}
149	p = t.node[v].p
150	if t.node[p].l == v {
151		ptr = &t.node[p].l
152	} else {
153		ptr = &t.node[p].r
154	}
155	return
156}
157
158// Remove node v.
159func (t *binTree) remove(v uint32) {
160	vn := &t.node[v]
161	p, ptr := t.parent(v)
162	l, r := vn.l, vn.r
163	if l == null {
164		// Move the right child up.
165		*ptr = r
166		if r != null {
167			t.node[r].p = p
168		}
169		return
170	}
171	if r == null {
172		// Move the left child up.
173		*ptr = l
174		t.node[l].p = p
175		return
176	}
177
178	// Search the in-order predecessor u.
179	un := &t.node[l]
180	ur := un.r
181	if ur == null {
182		// In order predecessor is l. Move it up.
183		un.r = r
184		t.node[r].p = l
185		un.p = p
186		*ptr = l
187		return
188	}
189	var u uint32
190	for {
191		// Look for the max value in the tree where l is root.
192		u = ur
193		ur = t.node[u].r
194		if ur == null {
195			break
196		}
197	}
198	// replace u with ul
199	un = &t.node[u]
200	ul := un.l
201	up := un.p
202	t.node[up].r = ul
203	if ul != null {
204		t.node[ul].p = up
205	}
206
207	// replace v by u
208	un.l, un.r = l, r
209	t.node[l].p = u
210	t.node[r].p = u
211	*ptr = u
212	un.p = p
213}
214
215// search looks for the node that have the value x or for the nodes that
216// brace it. The node highest in the tree with the value x will be
217// returned. All other nodes with the same value live in left subtree of
218// the returned node.
219func (t *binTree) search(v uint32, x uint32) (a, b uint32) {
220	a, b = null, null
221	if v == null {
222		return
223	}
224	for {
225		vn := &t.node[v]
226		if x <= vn.x {
227			if x == vn.x {
228				return v, v
229			}
230			b = v
231			if vn.l == null {
232				return
233			}
234			v = vn.l
235		} else {
236			a = v
237			if vn.r == null {
238				return
239			}
240			v = vn.r
241		}
242	}
243}
244
245// max returns the node with maximum value in the subtree with v as
246// root.
247func (t *binTree) max(v uint32) uint32 {
248	if v == null {
249		return null
250	}
251	for {
252		r := t.node[v].r
253		if r == null {
254			return v
255		}
256		v = r
257	}
258}
259
260// min returns the node with the minimum value in the subtree with v as
261// root.
262func (t *binTree) min(v uint32) uint32 {
263	if v == null {
264		return null
265	}
266	for {
267		l := t.node[v].l
268		if l == null {
269			return v
270		}
271		v = l
272	}
273}
274
275// pred returns the in-order predecessor of node v.
276func (t *binTree) pred(v uint32) uint32 {
277	if v == null {
278		return null
279	}
280	u := t.max(t.node[v].l)
281	if u != null {
282		return u
283	}
284	for {
285		p := t.node[v].p
286		if p == null {
287			return null
288		}
289		if t.node[p].r == v {
290			return p
291		}
292		v = p
293	}
294}
295
296// succ returns the in-order successor of node v.
297func (t *binTree) succ(v uint32) uint32 {
298	if v == null {
299		return null
300	}
301	u := t.min(t.node[v].r)
302	if u != null {
303		return u
304	}
305	for {
306		p := t.node[v].p
307		if p == null {
308			return null
309		}
310		if t.node[p].l == v {
311			return p
312		}
313		v = p
314	}
315}
316
317// xval converts the first four bytes of a into an 32-bit unsigned
318// integer in big-endian order.
319func xval(a []byte) uint32 {
320	var x uint32
321	switch len(a) {
322	default:
323		x |= uint32(a[3])
324		fallthrough
325	case 3:
326		x |= uint32(a[2]) << 8
327		fallthrough
328	case 2:
329		x |= uint32(a[1]) << 16
330		fallthrough
331	case 1:
332		x |= uint32(a[0]) << 24
333	case 0:
334	}
335	return x
336}
337
338// dumpX converts value x into a four-letter string.
339func dumpX(x uint32) string {
340	a := make([]byte, 4)
341	for i := 0; i < 4; i++ {
342		c := byte(x >> uint((3-i)*8))
343		if unicode.IsGraphic(rune(c)) {
344			a[i] = c
345		} else {
346			a[i] = '.'
347		}
348	}
349	return string(a)
350}
351
352// dumpNode writes a representation of the node v into the io.Writer.
353func (t *binTree) dumpNode(w io.Writer, v uint32, indent int) {
354	if v == null {
355		return
356	}
357
358	vn := &t.node[v]
359
360	t.dumpNode(w, vn.r, indent+2)
361
362	for i := 0; i < indent; i++ {
363		fmt.Fprint(w, " ")
364	}
365	if vn.p == null {
366		fmt.Fprintf(w, "node %d %q parent null\n", v, dumpX(vn.x))
367	} else {
368		fmt.Fprintf(w, "node %d %q parent %d\n", v, dumpX(vn.x), vn.p)
369	}
370
371	t.dumpNode(w, vn.l, indent+2)
372}
373
374// dump prints a representation of the binary tree into the writer.
375func (t *binTree) dump(w io.Writer) error {
376	bw := bufio.NewWriter(w)
377	t.dumpNode(bw, t.root, 0)
378	return bw.Flush()
379}
380
381func (t *binTree) distance(v uint32) int {
382	dist := int(t.front) - int(v)
383	if dist <= 0 {
384		dist += len(t.node)
385	}
386	return dist
387}
388
389type matchParams struct {
390	rep [4]uint32
391	// length when match will be accepted
392	nAccept int
393	// nodes to check
394	check int
395	// finish if length get shorter
396	stopShorter bool
397}
398
399func (t *binTree) match(m match, distIter func() (int, bool), p matchParams,
400) (r match, checked int, accepted bool) {
401	buf := &t.dict.buf
402	for {
403		if checked >= p.check {
404			return m, checked, true
405		}
406		dist, ok := distIter()
407		if !ok {
408			return m, checked, false
409		}
410		checked++
411		if m.n > 0 {
412			i := buf.rear - dist + m.n - 1
413			if i < 0 {
414				i += len(buf.data)
415			} else if i >= len(buf.data) {
416				i -= len(buf.data)
417			}
418			if buf.data[i] != t.data[m.n-1] {
419				if p.stopShorter {
420					return m, checked, false
421				}
422				continue
423			}
424		}
425		n := buf.matchLen(dist, t.data)
426		switch n {
427		case 0:
428			if p.stopShorter {
429				return m, checked, false
430			}
431			continue
432		case 1:
433			if uint32(dist-minDistance) != p.rep[0] {
434				continue
435			}
436		}
437		if n < m.n || (n == m.n && int64(dist) >= m.distance) {
438			continue
439		}
440		m = match{int64(dist), n}
441		if n >= p.nAccept {
442			return m, checked, true
443		}
444	}
445}
446
447func (t *binTree) NextOp(rep [4]uint32) operation {
448	// retrieve maxMatchLen data
449	n, _ := t.dict.buf.Peek(t.data[:maxMatchLen])
450	if n == 0 {
451		panic("no data in buffer")
452	}
453	t.data = t.data[:n]
454
455	var (
456		m                  match
457		x, u, v            uint32
458		iterPred, iterSucc func() (int, bool)
459	)
460	p := matchParams{
461		rep:     rep,
462		nAccept: maxMatchLen,
463		check:   32,
464	}
465	i := 4
466	iterSmall := func() (dist int, ok bool) {
467		i--
468		if i <= 0 {
469			return 0, false
470		}
471		return i, true
472	}
473	m, checked, accepted := t.match(m, iterSmall, p)
474	if accepted {
475		goto end
476	}
477	p.check -= checked
478	x = xval(t.data)
479	u, v = t.search(t.root, x)
480	if u == v && len(t.data) == 4 {
481		iter := func() (dist int, ok bool) {
482			if u == null {
483				return 0, false
484			}
485			dist = t.distance(u)
486			u, v = t.search(t.node[u].l, x)
487			if u != v {
488				u = null
489			}
490			return dist, true
491		}
492		m, _, _ = t.match(m, iter, p)
493		goto end
494	}
495	p.stopShorter = true
496	iterSucc = func() (dist int, ok bool) {
497		if v == null {
498			return 0, false
499		}
500		dist = t.distance(v)
501		v = t.succ(v)
502		return dist, true
503	}
504	m, checked, accepted = t.match(m, iterSucc, p)
505	if accepted {
506		goto end
507	}
508	p.check -= checked
509	iterPred = func() (dist int, ok bool) {
510		if u == null {
511			return 0, false
512		}
513		dist = t.distance(u)
514		u = t.pred(u)
515		return dist, true
516	}
517	m, _, _ = t.match(m, iterPred, p)
518end:
519	if m.n == 0 {
520		return lit{t.data[0]}
521	}
522	return m
523}
524