1"""
2Pythagoras tree viewer for visualizing tree structures.
3
4The pythagoras tree viewer widget is a widget that can be plugged into any
5existing widget given a tree adapter instance. It is simply a canvas that takes
6and input tree adapter and takes care of all the drawing.
7
8Types
9-----
10Square : namedtuple (center, length, angle)
11    Since Pythagoras trees deal only with squares (they also deal with
12    rectangles in the generalized form, but are completely unreadable), this
13    is what all the squares are stored as.
14Point : namedtuple (x, y)
15    Self exaplanatory.
16
17"""
18from abc import ABCMeta, abstractmethod
19from collections import namedtuple, defaultdict, deque
20from math import pi, sqrt, cos, sin, degrees
21
22import numpy as np
23from AnyQt.QtCore import Qt, QTimer, QRectF, QSizeF
24from AnyQt.QtGui import QColor, QPen
25from AnyQt.QtWidgets import (
26    QSizePolicy, QGraphicsItem, QGraphicsRectItem, QGraphicsWidget, QStyle
27)
28
29from Orange.widgets.utils import to_html
30from Orange.widgets.visualize.utils.tree.rules import Rule
31from Orange.widgets.visualize.utils.tree.treeadapter import TreeAdapter
32
33# z index range, increase if needed
34Z_STEP = 5000000
35
36Square = namedtuple('Square', ['center', 'length', 'angle'])
37Point = namedtuple('Point', ['x', 'y'])
38
39
40class PythagorasTreeViewer(QGraphicsWidget):
41    """Pythagoras tree viewer graphics widget.
42
43    Examples
44    --------
45    >>> from Orange.widgets.visualize.utils.tree.treeadapter import (
46    ...     TreeAdapter
47    ... )
48    Pass tree through constructor.
49    >>> tree_view = PythagorasTreeViewer(parent=scene, adapter=tree_adapter)
50
51    Pass tree later through method.
52    >>> tree_adapter = TreeAdapter()
53    >>> scene = QGraphicsScene()
54    This is where the magic happens
55    >>> tree_view = PythagorasTreeViewer(parent=scene)
56    >>> tree_view.set_tree(tree_adapter)
57
58    Both these examples set the appropriate tree and add all the squares to the
59    widget instance.
60
61    Parameters
62    ----------
63    parent : QGraphicsItem, optional
64        The parent object that the graphics widget belongs to. Should be a
65        scene.
66    adapter : TreeAdapter, optional
67        Any valid tree adapter instance.
68    interacitive : bool, optional
69        Specify whether the widget should have an interactive display. This
70        means special hover effects, selectable boxes. Default is true.
71
72    Notes
73    -----
74    .. note:: The class contains two clear methods: `clear` and `clear_tree`.
75        Each has  their own use.
76        `clear_tree` will clear out the tree and remove any graphics items.
77        `clear` will, on the other hand, clear everything, all settings
78        (tooltip and color calculation functions.
79
80        This is useful because when we want to change the size calculation of
81        the Pythagora tree, we just want to clear the scene and it would be
82        inconvenient to have to set color and tooltip functions again.
83        On the other hand, when we want to draw a brand new tree, it is best
84        to clear all settings to avoid any strange bugs - we start with a blank
85        slate.
86
87    """
88
89    def __init__(self, parent=None, adapter=None, depth_limit=0, padding=0,
90                 **kwargs):
91        super().__init__()
92        self.parent = parent
93
94        # In case a tree was passed, it will be handled at the end of init
95        self.tree_adapter = None
96        self.root = None
97
98        self._depth_limit = depth_limit
99        self._interactive = kwargs.get('interactive', True)
100        self._padding = padding
101
102        self._square_objects = {}
103        self._drawn_nodes = deque()
104        self._frontier = deque()
105
106        self._target_class_index = 0
107
108        self.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
109
110        # If a tree was passed in the constructor, set and draw the tree
111        if adapter is not None:
112            self.set_tree(
113                adapter,
114                target_class_index=kwargs.get('target_class_index'),
115                weight_adjustment=kwargs.get('weight_adjustment'),
116            )
117            # Since `set_tree` needs to draw the entire tree to be visualized
118            # properly, it overrides the `depth_limit` to max. If we specified
119            # the depth limit, however, apply that afterwards-
120            self.set_depth_limit(depth_limit)
121
122    def set_tree(self, tree_adapter, weight_adjustment=lambda x: x,
123                 target_class_index=0):
124        """Pass in a new tree adapter instance and perform updates to canvas.
125
126        Parameters
127        ----------
128        tree_adapter : TreeAdapter
129            The new tree adapter that is to be used.
130        weight_adjustment : callable
131            A weight adjustment function that with signature `x -> x`
132        target_class_index : int
133
134        Returns
135        -------
136
137        """
138        self.clear_tree()
139        self.tree_adapter = tree_adapter
140        self.weight_adjustment = weight_adjustment
141
142        if self.tree_adapter is not None:
143            self.root = self._calculate_tree(self.tree_adapter, self.weight_adjustment)
144            self.set_depth_limit(tree_adapter.max_depth)
145            self.target_class_changed(target_class_index)
146            self._draw_tree(self.root)
147
148    def set_size_calc(self, weight_adjustment):
149        """Set the weight adjustment on the tree. Redraws the whole tree."""
150        # Since we have to redraw the whole tree anyways, just call `set_tree`
151        self.weight_adjustment = weight_adjustment
152        self.set_tree(self.tree_adapter, self.weight_adjustment,
153                      self._target_class_index)
154
155    def set_depth_limit(self, depth):
156        """Update the drawing depth limit.
157
158        The drawing stops when the depth is GT the limit. This means that at
159        depth 0, the root node will be drawn.
160
161        Parameters
162        ----------
163        depth : int
164            The maximum depth at which the nodes can still be drawn.
165
166        Returns
167        -------
168
169        """
170        self._depth_limit = depth
171        self._draw_tree(self.root)
172
173    def target_class_changed(self, target_class_index=0):
174        """When the target class has changed, perform appropriate updates."""
175        self._target_class_index = target_class_index
176
177        def _recurse(node):
178            node.target_class_index = target_class_index
179            for child in node.children:
180                _recurse(child)
181
182        _recurse(self.root)
183
184    def tooltip_changed(self, tooltip_enabled):
185        """Set the tooltip to the appropriate value on each square."""
186        for square in self._squares():
187            if tooltip_enabled:
188                square.setToolTip(square.tree_node.tooltip)
189            else:
190                square.setToolTip(None)
191
192    def clear(self):
193        """Clear the entire widget state."""
194        self.clear_tree()
195        self._target_class_index = 0
196
197    def clear_tree(self):
198        """Clear only the tree, keeping tooltip and color functions."""
199        self.tree_adapter = None
200        self.root = None
201        self._clear_scene()
202
203    def _calculate_tree(self, tree_adapter, weight_adjustment):
204        """Actually calculate the tree squares"""
205        tree_builder = PythagorasTree(weight_adjustment=weight_adjustment)
206        return tree_builder.pythagoras_tree(
207            tree_adapter, tree_adapter.root, Square(Point(0, 0), 200, -pi / 2)
208        )
209
210    def _draw_tree(self, root):
211        """Efficiently draw the tree with regards to the depth.
212
213        If we used a recursive approach, the tree would have to be redrawn
214        every time the depth changed, which is very impractical for larger
215        trees, since drawing can take a long time.
216
217        Using an iterative approach, we use two queues to represent the tree
218        frontier and the nodes that have already been drawn. We also store the
219        current depth. This way, when the max depth is increased, we do not
220        redraw the entire tree but only iterate through the frontier and draw
221        those nodes, and update the frontier accordingly.
222        When decreasing the max depth, we reverse the process, we clear the
223        frontier, and remove nodes from the drawn nodes, and append those with
224        depth max_depth + 1 to the frontier, so the frontier doesn't get
225        cluttered.
226
227        Parameters
228        ----------
229        root : TreeNode
230            The root tree node.
231
232        Returns
233        -------
234
235        """
236        if self.root is None:
237            return
238        # if this is the first time drawing the tree begin with root
239        if not self._drawn_nodes:
240            self._frontier.appendleft((0, root))
241        # if the depth was decreased, we can clear the frontier, otherwise
242        # frontier gets cluttered with non-frontier nodes
243        was_decreased = self._depth_was_decreased()
244        if was_decreased:
245            self._frontier.clear()
246        # remove nodes from drawn and add to frontier if limit is decreased
247        while self._drawn_nodes:
248            depth, node = self._drawn_nodes.pop()
249            # check if the node is in the allowed limit
250            if depth <= self._depth_limit:
251                self._drawn_nodes.append((depth, node))
252                break
253            if depth == self._depth_limit + 1:
254                self._frontier.appendleft((depth, node))
255
256            if node.label in self._square_objects:
257                self._square_objects[node.label].hide()
258
259        # add nodes to drawn and remove from frontier if limit is increased
260        while self._frontier:
261            depth, node = self._frontier.popleft()
262            # check if the depth of the node is outside the allowed limit
263            if depth > self._depth_limit:
264                self._frontier.appendleft((depth, node))
265                break
266            self._drawn_nodes.append((depth, node))
267            self._frontier.extend((depth + 1, c) for c in node.children)
268
269            node.target_class_index = self._target_class_index
270            if node.label in self._square_objects:
271                self._square_objects[node.label].show()
272            else:
273                square_obj = InteractiveSquareGraphicsItem \
274                    if self._interactive else SquareGraphicsItem
275                self._square_objects[node.label] = square_obj(
276                    node, parent=self, zvalue=depth)
277
278    def _depth_was_decreased(self):
279        if not self._drawn_nodes:
280            return False
281        # checks if the max depth was increased from the last change
282        depth, node = self._drawn_nodes.pop()
283        self._drawn_nodes.append((depth, node))
284        # if the right most node in drawn nodes has appropriate depth, it must
285        # have been increased
286        return depth > self._depth_limit
287
288    def _squares(self):
289        return [node.graphics_item for _, node in self._drawn_nodes]
290
291    def _clear_scene(self):
292        for square in self._squares():
293            self.scene().removeItem(square)
294        self._frontier.clear()
295        self._drawn_nodes.clear()
296        self._square_objects.clear()
297
298    def boundingRect(self):
299        return self.childrenBoundingRect().adjusted(
300            -self._padding, -self._padding, self._padding, self._padding)
301
302    def sizeHint(self, size_hint, size_constraint=None, *args, **kwargs):
303        return self.boundingRect().size() + QSizeF(self._padding, self._padding)
304
305
306class SquareGraphicsItem(QGraphicsRectItem):
307    """Square Graphics Item.
308
309    Square component to draw as components for the non-interactive Pythagoras
310    tree.
311
312    Parameters
313    ----------
314    tree_node : TreeNode
315        The tree node the square represents.
316    parent : QGraphicsItem
317
318    """
319
320    def __init__(self, tree_node, parent=None, **kwargs):
321        self.tree_node = tree_node
322        super().__init__(self._get_rect_attributes(), parent)
323        self.tree_node.graphics_item = self
324
325        self.setTransformOriginPoint(self.boundingRect().center())
326        self.setRotation(degrees(self.tree_node.square.angle))
327
328        self.setBrush(kwargs.get('brush', QColor('#297A1F')))
329        # The border should be invariant to scaling
330        pen = QPen(QColor(Qt.black))
331        pen.setWidthF(0.75)
332        pen.setCosmetic(True)
333        self.setPen(pen)
334
335        self.setAcceptHoverEvents(True)
336        self.setZValue(kwargs.get('zvalue', 0))
337        self.z_step = Z_STEP
338
339        # calculate the correct z values based on the parent
340        if self.tree_node.parent != TreeAdapter.ROOT_PARENT:
341            p = self.tree_node.parent
342            # override root z step
343            num_children = len(p.children)
344            own_index = [1 if c.label == self.tree_node.label else 0
345                         for c in p.children].index(1)
346
347            self.z_step = int(p.graphics_item.z_step / num_children)
348            base_z = p.graphics_item.zValue()
349
350            self.setZValue(base_z + own_index * self.z_step)
351
352    def update(self):
353        self.setBrush(self.tree_node.color)
354        return super().update()
355
356    def _get_rect_attributes(self):
357        """Get the rectangle attributes requrired to draw item.
358
359        Compute the QRectF that a QGraphicsRect needs to be rendered with the
360        data passed down in the constructor.
361
362        """
363        center, length, _ = self.tree_node.square
364        x = center[0] - length / 2
365        y = center[1] - length / 2
366        return QRectF(x, y, length, length)
367
368
369class InteractiveSquareGraphicsItem(SquareGraphicsItem):
370    """Interactive square graphics items.
371
372    This is different from the base square graphics item so that it is
373    selectable, and it can handle and react to hover events (highlight and
374    focus own branch).
375
376    Parameters
377    ----------
378    tree_node : TreeNode
379        The tree node the square represents.
380    parent : QGraphicsItem
381
382    """
383
384    timer = QTimer()
385
386    MAX_OPACITY = 1.
387    SELECTION_OPACITY = .5
388    HOVER_OPACITY = .1
389
390    def __init__(self, tree_node, parent=None, **kwargs):
391        super().__init__(tree_node, parent, **kwargs)
392        self.setFlag(QGraphicsItem.ItemIsSelectable, True)
393
394        self.initial_zvalue = self.zValue()
395        # The max z value changes if any item is selected
396        self.any_selected = False
397
398        self.timer.setSingleShot(True)
399
400    def update(self):
401        self.setToolTip(self.tree_node.tooltip)
402        return super().update()
403
404    def hoverEnterEvent(self, event):
405        self.timer.stop()
406
407        def fnc(graphics_item):
408            graphics_item.setZValue(Z_STEP)
409            if self.any_selected:
410                if graphics_item.isSelected():
411                    opacity = self.MAX_OPACITY
412                else:
413                    opacity = self.SELECTION_OPACITY
414            else:
415                opacity = self.MAX_OPACITY
416            graphics_item.setOpacity(opacity)
417
418        def other_fnc(graphics_item):
419            if graphics_item.isSelected():
420                opacity = self.MAX_OPACITY
421            else:
422                opacity = self.HOVER_OPACITY
423            graphics_item.setOpacity(opacity)
424            graphics_item.setZValue(self.initial_zvalue)
425
426        self._propagate_z_values(self, fnc, other_fnc)
427
428    def hoverLeaveEvent(self, event):
429
430        def fnc(graphics_item):
431            # No need to set opacity in this branch since it was just selected
432            # and had the max value
433            graphics_item.setZValue(self.initial_zvalue)
434
435        def other_fnc(graphics_item):
436            if self.any_selected:
437                if graphics_item.isSelected():
438                    opacity = self.MAX_OPACITY
439                else:
440                    opacity = self.SELECTION_OPACITY
441            else:
442                opacity = self.MAX_OPACITY
443            graphics_item.setOpacity(opacity)
444
445        self.timer.timeout.connect(
446            lambda: self._propagate_z_values(self, fnc, other_fnc))
447
448        self.timer.start(250)
449
450    def _propagate_z_values(self, graphics_item, fnc, other_fnc):
451        self._propagate_to_children(graphics_item, fnc)
452        self._propagate_to_parents(graphics_item, fnc, other_fnc)
453
454    def _propagate_to_children(self, graphics_item, fnc):
455        # propagate function that handles graphics item to appropriate children
456        fnc(graphics_item)
457        for c in graphics_item.tree_node.children:
458            self._propagate_to_children(c.graphics_item, fnc)
459
460    def _propagate_to_parents(self, graphics_item, fnc, other_fnc):
461        # propagate function that handles graphics item to appropriate parents
462        if graphics_item.tree_node.parent != TreeAdapter.ROOT_PARENT:
463            parent = graphics_item.tree_node.parent.graphics_item
464            # handle the non relevant children nodes
465            for c in parent.tree_node.children:
466                if c != graphics_item.tree_node:
467                    self._propagate_to_children(c.graphics_item, other_fnc)
468            # handle the parent node
469            fnc(parent)
470            # propagate up the tree
471            self._propagate_to_parents(parent, fnc, other_fnc)
472
473    def mouseDoubleClickEvent(self, event):
474        self.tree_node.tree.reverse_children(self.tree_node.label)
475        p = self.parentWidget()  # PythagorasTreeViewer
476        p.set_tree(p.tree_adapter, p.weight_adjustment, self.tree_node.target_class_index)
477        widget = p.parent  # OWPythagorasTree
478        widget._update_main_area()
479
480    def selection_changed(self):
481        """Handle selection changed."""
482        self.any_selected = len(self.scene().selectedItems()) > 0
483        if self.any_selected:
484            if self.isSelected():
485                self.setOpacity(self.MAX_OPACITY)
486            else:
487                if self.opacity() != self.HOVER_OPACITY:
488                    self.setOpacity(self.SELECTION_OPACITY)
489        else:
490            self.setGraphicsEffect(None)
491            self.setOpacity(self.MAX_OPACITY)
492
493    def paint(self, painter, option, widget=None):
494        # Override the default selected appearance
495        if self.isSelected():
496            option.state ^= QStyle.State_Selected
497            rect = self.rect()
498            # this must render before overlay due to order in which it's drawn
499            super().paint(painter, option, widget)
500            painter.save()
501            pen = QPen(QColor(Qt.black))
502            pen.setWidthF(2)
503            pen.setCosmetic(True)
504            pen.setJoinStyle(Qt.MiterJoin)
505            painter.setPen(pen)
506            painter.drawRect(rect)
507            painter.restore()
508        else:
509            super().paint(painter, option, widget)
510
511
512class TreeNode(metaclass=ABCMeta):
513    """A tree node meant to be used in conjuction with graphics items.
514
515    The tree node contains methods that are very general to any tree
516    visualisation, containing methods for the node color and tooltip.
517
518    This is an abstract class and not meant to be used by itself. There are two
519    subclasses - `DiscreteTreeNode` and `ContinuousTreeNode`, which need no
520    explanation. If you don't wish to deal with figuring out which node to use,
521    the `from_tree` method is provided.
522
523    Parameters
524    ----------
525    label : int
526        The label of the tree node, can be looked up in the original tree.
527    square : Square
528        The square the represents the tree node.
529    tree : TreeAdapter
530        The tree model that the node belongs to.
531    children : tuple of TreeNode, optional, default is empty tuple
532        All the children that belong to this node.
533
534    """
535
536    def __init__(self, label, square, tree, children=()):
537        self.label = label
538        self.square = square
539        self.tree = tree
540        self.children = children
541        self.parent = None
542        # Properties that should update the associated graphics item
543        self.__graphics_item = None
544        self.__target_class_index = None
545
546    @property
547    def graphics_item(self):
548        return self.__graphics_item
549
550    @graphics_item.setter
551    def graphics_item(self, graphics_item):
552        self.__graphics_item = graphics_item
553        self._update_graphics_item()
554
555    @property
556    def target_class_index(self):
557        return self.__target_class_index
558
559    @target_class_index.setter
560    def target_class_index(self, target_class_index):
561        self.__target_class_index = target_class_index
562        self._update_graphics_item()
563
564    def _update_graphics_item(self):
565        if self.__graphics_item is not None:
566            self.__graphics_item.update()
567
568    @classmethod
569    def from_tree(cls, label, square, tree, children=()):
570        """Construct the appropriate type of node from the given tree."""
571        if tree.domain.has_discrete_class:
572            node = DiscreteTreeNode
573        else:
574            node = ContinuousTreeNode
575        return node(label, square, tree, children)
576
577    @property
578    @abstractmethod
579    def color(self):
580        """Get the color of the node.
581
582        Returns
583        -------
584        QColor
585
586        """
587
588    @property
589    @abstractmethod
590    def tooltip(self):
591        """get the tooltip for the node.
592
593        Returns
594        -------
595        str
596
597        """
598
599    @property
600    def color_palette(self):
601        return self.tree.domain.class_var.palette
602
603    def _rules_str(self):
604        rules = self.tree.rules(self.label)
605        if rules:
606            if isinstance(rules[0], Rule):
607                sorted_rules = sorted(rules[:-1], key=lambda rule: rule.attr_name)
608                return '<br>'.join(str(rule) for rule in sorted_rules) + \
609                       '<br><b>%s</b>' % rules[-1]
610            else:
611                return '<br>'.join(to_html(rule) for rule in rules)
612        else:
613            return ''
614
615
616class DiscreteTreeNode(TreeNode):
617    """Discrete tree node containing methods for tree visualisations.
618
619    Colors are defined by the data domain, and possible colorings are different
620    target classes.
621
622    """
623
624    @property
625    def color(self):
626        distribution = self.tree.get_distribution(self.label)[0]
627        total = np.sum(distribution)
628
629        if self.target_class_index:
630            p = distribution[self.target_class_index - 1] / total
631            color = self.color_palette[self.target_class_index - 1]
632            color = color.lighter(200 - 100 * p)
633        else:
634            modus = np.argmax(distribution)
635            p = distribution[modus] / (total or 1)
636            color = self.color_palette[int(modus)]
637            color = color.lighter(400 - 300 * p)
638        return color
639
640    @property
641    def tooltip(self):
642        distribution = self.tree.get_distribution(self.label)[0]
643        total = int(np.sum(distribution))
644        if self.target_class_index:
645            samples = distribution[self.target_class_index - 1]
646            text = ''
647        else:
648            modus = np.argmax(distribution)
649            samples = distribution[modus]
650            text = self.tree.domain.class_vars[0].values[modus] + \
651                '<br>'
652        ratio = samples / np.sum(distribution)
653
654        rules_str = self._rules_str()
655        splitting_attr = self.tree.attribute(self.label)
656
657        return '<p>' \
658            + text \
659            + '{}/{} samples ({:2.3f}%)'.format(
660                int(samples), total, ratio * 100) \
661            + '<hr>' \
662            + ('Split by ' + splitting_attr.name
663               if not self.tree.is_leaf(self.label) else '') \
664            + ('<br><br>' if rules_str and not self.tree.is_leaf(self.label) else '') \
665            + rules_str \
666            + '</p>'
667
668
669class ContinuousTreeNode(TreeNode):
670    """Continuous tree node containing methods for tree visualisations.
671
672    There are three modes of coloring:
673     - None, which is a solid color
674     - Mean, which colors nodes w.r.t. the mean value of all the
675       instances that belong to a given node.
676     - Standard deviation, which colors nodes w.r.t the standard deviation of
677       all the instances that belong to a given node.
678
679    """
680
681    COLOR_NONE, COLOR_MEAN, COLOR_STD = range(3)
682    COLOR_METHODS = {
683        'None': COLOR_NONE,
684        'Mean': COLOR_MEAN,
685        'Standard deviation': COLOR_STD,
686    }
687
688    @property
689    def color(self):
690        if self.target_class_index is self.COLOR_MEAN:
691            return self._color_mean()
692        elif self.target_class_index is self.COLOR_STD:
693            return self._color_var()
694        else:
695            return QColor(255, 255, 255)
696
697    def _color_mean(self):
698        """Color the nodes with respect to the mean of instances inside."""
699        min_mean = np.min(self.tree.instances.Y)
700        max_mean = np.max(self.tree.instances.Y)
701        instances = self.tree.get_instances_in_nodes(self.label)
702        mean = np.mean(instances.Y)
703        return self.color_palette.value_to_qcolor(
704            mean, low=min_mean, high=max_mean)
705
706    def _color_var(self):
707        """Color the nodes with respect to the variance of instances inside."""
708        min_std, max_std = 0, np.std(self.tree.instances.Y)
709        instances = self.tree.get_instances_in_nodes(self.label)
710        std = np.std(instances.Y)
711        return self.color_palette.value_to_qcolor(
712            std, low=min_std, high=max_std)
713
714    @property
715    def tooltip(self):
716        num_samples = self.tree.num_samples(self.label)
717
718        instances = self.tree.get_instances_in_nodes(self.label)
719        mean = np.mean(instances.Y)
720        std = np.std(instances.Y)
721
722        rules_str = self._rules_str()
723        splitting_attr = self.tree.attribute(self.label)
724
725        return '<p>Mean: {:2.3f}'.format(mean) \
726            + '<br>Standard deviation: {:2.3f}'.format(std) \
727            + '<br>{} samples'.format(num_samples) \
728            + '<hr>' \
729            + ('Split by ' + splitting_attr.name
730               if not self.tree.is_leaf(self.label) else '') \
731            + ('<br><br>' if rules_str and not self.tree.is_leaf(self.label) else '') \
732            + rules_str \
733            + '</p>'
734
735
736class PythagorasTree:
737    """Pythagoras tree.
738
739    Contains all the logic that converts a given tree adapter to a tree
740    consisting of node classes.
741
742    Parameters
743    ----------
744    weight_adjustment : callable
745        The function to be used to adjust child weights
746
747    """
748
749    def __init__(self, weight_adjustment=lambda x: x):
750        self.adjust_weight = weight_adjustment
751        # store the previous angles of each square children so that slopes can
752        # be computed
753        self._slopes = defaultdict(list)
754
755    def pythagoras_tree(self, tree, node, square):
756        """Get the Pythagoras tree representation in a graph like view.
757
758        Constructs a graph using TreeNode into a tree structure. Each node in
759        graph contains the information required to plot the the tree.
760
761        Parameters
762        ----------
763        tree : TreeAdapter
764            A tree adapter instance where the original tree is stored.
765        node : int
766            The node label, the root node is denoted with 0.
767        square : Square
768            The initial square which will represent the root of the tree.
769
770        Returns
771        -------
772        TreeNode
773            The root node which contains the rest of the tree.
774
775        """
776        # make sure to clear out any old slopes if we are drawing a new tree
777        if node == tree.root:
778            self._slopes.clear()
779
780        # Calculate the adjusted child weights for the node children
781        child_weights = [self.adjust_weight(tree.weight(c))
782                         for c in tree.children(node)]
783        total_weight = sum(child_weights)
784        normalized_child_weights = [cw / total_weight for cw in child_weights]
785
786        children = tuple(
787            self._compute_child(tree, square, child, cw)
788            for child, cw in zip(tree.children(node), normalized_child_weights)
789        )
790        # make sure to pass a reference to parent to each child
791        obj = TreeNode.from_tree(node, square, tree, children)
792        # mutate the existing data stored in the created tree node
793        for c in children:
794            c.parent = obj
795        return obj
796
797    def _compute_child(self, tree, parent_square, node, weight):
798        """Compute all the properties for a single child.
799
800        Parameters
801        ----------
802        tree : TreeAdapter
803            A tree adapter instance where the original tree is stored.
804        parent_square : Square
805            The parent square of the given child.
806        node : int
807            The node label of the child.
808        weight : float
809            The weight of the node relative to its parent e.g. two children in
810            relation 3:1 should have weights .75 and .25, respectively.
811
812        Returns
813        -------
814        TreeNode
815            The tree node representation of the given child with the computed
816            subtree.
817
818        """
819        # the angle of the child from its parent
820        alpha = weight * pi
821        # the child side length
822        length = parent_square.length * sin(alpha / 2)
823        # the sum of the previous anlges
824        prev_angles = sum(self._slopes[parent_square])
825
826        center = self._compute_center(
827            parent_square, length, alpha, prev_angles
828        )
829        # the angle of the square is dependent on the parent, the current
830        # angle and the previous angles. Subtract PI/2 so it starts drawing at
831        # 0rads.
832        angle = parent_square.angle - pi / 2 + prev_angles + alpha / 2
833        square = Square(center, length, angle)
834
835        self._slopes[parent_square].append(alpha)
836
837        return self.pythagoras_tree(tree, node, square)
838
839    def _compute_center(self, initial_square, length, alpha, base_angle=0):
840        """Compute the central point of a child square.
841
842        Parameters
843        ----------
844        initial_square : Square
845            The parent square representation where we will be drawing from.
846        length : float
847            The length of the side of the new square (the one we are computing
848            the center for).
849        alpha : float
850            The angle that defines the size of our new square (in radians).
851        base_angle : float, optional
852            If the square we want to find the center for is not the first child
853            i.e. its edges does not touch the base square, then we need the
854            initial angle that will act as the starting point for the new
855            square.
856
857        Returns
858        -------
859        Point
860            The central point to the new square.
861
862        """
863        parent_center, parent_length, parent_angle = initial_square
864        # get the point on the square side that will be the rotation origin
865        t0 = self._get_point_on_square_edge(
866            parent_center, parent_length, parent_angle)
867        # get the edge point that we will rotate around t0
868        square_diagonal_length = sqrt(2 * parent_length ** 2)
869        edge = self._get_point_on_square_edge(
870            parent_center, square_diagonal_length, parent_angle - pi / 4)
871        # if the new square is not the first child, we need to rotate the edge
872        if base_angle != 0:
873            edge = self._rotate_point(edge, t0, base_angle)
874
875        # rotate the edge point to the correct spot
876        t1 = self._rotate_point(edge, t0, alpha)
877
878        # calculate the middle point between the rotated point and edge
879        t2 = Point((t1.x + edge.x) / 2, (t1.y + edge.y) / 2)
880        # calculate the slope of the new square
881        slope = parent_angle - pi / 2 + alpha / 2
882        # using this data, we can compute the square center
883        return self._get_point_on_square_edge(t2, length, slope + base_angle)
884
885    @staticmethod
886    def _rotate_point(point, around, alpha):
887        """Rotate a point around another point by some angle.
888
889        Parameters
890        ----------
891        point : Point
892            The point to rotate.
893        around : Point
894            The point to perform rotation around.
895        alpha : float
896            The angle to rotate by (in radians).
897
898        Returns
899        -------
900        Point:
901            The rotated point.
902
903        """
904        temp = Point(point.x - around.x, point.y - around.y)
905        temp = Point(
906            temp.x * cos(alpha) - temp.y * sin(alpha),
907            temp.x * sin(alpha) + temp.y * cos(alpha)
908        )
909        return Point(temp.x + around.x, temp.y + around.y)
910
911    @staticmethod
912    def _get_point_on_square_edge(center, length, angle):
913        """Calculate the central point on the drawing edge of the given square.
914
915        Parameters
916        ----------
917        center : Point
918            The square center point.
919        length : float
920            The square side length.
921        angle : float
922            The angle of the square.
923
924        Returns
925        -------
926        Point
927            A point on the center of the drawing edge of the given square.
928
929        """
930        return Point(
931            center.x + length / 2 * cos(angle),
932            center.y + length / 2 * sin(angle)
933        )
934