1# Licensed under a 3-clause BSD style license - see LICENSE.rst 2import operator 3 4__all__ = ['BST'] 5 6 7class MaxValue: 8 ''' 9 Represents an infinite value for purposes 10 of tuple comparison. 11 ''' 12 13 def __gt__(self, other): 14 return True 15 16 def __ge__(self, other): 17 return True 18 19 def __lt__(self, other): 20 return False 21 22 def __le__(self, other): 23 return False 24 25 def __repr__(self): 26 return "MAX" 27 28 __str__ = __repr__ 29 30 31class MinValue: 32 ''' 33 The opposite of MaxValue, i.e. a representation of 34 negative infinity. 35 ''' 36 37 def __lt__(self, other): 38 return True 39 40 def __le__(self, other): 41 return True 42 43 def __gt__(self, other): 44 return False 45 46 def __ge__(self, other): 47 return False 48 49 def __repr__(self): 50 return "MIN" 51 52 __str__ = __repr__ 53 54 55class Epsilon: 56 ''' 57 Represents the "next largest" version of a given value, 58 so that for all valid comparisons we have 59 x < y < Epsilon(y) < z whenever x < y < z and x, z are 60 not Epsilon objects. 61 62 Parameters 63 ---------- 64 val : object 65 Original value 66 ''' 67 __slots__ = ('val',) 68 69 def __init__(self, val): 70 self.val = val 71 72 def __lt__(self, other): 73 if self.val == other: 74 return False 75 return self.val < other 76 77 def __gt__(self, other): 78 if self.val == other: 79 return True 80 return self.val > other 81 82 def __eq__(self, other): 83 return False 84 85 def __repr__(self): 86 return repr(self.val) + " + epsilon" 87 88 89class Node: 90 ''' 91 An element in a binary search tree, containing 92 a key, data, and references to children nodes and 93 a parent node. 94 95 Parameters 96 ---------- 97 key : tuple 98 Node key 99 data : list or int 100 Node data 101 ''' 102 __lt__ = lambda x, y: x.key < y.key 103 __le__ = lambda x, y: x.key <= y.key 104 __eq__ = lambda x, y: x.key == y.key 105 __ge__ = lambda x, y: x.key >= y.key 106 __gt__ = lambda x, y: x.key > y.key 107 __ne__ = lambda x, y: x.key != y.key 108 __slots__ = ('key', 'data', 'left', 'right') 109 110 # each node has a key and data list 111 def __init__(self, key, data): 112 self.key = key 113 self.data = data if isinstance(data, list) else [data] 114 self.left = None 115 self.right = None 116 117 def replace(self, child, new_child): 118 ''' 119 Replace this node's child with a new child. 120 ''' 121 if self.left is not None and self.left == child: 122 self.left = new_child 123 elif self.right is not None and self.right == child: 124 self.right = new_child 125 else: 126 raise ValueError("Cannot call replace() on non-child") 127 128 def remove(self, child): 129 ''' 130 Remove the given child. 131 ''' 132 self.replace(child, None) 133 134 def set(self, other): 135 ''' 136 Copy the given node. 137 ''' 138 self.key = other.key 139 self.data = other.data[:] 140 141 def __str__(self): 142 return str((self.key, self.data)) 143 144 def __repr__(self): 145 return str(self) 146 147 148class BST: 149 ''' 150 A basic binary search tree in pure Python, used 151 as an engine for indexing. 152 153 Parameters 154 ---------- 155 data : Table 156 Sorted columns of the original table 157 row_index : Column object 158 Row numbers corresponding to data columns 159 unique : bool 160 Whether the values of the index must be unique. 161 Defaults to False. 162 ''' 163 NodeClass = Node 164 165 def __init__(self, data, row_index, unique=False): 166 self.root = None 167 self.size = 0 168 self.unique = unique 169 for key, row in zip(data, row_index): 170 self.add(tuple(key), row) 171 172 def add(self, key, data=None): 173 ''' 174 Add a key, data pair. 175 ''' 176 if data is None: 177 data = key 178 179 self.size += 1 180 node = self.NodeClass(key, data) 181 curr_node = self.root 182 if curr_node is None: 183 self.root = node 184 return 185 while True: 186 if node < curr_node: 187 if curr_node.left is None: 188 curr_node.left = node 189 break 190 curr_node = curr_node.left 191 elif node > curr_node: 192 if curr_node.right is None: 193 curr_node.right = node 194 break 195 curr_node = curr_node.right 196 elif self.unique: 197 raise ValueError("Cannot insert non-unique value") 198 else: # add data to node 199 curr_node.data.extend(node.data) 200 curr_node.data = sorted(curr_node.data) 201 return 202 203 def find(self, key): 204 ''' 205 Return all data values corresponding to a given key. 206 207 Parameters 208 ---------- 209 key : tuple 210 Input key 211 212 Returns 213 ------- 214 data_vals : list 215 List of rows corresponding to the input key 216 ''' 217 node, parent = self.find_node(key) 218 return node.data if node is not None else [] 219 220 def find_node(self, key): 221 ''' 222 Find the node associated with the given key. 223 ''' 224 if self.root is None: 225 return (None, None) 226 return self._find_recursive(key, self.root, None) 227 228 def shift_left(self, row): 229 ''' 230 Decrement all rows larger than the given row. 231 ''' 232 for node in self.traverse(): 233 node.data = [x - 1 if x > row else x for x in node.data] 234 235 def shift_right(self, row): 236 ''' 237 Increment all rows greater than or equal to the given row. 238 ''' 239 for node in self.traverse(): 240 node.data = [x + 1 if x >= row else x for x in node.data] 241 242 def _find_recursive(self, key, node, parent): 243 try: 244 if key == node.key: 245 return (node, parent) 246 elif key > node.key: 247 if node.right is None: 248 return (None, None) 249 return self._find_recursive(key, node.right, node) 250 else: 251 if node.left is None: 252 return (None, None) 253 return self._find_recursive(key, node.left, node) 254 except TypeError: # wrong key type 255 return (None, None) 256 257 def traverse(self, order='inorder'): 258 ''' 259 Return nodes of the BST in the given order. 260 261 Parameters 262 ---------- 263 order : str 264 The order in which to recursively search the BST. 265 Possible values are: 266 "preorder": current node, left subtree, right subtree 267 "inorder": left subtree, current node, right subtree 268 "postorder": left subtree, right subtree, current node 269 ''' 270 if order == 'preorder': 271 return self._preorder(self.root, []) 272 elif order == 'inorder': 273 return self._inorder(self.root, []) 274 elif order == 'postorder': 275 return self._postorder(self.root, []) 276 raise ValueError(f"Invalid traversal method: \"{order}\"") 277 278 def items(self): 279 ''' 280 Return BST items in order as (key, data) pairs. 281 ''' 282 return [(x.key, x.data) for x in self.traverse()] 283 284 def sort(self): 285 ''' 286 Make row order align with key order. 287 ''' 288 i = 0 289 for node in self.traverse(): 290 num_rows = len(node.data) 291 node.data = [x for x in range(i, i + num_rows)] 292 i += num_rows 293 294 def sorted_data(self): 295 ''' 296 Return BST rows sorted by key values. 297 ''' 298 return [x for node in self.traverse() for x in node.data] 299 300 def _preorder(self, node, lst): 301 if node is None: 302 return lst 303 lst.append(node) 304 self._preorder(node.left, lst) 305 self._preorder(node.right, lst) 306 return lst 307 308 def _inorder(self, node, lst): 309 if node is None: 310 return lst 311 self._inorder(node.left, lst) 312 lst.append(node) 313 self._inorder(node.right, lst) 314 return lst 315 316 def _postorder(self, node, lst): 317 if node is None: 318 return lst 319 self._postorder(node.left, lst) 320 self._postorder(node.right, lst) 321 lst.append(node) 322 return lst 323 324 def _substitute(self, node, parent, new_node): 325 if node is self.root: 326 self.root = new_node 327 else: 328 parent.replace(node, new_node) 329 330 def remove(self, key, data=None): 331 ''' 332 Remove data corresponding to the given key. 333 334 Parameters 335 ---------- 336 key : tuple 337 The key to remove 338 data : int or None 339 If None, remove the node corresponding to the given key. 340 If not None, remove only the given data value from the node. 341 342 Returns 343 ------- 344 successful : bool 345 True if removal was successful, false otherwise 346 ''' 347 node, parent = self.find_node(key) 348 if node is None: 349 return False 350 if data is not None: 351 if data not in node.data: 352 raise ValueError("Data does not belong to correct node") 353 elif len(node.data) > 1: 354 node.data.remove(data) 355 return True 356 if node.left is None and node.right is None: 357 self._substitute(node, parent, None) 358 elif node.left is None and node.right is not None: 359 self._substitute(node, parent, node.right) 360 elif node.right is None and node.left is not None: 361 self._substitute(node, parent, node.left) 362 else: 363 # find largest element of left subtree 364 curr_node = node.left 365 parent = node 366 while curr_node.right is not None: 367 parent = curr_node 368 curr_node = curr_node.right 369 self._substitute(curr_node, parent, curr_node.left) 370 node.set(curr_node) 371 self.size -= 1 372 return True 373 374 def is_valid(self): 375 ''' 376 Returns whether this is a valid BST. 377 ''' 378 return self._is_valid(self.root) 379 380 def _is_valid(self, node): 381 if node is None: 382 return True 383 return (node.left is None or node.left <= node) and \ 384 (node.right is None or node.right >= node) and \ 385 self._is_valid(node.left) and self._is_valid(node.right) 386 387 def range(self, lower, upper, bounds=(True, True)): 388 ''' 389 Return all nodes with keys in the given range. 390 391 Parameters 392 ---------- 393 lower : tuple 394 Lower bound 395 upper : tuple 396 Upper bound 397 bounds : (2,) tuple of bool 398 Indicates whether the search should be inclusive or 399 exclusive with respect to the endpoints. The first 400 argument corresponds to an inclusive lower bound, 401 and the second argument to an inclusive upper bound. 402 ''' 403 nodes = self.range_nodes(lower, upper, bounds) 404 return [x for node in nodes for x in node.data] 405 406 def range_nodes(self, lower, upper, bounds=(True, True)): 407 ''' 408 Return nodes in the given range. 409 ''' 410 if self.root is None: 411 return [] 412 # op1 is <= or <, op2 is >= or > 413 op1 = operator.le if bounds[0] else operator.lt 414 op2 = operator.ge if bounds[1] else operator.gt 415 return self._range(lower, upper, op1, op2, self.root, []) 416 417 def same_prefix(self, val): 418 ''' 419 Assuming the given value has smaller length than keys, return 420 nodes whose keys have this value as a prefix. 421 ''' 422 if self.root is None: 423 return [] 424 nodes = self._same_prefix(val, self.root, []) 425 return [x for node in nodes for x in node.data] 426 427 def _range(self, lower, upper, op1, op2, node, lst): 428 if op1(lower, node.key) and op2(upper, node.key): 429 lst.append(node) 430 if upper > node.key and node.right is not None: 431 self._range(lower, upper, op1, op2, node.right, lst) 432 if lower < node.key and node.left is not None: 433 self._range(lower, upper, op1, op2, node.left, lst) 434 return lst 435 436 def _same_prefix(self, val, node, lst): 437 prefix = node.key[:len(val)] 438 if prefix == val: 439 lst.append(node) 440 if prefix <= val and node.right is not None: 441 self._same_prefix(val, node.right, lst) 442 if prefix >= val and node.left is not None: 443 self._same_prefix(val, node.left, lst) 444 return lst 445 446 def __repr__(self): 447 return f'<{self.__class__.__name__}>' 448 449 def _print(self, node, level): 450 line = '\t' * level + str(node) + '\n' 451 if node.left is not None: 452 line += self._print(node.left, level + 1) 453 if node.right is not None: 454 line += self._print(node.right, level + 1) 455 return line 456 457 @property 458 def height(self): 459 ''' 460 Return the BST height. 461 ''' 462 return self._height(self.root) 463 464 def _height(self, node): 465 if node is None: 466 return -1 467 return max(self._height(node.left), 468 self._height(node.right)) + 1 469 470 def replace_rows(self, row_map): 471 ''' 472 Replace all rows with the values they map to in the 473 given dictionary. Any rows not present as keys in 474 the dictionary will have their nodes deleted. 475 476 Parameters 477 ---------- 478 row_map : dict 479 Mapping of row numbers to new row numbers 480 ''' 481 for key, data in self.items(): 482 data[:] = [row_map[x] for x in data if x in row_map] 483