1""" 2Pythagoras tree viewer for visualizing tree structures. 3 4The pythagoras tree viewer widget is a widget that can be plugged into any 5existing widget given a tree adapter instance. It is simply a canvas that takes 6and input tree adapter and takes care of all the drawing. 7 8Types 9----- 10Square : namedtuple (center, length, angle) 11 Since Pythagoras trees deal only with squares (they also deal with 12 rectangles in the generalized form, but are completely unreadable), this 13 is what all the squares are stored as. 14Point : namedtuple (x, y) 15 Self exaplanatory. 16 17""" 18from abc import ABCMeta, abstractmethod 19from collections import namedtuple, defaultdict, deque 20from math import pi, sqrt, cos, sin, degrees 21 22import numpy as np 23from AnyQt.QtCore import Qt, QTimer, QRectF, QSizeF 24from AnyQt.QtGui import QColor, QPen 25from AnyQt.QtWidgets import ( 26 QSizePolicy, QGraphicsItem, QGraphicsRectItem, QGraphicsWidget, QStyle 27) 28 29from Orange.widgets.utils import to_html 30from Orange.widgets.visualize.utils.tree.rules import Rule 31from Orange.widgets.visualize.utils.tree.treeadapter import TreeAdapter 32 33# z index range, increase if needed 34Z_STEP = 5000000 35 36Square = namedtuple('Square', ['center', 'length', 'angle']) 37Point = namedtuple('Point', ['x', 'y']) 38 39 40class PythagorasTreeViewer(QGraphicsWidget): 41 """Pythagoras tree viewer graphics widget. 42 43 Examples 44 -------- 45 >>> from Orange.widgets.visualize.utils.tree.treeadapter import ( 46 ... TreeAdapter 47 ... ) 48 Pass tree through constructor. 49 >>> tree_view = PythagorasTreeViewer(parent=scene, adapter=tree_adapter) 50 51 Pass tree later through method. 52 >>> tree_adapter = TreeAdapter() 53 >>> scene = QGraphicsScene() 54 This is where the magic happens 55 >>> tree_view = PythagorasTreeViewer(parent=scene) 56 >>> tree_view.set_tree(tree_adapter) 57 58 Both these examples set the appropriate tree and add all the squares to the 59 widget instance. 60 61 Parameters 62 ---------- 63 parent : QGraphicsItem, optional 64 The parent object that the graphics widget belongs to. Should be a 65 scene. 66 adapter : TreeAdapter, optional 67 Any valid tree adapter instance. 68 interacitive : bool, optional 69 Specify whether the widget should have an interactive display. This 70 means special hover effects, selectable boxes. Default is true. 71 72 Notes 73 ----- 74 .. note:: The class contains two clear methods: `clear` and `clear_tree`. 75 Each has their own use. 76 `clear_tree` will clear out the tree and remove any graphics items. 77 `clear` will, on the other hand, clear everything, all settings 78 (tooltip and color calculation functions. 79 80 This is useful because when we want to change the size calculation of 81 the Pythagora tree, we just want to clear the scene and it would be 82 inconvenient to have to set color and tooltip functions again. 83 On the other hand, when we want to draw a brand new tree, it is best 84 to clear all settings to avoid any strange bugs - we start with a blank 85 slate. 86 87 """ 88 89 def __init__(self, parent=None, adapter=None, depth_limit=0, padding=0, 90 **kwargs): 91 super().__init__() 92 self.parent = parent 93 94 # In case a tree was passed, it will be handled at the end of init 95 self.tree_adapter = None 96 self.root = None 97 98 self._depth_limit = depth_limit 99 self._interactive = kwargs.get('interactive', True) 100 self._padding = padding 101 102 self._square_objects = {} 103 self._drawn_nodes = deque() 104 self._frontier = deque() 105 106 self._target_class_index = 0 107 108 self.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) 109 110 # If a tree was passed in the constructor, set and draw the tree 111 if adapter is not None: 112 self.set_tree( 113 adapter, 114 target_class_index=kwargs.get('target_class_index'), 115 weight_adjustment=kwargs.get('weight_adjustment'), 116 ) 117 # Since `set_tree` needs to draw the entire tree to be visualized 118 # properly, it overrides the `depth_limit` to max. If we specified 119 # the depth limit, however, apply that afterwards- 120 self.set_depth_limit(depth_limit) 121 122 def set_tree(self, tree_adapter, weight_adjustment=lambda x: x, 123 target_class_index=0): 124 """Pass in a new tree adapter instance and perform updates to canvas. 125 126 Parameters 127 ---------- 128 tree_adapter : TreeAdapter 129 The new tree adapter that is to be used. 130 weight_adjustment : callable 131 A weight adjustment function that with signature `x -> x` 132 target_class_index : int 133 134 Returns 135 ------- 136 137 """ 138 self.clear_tree() 139 self.tree_adapter = tree_adapter 140 self.weight_adjustment = weight_adjustment 141 142 if self.tree_adapter is not None: 143 self.root = self._calculate_tree(self.tree_adapter, self.weight_adjustment) 144 self.set_depth_limit(tree_adapter.max_depth) 145 self.target_class_changed(target_class_index) 146 self._draw_tree(self.root) 147 148 def set_size_calc(self, weight_adjustment): 149 """Set the weight adjustment on the tree. Redraws the whole tree.""" 150 # Since we have to redraw the whole tree anyways, just call `set_tree` 151 self.weight_adjustment = weight_adjustment 152 self.set_tree(self.tree_adapter, self.weight_adjustment, 153 self._target_class_index) 154 155 def set_depth_limit(self, depth): 156 """Update the drawing depth limit. 157 158 The drawing stops when the depth is GT the limit. This means that at 159 depth 0, the root node will be drawn. 160 161 Parameters 162 ---------- 163 depth : int 164 The maximum depth at which the nodes can still be drawn. 165 166 Returns 167 ------- 168 169 """ 170 self._depth_limit = depth 171 self._draw_tree(self.root) 172 173 def target_class_changed(self, target_class_index=0): 174 """When the target class has changed, perform appropriate updates.""" 175 self._target_class_index = target_class_index 176 177 def _recurse(node): 178 node.target_class_index = target_class_index 179 for child in node.children: 180 _recurse(child) 181 182 _recurse(self.root) 183 184 def tooltip_changed(self, tooltip_enabled): 185 """Set the tooltip to the appropriate value on each square.""" 186 for square in self._squares(): 187 if tooltip_enabled: 188 square.setToolTip(square.tree_node.tooltip) 189 else: 190 square.setToolTip(None) 191 192 def clear(self): 193 """Clear the entire widget state.""" 194 self.clear_tree() 195 self._target_class_index = 0 196 197 def clear_tree(self): 198 """Clear only the tree, keeping tooltip and color functions.""" 199 self.tree_adapter = None 200 self.root = None 201 self._clear_scene() 202 203 def _calculate_tree(self, tree_adapter, weight_adjustment): 204 """Actually calculate the tree squares""" 205 tree_builder = PythagorasTree(weight_adjustment=weight_adjustment) 206 return tree_builder.pythagoras_tree( 207 tree_adapter, tree_adapter.root, Square(Point(0, 0), 200, -pi / 2) 208 ) 209 210 def _draw_tree(self, root): 211 """Efficiently draw the tree with regards to the depth. 212 213 If we used a recursive approach, the tree would have to be redrawn 214 every time the depth changed, which is very impractical for larger 215 trees, since drawing can take a long time. 216 217 Using an iterative approach, we use two queues to represent the tree 218 frontier and the nodes that have already been drawn. We also store the 219 current depth. This way, when the max depth is increased, we do not 220 redraw the entire tree but only iterate through the frontier and draw 221 those nodes, and update the frontier accordingly. 222 When decreasing the max depth, we reverse the process, we clear the 223 frontier, and remove nodes from the drawn nodes, and append those with 224 depth max_depth + 1 to the frontier, so the frontier doesn't get 225 cluttered. 226 227 Parameters 228 ---------- 229 root : TreeNode 230 The root tree node. 231 232 Returns 233 ------- 234 235 """ 236 if self.root is None: 237 return 238 # if this is the first time drawing the tree begin with root 239 if not self._drawn_nodes: 240 self._frontier.appendleft((0, root)) 241 # if the depth was decreased, we can clear the frontier, otherwise 242 # frontier gets cluttered with non-frontier nodes 243 was_decreased = self._depth_was_decreased() 244 if was_decreased: 245 self._frontier.clear() 246 # remove nodes from drawn and add to frontier if limit is decreased 247 while self._drawn_nodes: 248 depth, node = self._drawn_nodes.pop() 249 # check if the node is in the allowed limit 250 if depth <= self._depth_limit: 251 self._drawn_nodes.append((depth, node)) 252 break 253 if depth == self._depth_limit + 1: 254 self._frontier.appendleft((depth, node)) 255 256 if node.label in self._square_objects: 257 self._square_objects[node.label].hide() 258 259 # add nodes to drawn and remove from frontier if limit is increased 260 while self._frontier: 261 depth, node = self._frontier.popleft() 262 # check if the depth of the node is outside the allowed limit 263 if depth > self._depth_limit: 264 self._frontier.appendleft((depth, node)) 265 break 266 self._drawn_nodes.append((depth, node)) 267 self._frontier.extend((depth + 1, c) for c in node.children) 268 269 node.target_class_index = self._target_class_index 270 if node.label in self._square_objects: 271 self._square_objects[node.label].show() 272 else: 273 square_obj = InteractiveSquareGraphicsItem \ 274 if self._interactive else SquareGraphicsItem 275 self._square_objects[node.label] = square_obj( 276 node, parent=self, zvalue=depth) 277 278 def _depth_was_decreased(self): 279 if not self._drawn_nodes: 280 return False 281 # checks if the max depth was increased from the last change 282 depth, node = self._drawn_nodes.pop() 283 self._drawn_nodes.append((depth, node)) 284 # if the right most node in drawn nodes has appropriate depth, it must 285 # have been increased 286 return depth > self._depth_limit 287 288 def _squares(self): 289 return [node.graphics_item for _, node in self._drawn_nodes] 290 291 def _clear_scene(self): 292 for square in self._squares(): 293 self.scene().removeItem(square) 294 self._frontier.clear() 295 self._drawn_nodes.clear() 296 self._square_objects.clear() 297 298 def boundingRect(self): 299 return self.childrenBoundingRect().adjusted( 300 -self._padding, -self._padding, self._padding, self._padding) 301 302 def sizeHint(self, size_hint, size_constraint=None, *args, **kwargs): 303 return self.boundingRect().size() + QSizeF(self._padding, self._padding) 304 305 306class SquareGraphicsItem(QGraphicsRectItem): 307 """Square Graphics Item. 308 309 Square component to draw as components for the non-interactive Pythagoras 310 tree. 311 312 Parameters 313 ---------- 314 tree_node : TreeNode 315 The tree node the square represents. 316 parent : QGraphicsItem 317 318 """ 319 320 def __init__(self, tree_node, parent=None, **kwargs): 321 self.tree_node = tree_node 322 super().__init__(self._get_rect_attributes(), parent) 323 self.tree_node.graphics_item = self 324 325 self.setTransformOriginPoint(self.boundingRect().center()) 326 self.setRotation(degrees(self.tree_node.square.angle)) 327 328 self.setBrush(kwargs.get('brush', QColor('#297A1F'))) 329 # The border should be invariant to scaling 330 pen = QPen(QColor(Qt.black)) 331 pen.setWidthF(0.75) 332 pen.setCosmetic(True) 333 self.setPen(pen) 334 335 self.setAcceptHoverEvents(True) 336 self.setZValue(kwargs.get('zvalue', 0)) 337 self.z_step = Z_STEP 338 339 # calculate the correct z values based on the parent 340 if self.tree_node.parent != TreeAdapter.ROOT_PARENT: 341 p = self.tree_node.parent 342 # override root z step 343 num_children = len(p.children) 344 own_index = [1 if c.label == self.tree_node.label else 0 345 for c in p.children].index(1) 346 347 self.z_step = int(p.graphics_item.z_step / num_children) 348 base_z = p.graphics_item.zValue() 349 350 self.setZValue(base_z + own_index * self.z_step) 351 352 def update(self): 353 self.setBrush(self.tree_node.color) 354 return super().update() 355 356 def _get_rect_attributes(self): 357 """Get the rectangle attributes requrired to draw item. 358 359 Compute the QRectF that a QGraphicsRect needs to be rendered with the 360 data passed down in the constructor. 361 362 """ 363 center, length, _ = self.tree_node.square 364 x = center[0] - length / 2 365 y = center[1] - length / 2 366 return QRectF(x, y, length, length) 367 368 369class InteractiveSquareGraphicsItem(SquareGraphicsItem): 370 """Interactive square graphics items. 371 372 This is different from the base square graphics item so that it is 373 selectable, and it can handle and react to hover events (highlight and 374 focus own branch). 375 376 Parameters 377 ---------- 378 tree_node : TreeNode 379 The tree node the square represents. 380 parent : QGraphicsItem 381 382 """ 383 384 timer = QTimer() 385 386 MAX_OPACITY = 1. 387 SELECTION_OPACITY = .5 388 HOVER_OPACITY = .1 389 390 def __init__(self, tree_node, parent=None, **kwargs): 391 super().__init__(tree_node, parent, **kwargs) 392 self.setFlag(QGraphicsItem.ItemIsSelectable, True) 393 394 self.initial_zvalue = self.zValue() 395 # The max z value changes if any item is selected 396 self.any_selected = False 397 398 self.timer.setSingleShot(True) 399 400 def update(self): 401 self.setToolTip(self.tree_node.tooltip) 402 return super().update() 403 404 def hoverEnterEvent(self, event): 405 self.timer.stop() 406 407 def fnc(graphics_item): 408 graphics_item.setZValue(Z_STEP) 409 if self.any_selected: 410 if graphics_item.isSelected(): 411 opacity = self.MAX_OPACITY 412 else: 413 opacity = self.SELECTION_OPACITY 414 else: 415 opacity = self.MAX_OPACITY 416 graphics_item.setOpacity(opacity) 417 418 def other_fnc(graphics_item): 419 if graphics_item.isSelected(): 420 opacity = self.MAX_OPACITY 421 else: 422 opacity = self.HOVER_OPACITY 423 graphics_item.setOpacity(opacity) 424 graphics_item.setZValue(self.initial_zvalue) 425 426 self._propagate_z_values(self, fnc, other_fnc) 427 428 def hoverLeaveEvent(self, event): 429 430 def fnc(graphics_item): 431 # No need to set opacity in this branch since it was just selected 432 # and had the max value 433 graphics_item.setZValue(self.initial_zvalue) 434 435 def other_fnc(graphics_item): 436 if self.any_selected: 437 if graphics_item.isSelected(): 438 opacity = self.MAX_OPACITY 439 else: 440 opacity = self.SELECTION_OPACITY 441 else: 442 opacity = self.MAX_OPACITY 443 graphics_item.setOpacity(opacity) 444 445 self.timer.timeout.connect( 446 lambda: self._propagate_z_values(self, fnc, other_fnc)) 447 448 self.timer.start(250) 449 450 def _propagate_z_values(self, graphics_item, fnc, other_fnc): 451 self._propagate_to_children(graphics_item, fnc) 452 self._propagate_to_parents(graphics_item, fnc, other_fnc) 453 454 def _propagate_to_children(self, graphics_item, fnc): 455 # propagate function that handles graphics item to appropriate children 456 fnc(graphics_item) 457 for c in graphics_item.tree_node.children: 458 self._propagate_to_children(c.graphics_item, fnc) 459 460 def _propagate_to_parents(self, graphics_item, fnc, other_fnc): 461 # propagate function that handles graphics item to appropriate parents 462 if graphics_item.tree_node.parent != TreeAdapter.ROOT_PARENT: 463 parent = graphics_item.tree_node.parent.graphics_item 464 # handle the non relevant children nodes 465 for c in parent.tree_node.children: 466 if c != graphics_item.tree_node: 467 self._propagate_to_children(c.graphics_item, other_fnc) 468 # handle the parent node 469 fnc(parent) 470 # propagate up the tree 471 self._propagate_to_parents(parent, fnc, other_fnc) 472 473 def mouseDoubleClickEvent(self, event): 474 self.tree_node.tree.reverse_children(self.tree_node.label) 475 p = self.parentWidget() # PythagorasTreeViewer 476 p.set_tree(p.tree_adapter, p.weight_adjustment, self.tree_node.target_class_index) 477 widget = p.parent # OWPythagorasTree 478 widget._update_main_area() 479 480 def selection_changed(self): 481 """Handle selection changed.""" 482 self.any_selected = len(self.scene().selectedItems()) > 0 483 if self.any_selected: 484 if self.isSelected(): 485 self.setOpacity(self.MAX_OPACITY) 486 else: 487 if self.opacity() != self.HOVER_OPACITY: 488 self.setOpacity(self.SELECTION_OPACITY) 489 else: 490 self.setGraphicsEffect(None) 491 self.setOpacity(self.MAX_OPACITY) 492 493 def paint(self, painter, option, widget=None): 494 # Override the default selected appearance 495 if self.isSelected(): 496 option.state ^= QStyle.State_Selected 497 rect = self.rect() 498 # this must render before overlay due to order in which it's drawn 499 super().paint(painter, option, widget) 500 painter.save() 501 pen = QPen(QColor(Qt.black)) 502 pen.setWidthF(2) 503 pen.setCosmetic(True) 504 pen.setJoinStyle(Qt.MiterJoin) 505 painter.setPen(pen) 506 painter.drawRect(rect) 507 painter.restore() 508 else: 509 super().paint(painter, option, widget) 510 511 512class TreeNode(metaclass=ABCMeta): 513 """A tree node meant to be used in conjuction with graphics items. 514 515 The tree node contains methods that are very general to any tree 516 visualisation, containing methods for the node color and tooltip. 517 518 This is an abstract class and not meant to be used by itself. There are two 519 subclasses - `DiscreteTreeNode` and `ContinuousTreeNode`, which need no 520 explanation. If you don't wish to deal with figuring out which node to use, 521 the `from_tree` method is provided. 522 523 Parameters 524 ---------- 525 label : int 526 The label of the tree node, can be looked up in the original tree. 527 square : Square 528 The square the represents the tree node. 529 tree : TreeAdapter 530 The tree model that the node belongs to. 531 children : tuple of TreeNode, optional, default is empty tuple 532 All the children that belong to this node. 533 534 """ 535 536 def __init__(self, label, square, tree, children=()): 537 self.label = label 538 self.square = square 539 self.tree = tree 540 self.children = children 541 self.parent = None 542 # Properties that should update the associated graphics item 543 self.__graphics_item = None 544 self.__target_class_index = None 545 546 @property 547 def graphics_item(self): 548 return self.__graphics_item 549 550 @graphics_item.setter 551 def graphics_item(self, graphics_item): 552 self.__graphics_item = graphics_item 553 self._update_graphics_item() 554 555 @property 556 def target_class_index(self): 557 return self.__target_class_index 558 559 @target_class_index.setter 560 def target_class_index(self, target_class_index): 561 self.__target_class_index = target_class_index 562 self._update_graphics_item() 563 564 def _update_graphics_item(self): 565 if self.__graphics_item is not None: 566 self.__graphics_item.update() 567 568 @classmethod 569 def from_tree(cls, label, square, tree, children=()): 570 """Construct the appropriate type of node from the given tree.""" 571 if tree.domain.has_discrete_class: 572 node = DiscreteTreeNode 573 else: 574 node = ContinuousTreeNode 575 return node(label, square, tree, children) 576 577 @property 578 @abstractmethod 579 def color(self): 580 """Get the color of the node. 581 582 Returns 583 ------- 584 QColor 585 586 """ 587 588 @property 589 @abstractmethod 590 def tooltip(self): 591 """get the tooltip for the node. 592 593 Returns 594 ------- 595 str 596 597 """ 598 599 @property 600 def color_palette(self): 601 return self.tree.domain.class_var.palette 602 603 def _rules_str(self): 604 rules = self.tree.rules(self.label) 605 if rules: 606 if isinstance(rules[0], Rule): 607 sorted_rules = sorted(rules[:-1], key=lambda rule: rule.attr_name) 608 return '<br>'.join(str(rule) for rule in sorted_rules) + \ 609 '<br><b>%s</b>' % rules[-1] 610 else: 611 return '<br>'.join(to_html(rule) for rule in rules) 612 else: 613 return '' 614 615 616class DiscreteTreeNode(TreeNode): 617 """Discrete tree node containing methods for tree visualisations. 618 619 Colors are defined by the data domain, and possible colorings are different 620 target classes. 621 622 """ 623 624 @property 625 def color(self): 626 distribution = self.tree.get_distribution(self.label)[0] 627 total = np.sum(distribution) 628 629 if self.target_class_index: 630 p = distribution[self.target_class_index - 1] / total 631 color = self.color_palette[self.target_class_index - 1] 632 color = color.lighter(200 - 100 * p) 633 else: 634 modus = np.argmax(distribution) 635 p = distribution[modus] / (total or 1) 636 color = self.color_palette[int(modus)] 637 color = color.lighter(400 - 300 * p) 638 return color 639 640 @property 641 def tooltip(self): 642 distribution = self.tree.get_distribution(self.label)[0] 643 total = int(np.sum(distribution)) 644 if self.target_class_index: 645 samples = distribution[self.target_class_index - 1] 646 text = '' 647 else: 648 modus = np.argmax(distribution) 649 samples = distribution[modus] 650 text = self.tree.domain.class_vars[0].values[modus] + \ 651 '<br>' 652 ratio = samples / np.sum(distribution) 653 654 rules_str = self._rules_str() 655 splitting_attr = self.tree.attribute(self.label) 656 657 return '<p>' \ 658 + text \ 659 + '{}/{} samples ({:2.3f}%)'.format( 660 int(samples), total, ratio * 100) \ 661 + '<hr>' \ 662 + ('Split by ' + splitting_attr.name 663 if not self.tree.is_leaf(self.label) else '') \ 664 + ('<br><br>' if rules_str and not self.tree.is_leaf(self.label) else '') \ 665 + rules_str \ 666 + '</p>' 667 668 669class ContinuousTreeNode(TreeNode): 670 """Continuous tree node containing methods for tree visualisations. 671 672 There are three modes of coloring: 673 - None, which is a solid color 674 - Mean, which colors nodes w.r.t. the mean value of all the 675 instances that belong to a given node. 676 - Standard deviation, which colors nodes w.r.t the standard deviation of 677 all the instances that belong to a given node. 678 679 """ 680 681 COLOR_NONE, COLOR_MEAN, COLOR_STD = range(3) 682 COLOR_METHODS = { 683 'None': COLOR_NONE, 684 'Mean': COLOR_MEAN, 685 'Standard deviation': COLOR_STD, 686 } 687 688 @property 689 def color(self): 690 if self.target_class_index is self.COLOR_MEAN: 691 return self._color_mean() 692 elif self.target_class_index is self.COLOR_STD: 693 return self._color_var() 694 else: 695 return QColor(255, 255, 255) 696 697 def _color_mean(self): 698 """Color the nodes with respect to the mean of instances inside.""" 699 min_mean = np.min(self.tree.instances.Y) 700 max_mean = np.max(self.tree.instances.Y) 701 instances = self.tree.get_instances_in_nodes(self.label) 702 mean = np.mean(instances.Y) 703 return self.color_palette.value_to_qcolor( 704 mean, low=min_mean, high=max_mean) 705 706 def _color_var(self): 707 """Color the nodes with respect to the variance of instances inside.""" 708 min_std, max_std = 0, np.std(self.tree.instances.Y) 709 instances = self.tree.get_instances_in_nodes(self.label) 710 std = np.std(instances.Y) 711 return self.color_palette.value_to_qcolor( 712 std, low=min_std, high=max_std) 713 714 @property 715 def tooltip(self): 716 num_samples = self.tree.num_samples(self.label) 717 718 instances = self.tree.get_instances_in_nodes(self.label) 719 mean = np.mean(instances.Y) 720 std = np.std(instances.Y) 721 722 rules_str = self._rules_str() 723 splitting_attr = self.tree.attribute(self.label) 724 725 return '<p>Mean: {:2.3f}'.format(mean) \ 726 + '<br>Standard deviation: {:2.3f}'.format(std) \ 727 + '<br>{} samples'.format(num_samples) \ 728 + '<hr>' \ 729 + ('Split by ' + splitting_attr.name 730 if not self.tree.is_leaf(self.label) else '') \ 731 + ('<br><br>' if rules_str and not self.tree.is_leaf(self.label) else '') \ 732 + rules_str \ 733 + '</p>' 734 735 736class PythagorasTree: 737 """Pythagoras tree. 738 739 Contains all the logic that converts a given tree adapter to a tree 740 consisting of node classes. 741 742 Parameters 743 ---------- 744 weight_adjustment : callable 745 The function to be used to adjust child weights 746 747 """ 748 749 def __init__(self, weight_adjustment=lambda x: x): 750 self.adjust_weight = weight_adjustment 751 # store the previous angles of each square children so that slopes can 752 # be computed 753 self._slopes = defaultdict(list) 754 755 def pythagoras_tree(self, tree, node, square): 756 """Get the Pythagoras tree representation in a graph like view. 757 758 Constructs a graph using TreeNode into a tree structure. Each node in 759 graph contains the information required to plot the the tree. 760 761 Parameters 762 ---------- 763 tree : TreeAdapter 764 A tree adapter instance where the original tree is stored. 765 node : int 766 The node label, the root node is denoted with 0. 767 square : Square 768 The initial square which will represent the root of the tree. 769 770 Returns 771 ------- 772 TreeNode 773 The root node which contains the rest of the tree. 774 775 """ 776 # make sure to clear out any old slopes if we are drawing a new tree 777 if node == tree.root: 778 self._slopes.clear() 779 780 # Calculate the adjusted child weights for the node children 781 child_weights = [self.adjust_weight(tree.weight(c)) 782 for c in tree.children(node)] 783 total_weight = sum(child_weights) 784 normalized_child_weights = [cw / total_weight for cw in child_weights] 785 786 children = tuple( 787 self._compute_child(tree, square, child, cw) 788 for child, cw in zip(tree.children(node), normalized_child_weights) 789 ) 790 # make sure to pass a reference to parent to each child 791 obj = TreeNode.from_tree(node, square, tree, children) 792 # mutate the existing data stored in the created tree node 793 for c in children: 794 c.parent = obj 795 return obj 796 797 def _compute_child(self, tree, parent_square, node, weight): 798 """Compute all the properties for a single child. 799 800 Parameters 801 ---------- 802 tree : TreeAdapter 803 A tree adapter instance where the original tree is stored. 804 parent_square : Square 805 The parent square of the given child. 806 node : int 807 The node label of the child. 808 weight : float 809 The weight of the node relative to its parent e.g. two children in 810 relation 3:1 should have weights .75 and .25, respectively. 811 812 Returns 813 ------- 814 TreeNode 815 The tree node representation of the given child with the computed 816 subtree. 817 818 """ 819 # the angle of the child from its parent 820 alpha = weight * pi 821 # the child side length 822 length = parent_square.length * sin(alpha / 2) 823 # the sum of the previous anlges 824 prev_angles = sum(self._slopes[parent_square]) 825 826 center = self._compute_center( 827 parent_square, length, alpha, prev_angles 828 ) 829 # the angle of the square is dependent on the parent, the current 830 # angle and the previous angles. Subtract PI/2 so it starts drawing at 831 # 0rads. 832 angle = parent_square.angle - pi / 2 + prev_angles + alpha / 2 833 square = Square(center, length, angle) 834 835 self._slopes[parent_square].append(alpha) 836 837 return self.pythagoras_tree(tree, node, square) 838 839 def _compute_center(self, initial_square, length, alpha, base_angle=0): 840 """Compute the central point of a child square. 841 842 Parameters 843 ---------- 844 initial_square : Square 845 The parent square representation where we will be drawing from. 846 length : float 847 The length of the side of the new square (the one we are computing 848 the center for). 849 alpha : float 850 The angle that defines the size of our new square (in radians). 851 base_angle : float, optional 852 If the square we want to find the center for is not the first child 853 i.e. its edges does not touch the base square, then we need the 854 initial angle that will act as the starting point for the new 855 square. 856 857 Returns 858 ------- 859 Point 860 The central point to the new square. 861 862 """ 863 parent_center, parent_length, parent_angle = initial_square 864 # get the point on the square side that will be the rotation origin 865 t0 = self._get_point_on_square_edge( 866 parent_center, parent_length, parent_angle) 867 # get the edge point that we will rotate around t0 868 square_diagonal_length = sqrt(2 * parent_length ** 2) 869 edge = self._get_point_on_square_edge( 870 parent_center, square_diagonal_length, parent_angle - pi / 4) 871 # if the new square is not the first child, we need to rotate the edge 872 if base_angle != 0: 873 edge = self._rotate_point(edge, t0, base_angle) 874 875 # rotate the edge point to the correct spot 876 t1 = self._rotate_point(edge, t0, alpha) 877 878 # calculate the middle point between the rotated point and edge 879 t2 = Point((t1.x + edge.x) / 2, (t1.y + edge.y) / 2) 880 # calculate the slope of the new square 881 slope = parent_angle - pi / 2 + alpha / 2 882 # using this data, we can compute the square center 883 return self._get_point_on_square_edge(t2, length, slope + base_angle) 884 885 @staticmethod 886 def _rotate_point(point, around, alpha): 887 """Rotate a point around another point by some angle. 888 889 Parameters 890 ---------- 891 point : Point 892 The point to rotate. 893 around : Point 894 The point to perform rotation around. 895 alpha : float 896 The angle to rotate by (in radians). 897 898 Returns 899 ------- 900 Point: 901 The rotated point. 902 903 """ 904 temp = Point(point.x - around.x, point.y - around.y) 905 temp = Point( 906 temp.x * cos(alpha) - temp.y * sin(alpha), 907 temp.x * sin(alpha) + temp.y * cos(alpha) 908 ) 909 return Point(temp.x + around.x, temp.y + around.y) 910 911 @staticmethod 912 def _get_point_on_square_edge(center, length, angle): 913 """Calculate the central point on the drawing edge of the given square. 914 915 Parameters 916 ---------- 917 center : Point 918 The square center point. 919 length : float 920 The square side length. 921 angle : float 922 The angle of the square. 923 924 Returns 925 ------- 926 Point 927 A point on the center of the drawing edge of the given square. 928 929 """ 930 return Point( 931 center.x + length / 2 * cos(angle), 932 center.y + length / 2 * sin(angle) 933 ) 934