1package iradix
2
3import (
4	"bytes"
5	"sort"
6)
7
8// WalkFn is used when walking the tree. Takes a
9// key and value, returning if iteration should
10// be terminated.
11type WalkFn func(k []byte, v interface{}) bool
12
13// leafNode is used to represent a value
14type leafNode struct {
15	mutateCh chan struct{}
16	key      []byte
17	val      interface{}
18}
19
20// edge is used to represent an edge node
21type edge struct {
22	label byte
23	node  *Node
24}
25
26// Node is an immutable node in the radix tree
27type Node struct {
28	// mutateCh is closed if this node is modified
29	mutateCh chan struct{}
30
31	// leaf is used to store possible leaf
32	leaf *leafNode
33
34	// prefix is the common prefix we ignore
35	prefix []byte
36
37	// Edges should be stored in-order for iteration.
38	// We avoid a fully materialized slice to save memory,
39	// since in most cases we expect to be sparse
40	edges edges
41}
42
43func (n *Node) isLeaf() bool {
44	return n.leaf != nil
45}
46
47func (n *Node) addEdge(e edge) {
48	num := len(n.edges)
49	idx := sort.Search(num, func(i int) bool {
50		return n.edges[i].label >= e.label
51	})
52	n.edges = append(n.edges, e)
53	if idx != num {
54		copy(n.edges[idx+1:], n.edges[idx:num])
55		n.edges[idx] = e
56	}
57}
58
59func (n *Node) replaceEdge(e edge) {
60	num := len(n.edges)
61	idx := sort.Search(num, func(i int) bool {
62		return n.edges[i].label >= e.label
63	})
64	if idx < num && n.edges[idx].label == e.label {
65		n.edges[idx].node = e.node
66		return
67	}
68	panic("replacing missing edge")
69}
70
71func (n *Node) getEdge(label byte) (int, *Node) {
72	num := len(n.edges)
73	idx := sort.Search(num, func(i int) bool {
74		return n.edges[i].label >= label
75	})
76	if idx < num && n.edges[idx].label == label {
77		return idx, n.edges[idx].node
78	}
79	return -1, nil
80}
81
82func (n *Node) getLowerBoundEdge(label byte) (int, *Node) {
83	num := len(n.edges)
84	idx := sort.Search(num, func(i int) bool {
85		return n.edges[i].label >= label
86	})
87	// we want lower bound behavior so return even if it's not an exact match
88	if idx < num {
89		return idx, n.edges[idx].node
90	}
91	return -1, nil
92}
93
94func (n *Node) delEdge(label byte) {
95	num := len(n.edges)
96	idx := sort.Search(num, func(i int) bool {
97		return n.edges[i].label >= label
98	})
99	if idx < num && n.edges[idx].label == label {
100		copy(n.edges[idx:], n.edges[idx+1:])
101		n.edges[len(n.edges)-1] = edge{}
102		n.edges = n.edges[:len(n.edges)-1]
103	}
104}
105
106func (n *Node) GetWatch(k []byte) (<-chan struct{}, interface{}, bool) {
107	search := k
108	watch := n.mutateCh
109	for {
110		// Check for key exhaustion
111		if len(search) == 0 {
112			if n.isLeaf() {
113				return n.leaf.mutateCh, n.leaf.val, true
114			}
115			break
116		}
117
118		// Look for an edge
119		_, n = n.getEdge(search[0])
120		if n == nil {
121			break
122		}
123
124		// Update to the finest granularity as the search makes progress
125		watch = n.mutateCh
126
127		// Consume the search prefix
128		if bytes.HasPrefix(search, n.prefix) {
129			search = search[len(n.prefix):]
130		} else {
131			break
132		}
133	}
134	return watch, nil, false
135}
136
137func (n *Node) Get(k []byte) (interface{}, bool) {
138	_, val, ok := n.GetWatch(k)
139	return val, ok
140}
141
142// LongestPrefix is like Get, but instead of an
143// exact match, it will return the longest prefix match.
144func (n *Node) LongestPrefix(k []byte) ([]byte, interface{}, bool) {
145	var last *leafNode
146	search := k
147	for {
148		// Look for a leaf node
149		if n.isLeaf() {
150			last = n.leaf
151		}
152
153		// Check for key exhaution
154		if len(search) == 0 {
155			break
156		}
157
158		// Look for an edge
159		_, n = n.getEdge(search[0])
160		if n == nil {
161			break
162		}
163
164		// Consume the search prefix
165		if bytes.HasPrefix(search, n.prefix) {
166			search = search[len(n.prefix):]
167		} else {
168			break
169		}
170	}
171	if last != nil {
172		return last.key, last.val, true
173	}
174	return nil, nil, false
175}
176
177// Minimum is used to return the minimum value in the tree
178func (n *Node) Minimum() ([]byte, interface{}, bool) {
179	for {
180		if n.isLeaf() {
181			return n.leaf.key, n.leaf.val, true
182		}
183		if len(n.edges) > 0 {
184			n = n.edges[0].node
185		} else {
186			break
187		}
188	}
189	return nil, nil, false
190}
191
192// Maximum is used to return the maximum value in the tree
193func (n *Node) Maximum() ([]byte, interface{}, bool) {
194	for {
195		if num := len(n.edges); num > 0 {
196			n = n.edges[num-1].node
197			continue
198		}
199		if n.isLeaf() {
200			return n.leaf.key, n.leaf.val, true
201		} else {
202			break
203		}
204	}
205	return nil, nil, false
206}
207
208// Iterator is used to return an iterator at
209// the given node to walk the tree
210func (n *Node) Iterator() *Iterator {
211	return &Iterator{node: n}
212}
213
214// rawIterator is used to return a raw iterator at the given node to walk the
215// tree.
216func (n *Node) rawIterator() *rawIterator {
217	iter := &rawIterator{node: n}
218	iter.Next()
219	return iter
220}
221
222// Walk is used to walk the tree
223func (n *Node) Walk(fn WalkFn) {
224	recursiveWalk(n, fn)
225}
226
227// WalkPrefix is used to walk the tree under a prefix
228func (n *Node) WalkPrefix(prefix []byte, fn WalkFn) {
229	search := prefix
230	for {
231		// Check for key exhaution
232		if len(search) == 0 {
233			recursiveWalk(n, fn)
234			return
235		}
236
237		// Look for an edge
238		_, n = n.getEdge(search[0])
239		if n == nil {
240			break
241		}
242
243		// Consume the search prefix
244		if bytes.HasPrefix(search, n.prefix) {
245			search = search[len(n.prefix):]
246
247		} else if bytes.HasPrefix(n.prefix, search) {
248			// Child may be under our search prefix
249			recursiveWalk(n, fn)
250			return
251		} else {
252			break
253		}
254	}
255}
256
257// WalkPath is used to walk the tree, but only visiting nodes
258// from the root down to a given leaf. Where WalkPrefix walks
259// all the entries *under* the given prefix, this walks the
260// entries *above* the given prefix.
261func (n *Node) WalkPath(path []byte, fn WalkFn) {
262	search := path
263	for {
264		// Visit the leaf values if any
265		if n.leaf != nil && fn(n.leaf.key, n.leaf.val) {
266			return
267		}
268
269		// Check for key exhaution
270		if len(search) == 0 {
271			return
272		}
273
274		// Look for an edge
275		_, n = n.getEdge(search[0])
276		if n == nil {
277			return
278		}
279
280		// Consume the search prefix
281		if bytes.HasPrefix(search, n.prefix) {
282			search = search[len(n.prefix):]
283		} else {
284			break
285		}
286	}
287}
288
289// recursiveWalk is used to do a pre-order walk of a node
290// recursively. Returns true if the walk should be aborted
291func recursiveWalk(n *Node, fn WalkFn) bool {
292	// Visit the leaf values if any
293	if n.leaf != nil && fn(n.leaf.key, n.leaf.val) {
294		return true
295	}
296
297	// Recurse on the children
298	for _, e := range n.edges {
299		if recursiveWalk(e.node, fn) {
300			return true
301		}
302	}
303	return false
304}
305