1# Licensed under a 3-clause BSD style license - see LICENSE.rst
2import operator
3
4__all__ = ['BST']
5
6
7class MaxValue:
8    '''
9    Represents an infinite value for purposes
10    of tuple comparison.
11    '''
12
13    def __gt__(self, other):
14        return True
15
16    def __ge__(self, other):
17        return True
18
19    def __lt__(self, other):
20        return False
21
22    def __le__(self, other):
23        return False
24
25    def __repr__(self):
26        return "MAX"
27
28    __str__ = __repr__
29
30
31class MinValue:
32    '''
33    The opposite of MaxValue, i.e. a representation of
34    negative infinity.
35    '''
36
37    def __lt__(self, other):
38        return True
39
40    def __le__(self, other):
41        return True
42
43    def __gt__(self, other):
44        return False
45
46    def __ge__(self, other):
47        return False
48
49    def __repr__(self):
50        return "MIN"
51
52    __str__ = __repr__
53
54
55class Epsilon:
56    '''
57    Represents the "next largest" version of a given value,
58    so that for all valid comparisons we have
59    x < y < Epsilon(y) < z whenever x < y < z and x, z are
60    not Epsilon objects.
61
62    Parameters
63    ----------
64    val : object
65        Original value
66    '''
67    __slots__ = ('val',)
68
69    def __init__(self, val):
70        self.val = val
71
72    def __lt__(self, other):
73        if self.val == other:
74            return False
75        return self.val < other
76
77    def __gt__(self, other):
78        if self.val == other:
79            return True
80        return self.val > other
81
82    def __eq__(self, other):
83        return False
84
85    def __repr__(self):
86        return repr(self.val) + " + epsilon"
87
88
89class Node:
90    '''
91    An element in a binary search tree, containing
92    a key, data, and references to children nodes and
93    a parent node.
94
95    Parameters
96    ----------
97    key : tuple
98        Node key
99    data : list or int
100        Node data
101    '''
102    __lt__ = lambda x, y: x.key < y.key
103    __le__ = lambda x, y: x.key <= y.key
104    __eq__ = lambda x, y: x.key == y.key
105    __ge__ = lambda x, y: x.key >= y.key
106    __gt__ = lambda x, y: x.key > y.key
107    __ne__ = lambda x, y: x.key != y.key
108    __slots__ = ('key', 'data', 'left', 'right')
109
110    # each node has a key and data list
111    def __init__(self, key, data):
112        self.key = key
113        self.data = data if isinstance(data, list) else [data]
114        self.left = None
115        self.right = None
116
117    def replace(self, child, new_child):
118        '''
119        Replace this node's child with a new child.
120        '''
121        if self.left is not None and self.left == child:
122            self.left = new_child
123        elif self.right is not None and self.right == child:
124            self.right = new_child
125        else:
126            raise ValueError("Cannot call replace() on non-child")
127
128    def remove(self, child):
129        '''
130        Remove the given child.
131        '''
132        self.replace(child, None)
133
134    def set(self, other):
135        '''
136        Copy the given node.
137        '''
138        self.key = other.key
139        self.data = other.data[:]
140
141    def __str__(self):
142        return str((self.key, self.data))
143
144    def __repr__(self):
145        return str(self)
146
147
148class BST:
149    '''
150    A basic binary search tree in pure Python, used
151    as an engine for indexing.
152
153    Parameters
154    ----------
155    data : Table
156        Sorted columns of the original table
157    row_index : Column object
158        Row numbers corresponding to data columns
159    unique : bool
160        Whether the values of the index must be unique.
161        Defaults to False.
162    '''
163    NodeClass = Node
164
165    def __init__(self, data, row_index, unique=False):
166        self.root = None
167        self.size = 0
168        self.unique = unique
169        for key, row in zip(data, row_index):
170            self.add(tuple(key), row)
171
172    def add(self, key, data=None):
173        '''
174        Add a key, data pair.
175        '''
176        if data is None:
177            data = key
178
179        self.size += 1
180        node = self.NodeClass(key, data)
181        curr_node = self.root
182        if curr_node is None:
183            self.root = node
184            return
185        while True:
186            if node < curr_node:
187                if curr_node.left is None:
188                    curr_node.left = node
189                    break
190                curr_node = curr_node.left
191            elif node > curr_node:
192                if curr_node.right is None:
193                    curr_node.right = node
194                    break
195                curr_node = curr_node.right
196            elif self.unique:
197                raise ValueError("Cannot insert non-unique value")
198            else:  # add data to node
199                curr_node.data.extend(node.data)
200                curr_node.data = sorted(curr_node.data)
201                return
202
203    def find(self, key):
204        '''
205        Return all data values corresponding to a given key.
206
207        Parameters
208        ----------
209        key : tuple
210            Input key
211
212        Returns
213        -------
214        data_vals : list
215            List of rows corresponding to the input key
216        '''
217        node, parent = self.find_node(key)
218        return node.data if node is not None else []
219
220    def find_node(self, key):
221        '''
222        Find the node associated with the given key.
223        '''
224        if self.root is None:
225            return (None, None)
226        return self._find_recursive(key, self.root, None)
227
228    def shift_left(self, row):
229        '''
230        Decrement all rows larger than the given row.
231        '''
232        for node in self.traverse():
233            node.data = [x - 1 if x > row else x for x in node.data]
234
235    def shift_right(self, row):
236        '''
237        Increment all rows greater than or equal to the given row.
238        '''
239        for node in self.traverse():
240            node.data = [x + 1 if x >= row else x for x in node.data]
241
242    def _find_recursive(self, key, node, parent):
243        try:
244            if key == node.key:
245                return (node, parent)
246            elif key > node.key:
247                if node.right is None:
248                    return (None, None)
249                return self._find_recursive(key, node.right, node)
250            else:
251                if node.left is None:
252                    return (None, None)
253                return self._find_recursive(key, node.left, node)
254        except TypeError:  # wrong key type
255            return (None, None)
256
257    def traverse(self, order='inorder'):
258        '''
259        Return nodes of the BST in the given order.
260
261        Parameters
262        ----------
263        order : str
264            The order in which to recursively search the BST.
265            Possible values are:
266            "preorder": current node, left subtree, right subtree
267            "inorder": left subtree, current node, right subtree
268            "postorder": left subtree, right subtree, current node
269        '''
270        if order == 'preorder':
271            return self._preorder(self.root, [])
272        elif order == 'inorder':
273            return self._inorder(self.root, [])
274        elif order == 'postorder':
275            return self._postorder(self.root, [])
276        raise ValueError(f"Invalid traversal method: \"{order}\"")
277
278    def items(self):
279        '''
280        Return BST items in order as (key, data) pairs.
281        '''
282        return [(x.key, x.data) for x in self.traverse()]
283
284    def sort(self):
285        '''
286        Make row order align with key order.
287        '''
288        i = 0
289        for node in self.traverse():
290            num_rows = len(node.data)
291            node.data = [x for x in range(i, i + num_rows)]
292            i += num_rows
293
294    def sorted_data(self):
295        '''
296        Return BST rows sorted by key values.
297        '''
298        return [x for node in self.traverse() for x in node.data]
299
300    def _preorder(self, node, lst):
301        if node is None:
302            return lst
303        lst.append(node)
304        self._preorder(node.left, lst)
305        self._preorder(node.right, lst)
306        return lst
307
308    def _inorder(self, node, lst):
309        if node is None:
310            return lst
311        self._inorder(node.left, lst)
312        lst.append(node)
313        self._inorder(node.right, lst)
314        return lst
315
316    def _postorder(self, node, lst):
317        if node is None:
318            return lst
319        self._postorder(node.left, lst)
320        self._postorder(node.right, lst)
321        lst.append(node)
322        return lst
323
324    def _substitute(self, node, parent, new_node):
325        if node is self.root:
326            self.root = new_node
327        else:
328            parent.replace(node, new_node)
329
330    def remove(self, key, data=None):
331        '''
332        Remove data corresponding to the given key.
333
334        Parameters
335        ----------
336        key : tuple
337            The key to remove
338        data : int or None
339            If None, remove the node corresponding to the given key.
340            If not None, remove only the given data value from the node.
341
342        Returns
343        -------
344        successful : bool
345            True if removal was successful, false otherwise
346        '''
347        node, parent = self.find_node(key)
348        if node is None:
349            return False
350        if data is not None:
351            if data not in node.data:
352                raise ValueError("Data does not belong to correct node")
353            elif len(node.data) > 1:
354                node.data.remove(data)
355                return True
356        if node.left is None and node.right is None:
357            self._substitute(node, parent, None)
358        elif node.left is None and node.right is not None:
359            self._substitute(node, parent, node.right)
360        elif node.right is None and node.left is not None:
361            self._substitute(node, parent, node.left)
362        else:
363            # find largest element of left subtree
364            curr_node = node.left
365            parent = node
366            while curr_node.right is not None:
367                parent = curr_node
368                curr_node = curr_node.right
369            self._substitute(curr_node, parent, curr_node.left)
370            node.set(curr_node)
371        self.size -= 1
372        return True
373
374    def is_valid(self):
375        '''
376        Returns whether this is a valid BST.
377        '''
378        return self._is_valid(self.root)
379
380    def _is_valid(self, node):
381        if node is None:
382            return True
383        return (node.left is None or node.left <= node) and \
384            (node.right is None or node.right >= node) and \
385            self._is_valid(node.left) and self._is_valid(node.right)
386
387    def range(self, lower, upper, bounds=(True, True)):
388        '''
389        Return all nodes with keys in the given range.
390
391        Parameters
392        ----------
393        lower : tuple
394            Lower bound
395        upper : tuple
396            Upper bound
397        bounds : (2,) tuple of bool
398            Indicates whether the search should be inclusive or
399            exclusive with respect to the endpoints. The first
400            argument corresponds to an inclusive lower bound,
401            and the second argument to an inclusive upper bound.
402        '''
403        nodes = self.range_nodes(lower, upper, bounds)
404        return [x for node in nodes for x in node.data]
405
406    def range_nodes(self, lower, upper, bounds=(True, True)):
407        '''
408        Return nodes in the given range.
409        '''
410        if self.root is None:
411            return []
412        # op1 is <= or <, op2 is >= or >
413        op1 = operator.le if bounds[0] else operator.lt
414        op2 = operator.ge if bounds[1] else operator.gt
415        return self._range(lower, upper, op1, op2, self.root, [])
416
417    def same_prefix(self, val):
418        '''
419        Assuming the given value has smaller length than keys, return
420        nodes whose keys have this value as a prefix.
421        '''
422        if self.root is None:
423            return []
424        nodes = self._same_prefix(val, self.root, [])
425        return [x for node in nodes for x in node.data]
426
427    def _range(self, lower, upper, op1, op2, node, lst):
428        if op1(lower, node.key) and op2(upper, node.key):
429            lst.append(node)
430        if upper > node.key and node.right is not None:
431            self._range(lower, upper, op1, op2, node.right, lst)
432        if lower < node.key and node.left is not None:
433            self._range(lower, upper, op1, op2, node.left, lst)
434        return lst
435
436    def _same_prefix(self, val, node, lst):
437        prefix = node.key[:len(val)]
438        if prefix == val:
439            lst.append(node)
440        if prefix <= val and node.right is not None:
441            self._same_prefix(val, node.right, lst)
442        if prefix >= val and node.left is not None:
443            self._same_prefix(val, node.left, lst)
444        return lst
445
446    def __repr__(self):
447        return f'<{self.__class__.__name__}>'
448
449    def _print(self, node, level):
450        line = '\t' * level + str(node) + '\n'
451        if node.left is not None:
452            line += self._print(node.left, level + 1)
453        if node.right is not None:
454            line += self._print(node.right, level + 1)
455        return line
456
457    @property
458    def height(self):
459        '''
460        Return the BST height.
461        '''
462        return self._height(self.root)
463
464    def _height(self, node):
465        if node is None:
466            return -1
467        return max(self._height(node.left),
468                   self._height(node.right)) + 1
469
470    def replace_rows(self, row_map):
471        '''
472        Replace all rows with the values they map to in the
473        given dictionary. Any rows not present as keys in
474        the dictionary will have their nodes deleted.
475
476        Parameters
477        ----------
478        row_map : dict
479            Mapping of row numbers to new row numbers
480        '''
481        for key, data in self.items():
482            data[:] = [row_map[x] for x in data if x in row_map]
483