1package critbitgo
2
3import (
4	"bytes"
5	"encoding/hex"
6	"fmt"
7	"io"
8	"os"
9	"strconv"
10)
11
12// The matrix of most significant bit
13var msbMatrix [256]byte
14
15func buildMsbMatrix() {
16	for i := 0; i < len(msbMatrix); i++ {
17		b := byte(i)
18		b |= b >> 1
19		b |= b >> 2
20		b |= b >> 4
21		msbMatrix[i] = b &^ (b >> 1)
22	}
23}
24
25type node struct {
26	internal *internal
27	external *external
28}
29
30type internal struct {
31	child  [2]node
32	offset int
33	bit    byte
34	cont   bool // if true, key of child[1] contains key of child[0]
35}
36
37type external struct {
38	key   []byte
39	value interface{}
40}
41
42// finding the critical bit.
43func (n *external) criticalBit(key []byte) (offset int, bit byte, cont bool) {
44	nlen := len(n.key)
45	klen := len(key)
46	mlen := nlen
47	if nlen > klen {
48		mlen = klen
49	}
50
51	// find first differing byte and bit
52	for offset = 0; offset < mlen; offset++ {
53		if a, b := key[offset], n.key[offset]; a != b {
54			bit = msbMatrix[a^b]
55			return
56		}
57	}
58
59	if nlen < klen {
60		bit = msbMatrix[key[offset]]
61	} else if nlen > klen {
62		bit = msbMatrix[n.key[offset]]
63	} else {
64		// two keys are equal
65		offset = -1
66	}
67	return offset, bit, true
68}
69
70// calculate direction.
71func (n *internal) direction(key []byte) int {
72	if n.offset < len(key) && (key[n.offset]&n.bit != 0 || n.cont) {
73		return 1
74	}
75	return 0
76}
77
78// Crit-bit Tree
79type Trie struct {
80	root node
81	size int
82}
83
84// searching the tree.
85func (t *Trie) search(key []byte) *node {
86	n := &t.root
87	for n.internal != nil {
88		n = &n.internal.child[n.internal.direction(key)]
89	}
90	return n
91}
92
93// membership testing.
94func (t *Trie) Contains(key []byte) bool {
95	if n := t.search(key); n.external != nil && bytes.Equal(n.external.key, key) {
96		return true
97	}
98	return false
99}
100
101// get member.
102// if `key` is in Trie, `ok` is true.
103func (t *Trie) Get(key []byte) (value interface{}, ok bool) {
104	if n := t.search(key); n.external != nil && bytes.Equal(n.external.key, key) {
105		return n.external.value, true
106	}
107	return
108}
109
110// insert into the tree (replaceable).
111func (t *Trie) insert(key []byte, value interface{}, replace bool) bool {
112	// an empty tree
113	if t.size == 0 {
114		t.root.external = &external{
115			key:   key,
116			value: value,
117		}
118		t.size = 1
119		return true
120	}
121
122	n := t.search(key)
123	newOffset, newBit, newCont := n.external.criticalBit(key)
124
125	// already exists in the tree
126	if newOffset == -1 {
127		if replace {
128			n.external.value = value
129			return true
130		}
131		return false
132	}
133
134	// allocate new node
135	newNode := &internal{
136		offset: newOffset,
137		bit:    newBit,
138		cont:   newCont,
139	}
140	direction := newNode.direction(key)
141	newNode.child[direction].external = &external{
142		key:   key,
143		value: value,
144	}
145
146	// insert new node
147	wherep := &t.root
148	for in := wherep.internal; in != nil; in = wherep.internal {
149		if in.offset > newOffset || (in.offset == newOffset && in.bit < newBit) {
150			break
151		}
152		wherep = &in.child[in.direction(key)]
153	}
154
155	if wherep.internal != nil {
156		newNode.child[1-direction].internal = wherep.internal
157	} else {
158		newNode.child[1-direction].external = wherep.external
159		wherep.external = nil
160	}
161	wherep.internal = newNode
162	t.size += 1
163	return true
164}
165
166// insert into the tree.
167// if `key` is alredy in Trie, return false.
168func (t *Trie) Insert(key []byte, value interface{}) bool {
169	return t.insert(key, value, false)
170}
171
172// set into the tree.
173func (t *Trie) Set(key []byte, value interface{}) {
174	t.insert(key, value, true)
175}
176
177// deleting elements.
178// if `key` is in Trie, `ok` is true.
179func (t *Trie) Delete(key []byte) (value interface{}, ok bool) {
180	// an empty tree
181	if t.size == 0 {
182		return
183	}
184
185	var direction int
186	var whereq *node // pointer to the grandparent
187	var wherep *node = &t.root
188
189	// finding the best candidate to delete
190	for in := wherep.internal; in != nil; in = wherep.internal {
191		direction = in.direction(key)
192		whereq = wherep
193		wherep = &in.child[direction]
194	}
195
196	// checking that we have the right element
197	if !bytes.Equal(wherep.external.key, key) {
198		return
199	}
200	value = wherep.external.value
201	ok = true
202
203	// removing the node
204	if whereq == nil {
205		wherep.external = nil
206	} else {
207		othern := whereq.internal.child[1-direction]
208		whereq.internal = othern.internal
209		whereq.external = othern.external
210	}
211	t.size -= 1
212	return
213}
214
215// clearing a tree.
216func (t *Trie) Clear() {
217	t.root.internal = nil
218	t.root.external = nil
219	t.size = 0
220}
221
222// return the number of key in a tree.
223func (t *Trie) Size() int {
224	return t.size
225}
226
227// fetching elements with a given prefix.
228// handle is called with arguments key and value (if handle returns `false`, the iteration is aborted)
229func (t *Trie) Allprefixed(prefix []byte, handle func(key []byte, value interface{}) bool) bool {
230	// an empty tree
231	if t.size == 0 {
232		return true
233	}
234
235	// walk tree, maintaining top pointer
236	p := &t.root
237	top := p
238	if len(prefix) > 0 {
239		for q := p.internal; q != nil; q = p.internal {
240			p = &q.child[q.direction(prefix)]
241			if q.offset < len(prefix) {
242				top = p
243			}
244		}
245
246		// check prefix
247		if !bytes.HasPrefix(p.external.key, prefix) {
248			return true
249		}
250	}
251
252	return allprefixed(top, handle)
253}
254
255func allprefixed(n *node, handle func([]byte, interface{}) bool) bool {
256	if n.internal != nil {
257		// dealing with an internal node while recursing
258		for i := 0; i < 2; i++ {
259			if !allprefixed(&n.internal.child[i], handle) {
260				return false
261			}
262		}
263	} else {
264		// dealing with an external node while recursing
265		return handle(n.external.key, n.external.value)
266	}
267	return true
268}
269
270// Search for the longest matching key from the beginning of the given key.
271// if `key` is in Trie, `ok` is true.
272func (t *Trie) LongestPrefix(given []byte) (key []byte, value interface{}, ok bool) {
273	// an empty tree
274	if t.size == 0 {
275		return
276	}
277	return longestPrefix(&t.root, given)
278}
279
280func longestPrefix(n *node, key []byte) ([]byte, interface{}, bool) {
281	if n.internal != nil {
282		direction := n.internal.direction(key)
283		if k, v, ok := longestPrefix(&n.internal.child[direction], key); ok {
284			return k, v, ok
285		}
286		if direction == 1 {
287			return longestPrefix(&n.internal.child[0], key)
288		}
289	} else {
290		if bytes.HasPrefix(key, n.external.key) {
291			return n.external.key, n.external.value, true
292		}
293	}
294	return nil, nil, false
295}
296
297// Iterating elements from a given start key.
298// handle is called with arguments key and value (if handle returns `false`, the iteration is aborted)
299func (t *Trie) Walk(start []byte, handle func(key []byte, value interface{}) bool) bool {
300	if t.size == 0 {
301		return true
302	}
303	var seek bool
304	if start != nil {
305		seek = true
306	}
307	return walk(&t.root, start, &seek, handle)
308}
309
310func walk(n *node, key []byte, seek *bool, handle func([]byte, interface{}) bool) bool {
311	if n.internal != nil {
312		var direction int
313		if *seek {
314			direction = n.internal.direction(key)
315		}
316		if !walk(&n.internal.child[direction], key, seek, handle) {
317			return false
318		}
319		if !(*seek) && direction == 0 {
320			// iteration another side
321			return walk(&n.internal.child[1], key, seek, handle)
322		}
323		return true
324	} else {
325		if *seek {
326			if bytes.Equal(n.external.key, key) {
327				// seek completed
328				*seek = false
329			} else {
330				// key is not in Trie
331				return false
332			}
333		}
334		return handle(n.external.key, n.external.value)
335	}
336}
337
338// dump tree. (for debugging)
339func (t *Trie) Dump(w io.Writer) {
340	if t.root.internal == nil && t.root.external == nil {
341		return
342	}
343	if w == nil {
344		w = os.Stdout
345	}
346	dump(w, &t.root, true, "")
347}
348
349func dump(w io.Writer, n *node, right bool, prefix string) {
350	var ownprefix string
351	if right {
352		ownprefix = prefix
353	} else {
354		ownprefix = prefix[:len(prefix)-1] + "`"
355	}
356
357	if in := n.internal; in != nil {
358		fmt.Fprintf(w, "%s-- off=%d, bit=%08b(%02x), cont=%v\n", ownprefix, in.offset, in.bit, in.bit, in.cont)
359		for i := 0; i < 2; i++ {
360			var nextprefix string
361			switch i {
362			case 0:
363				nextprefix = prefix + " |"
364				right = true
365			case 1:
366				nextprefix = prefix + "  "
367				right = false
368			}
369			dump(w, &in.child[i], right, nextprefix)
370		}
371	} else {
372		fmt.Fprintf(w, "%s-- key=%d (%s)\n", ownprefix, n.external.key, key2str(n.external.key))
373	}
374	return
375}
376
377func key2str(key []byte) string {
378	for _, c := range key {
379		if !strconv.IsPrint(rune(c)) {
380			return hex.EncodeToString(key)
381		}
382	}
383	return string(key)
384}
385
386// create a tree.
387func NewTrie() *Trie {
388	return &Trie{}
389}
390
391func init() {
392	buildMsbMatrix()
393}
394