1from collections import OrderedDict
2
3from AnyQt.QtGui import (
4    QBrush, QPen, QColor, QPainter, QPainterPath, QTransform
5)
6from AnyQt.QtWidgets import (
7    QGraphicsItem, QGraphicsEllipseItem, QGraphicsTextItem,
8    QGraphicsLineItem, QGraphicsScene, QGraphicsView, QStyle, QSizePolicy,
9    QFormLayout
10)
11from AnyQt.QtCore import (
12    Qt, QRectF, QSize, QPointF, QLineF, QTimer,
13    pyqtSignal, pyqtProperty
14)
15
16from Orange.widgets import gui
17from Orange.widgets.widget import OWWidget
18from Orange.widgets.settings import Setting
19
20DefDroppletBrush = QBrush(Qt.darkGray)
21
22
23class GraphNode:
24    def __init__(self, *_, **kwargs):
25        # Implement edges as an ordered dict to get the nice speed benefits as
26        # well as adding ordering, which we need to make trees deterministic
27        self.__edges = kwargs.get("edges", OrderedDict())
28
29    def graph_edges(self):
30        """Get a list of the edges that stem from the node."""
31        return self.__edges.keys()
32
33    def graph_add_edge(self, edge):
34        """Add an edge stemming from the node."""
35        self.__edges[edge] = 0
36
37    def __iter__(self):
38        for edge in self.__edges.keys():
39            yield edge.node2
40
41    def graph_nodes(self, atype=1):
42        pass
43
44
45class GraphEdge:
46    def __init__(self, node1=None, node2=None, atype=1):
47        self.node1 = node1
48        self.node2 = node2
49        self.type = atype
50        if node1 is not None:
51            node1.graph_add_edge(self)
52        if node2 is not None:
53            node2.graph_add_edge(self)
54
55
56class GraphicsDroplet(QGraphicsEllipseItem):
57    def __init__(self, *args):
58        super().__init__(*args)
59        self.setAcceptHoverEvents(True)
60        self.setAcceptedMouseButtons(Qt.LeftButton)
61        self.setBrush(QBrush(Qt.gray))
62        self.setPen(Qt.white)
63
64    def hoverEnterEvent(self, event):
65        super().hoverEnterEvent(event)
66        self.setBrush(QBrush(QColor(100, 100, 100)))
67        self.update()
68
69    def hoverLeaveEvent(self, event):
70        super().hoverLeaveEvent(event)
71        self.setBrush(QBrush(QColor(200, 200, 200)))
72        self.update()
73
74    def mousePressEvent(self, event):
75        super().mousePressEvent(event)
76        self.parentItem().set_open(not self.parentItem().isOpen)
77        if self.scene():
78            self.scene().fix_pos()
79
80
81# noinspection PyPep8Naming
82class TextTreeNode(QGraphicsTextItem, GraphNode):
83    def setBackgroundBrush(self, brush):
84        if self._background_brush != brush:
85            self._background_brush = QBrush(brush)
86            color = brush.color()
87            r, g, b, _ = color.getRgb()
88            lum = 0.2126 * r + 0.7152 * g + 0.0722 * b
89            if lum > 100:
90                self.setDefaultTextColor(Qt.black)
91            else:
92                self.setDefaultTextColor(Qt.white)
93            self.update()
94
95    def backgroundBrush(self):
96        brush = getattr(self, "_background_brush")
97        if brush is None:
98            brush = getattr(self.scene(), "defaultItemBrush", Qt.NoBrush)
99        return QBrush(brush)
100
101    backgroundBrush = pyqtProperty(
102        "QBrush", fget=backgroundBrush, fset=setBackgroundBrush,
103        doc="Background brush")
104
105    def __init__(self, parent, *args, **kwargs):
106        QGraphicsTextItem.__init__(self, *args)
107        GraphNode.__init__(self, **kwargs)
108        self._background_brush = None
109        self._rect = None
110
111        self.parent = parent
112        font = self.font()
113        font.setPointSize(10)
114        self.setFont(font)
115        self.droplet = GraphicsDroplet(-5, 0, 10, 10, self)
116        self.droplet.setPos(self.rect().center().x(), self.rect().height())
117        self.document().contentsChanged.connect(self.update_contents)
118        self.isOpen = True
119        self.setFlag(QGraphicsItem.ItemIsSelectable, True)
120
121    def setHtml(self, html):
122        return super().setHtml("<body>" + html + "</body>")
123
124    def update_contents(self):
125        self.setTextWidth(-1)
126        self.setTextWidth(self.document().idealWidth())
127        self.droplet.setPos(self.rect().center().x(), self.rect().height())
128        self.droplet.setVisible(bool(self.branches))
129
130    def set_rect(self, rect):
131        self.prepareGeometryChange()
132        rect = QRectF() if rect is None else rect
133        self._rect = rect
134        self.update_contents()
135        self.update()
136
137    def shape(self):
138        path = QPainterPath()
139        path.addRect(self.boundingRect())
140        return path
141
142    def rect(self):
143        if getattr(self, "_rect", QRectF()).isValid():
144            return self._rect
145        else:
146            return QRectF(QPointF(0, 0), self.document().size()) | \
147                   getattr(self, "_rect", QRectF(0, 0, 1, 1))
148
149    def boundingRect(self):
150        return self._rect if getattr(self, "_rect", QRectF()).isValid() \
151            else super().boundingRect()
152
153    @property
154    def branches(self):
155        return [edge.node2 for edge in self.graph_edges() if edge.node1 is self]
156
157    def paint(self, painter, option, widget=0):
158        painter.save()
159        painter.setBrush(self.backgroundBrush)
160        painter.setPen(QPen(Qt.gray))
161        rect = self.rect()
162        painter.drawRoundedRect(rect, 4, 4)
163        painter.restore()
164        painter.setClipRect(rect)
165        return QGraphicsTextItem.paint(self, painter, option, widget)
166
167
168class GraphicsNode(TextTreeNode):
169    def graph_traverse_bf(self):
170        visited = set()
171        queue = list(self)
172        while queue:
173            node = queue.pop(0)
174            if node not in visited:
175                yield node
176                visited.add(node)
177                if node.isOpen:
178                    queue.extend(list(node))
179
180    def set_open(self, do_open):
181        self.isOpen = do_open
182        for node in self.graph_traverse_bf():
183            if node is not self:
184                node.setVisible(do_open)
185
186    def itemChange(self, change, value):
187        if change in [QGraphicsItem.ItemPositionHasChanged,
188                      QGraphicsItem.ItemVisibleHasChanged]:
189            self.update_edge()
190        return super().itemChange(change, value)
191
192    # noinspection PyCallByClass,PyTypeChecker
193    def update_edge(self):
194        for edge in self.graph_edges():
195            if edge.node1 is self:
196                QTimer.singleShot(0, edge.update_ends)
197            elif edge.node2 is self:
198                edge.setVisible(self.isVisible())
199
200    def edge_in_point(self, edge):
201        return edge.mapFromItem(
202            self, QPointF(self.rect().center().x(), self.rect().y()))
203
204    def edge_out_point(self, edge):
205        return edge.mapFromItem(self.droplet, self.droplet.rect().center())
206
207    def paint(self, painter, option, widget=0):
208        if self.isSelected():
209            option.state ^= QStyle.State_Selected
210        if self.isSelected():
211            rect = self.rect()
212            painter.save()
213            painter.setBrush(QBrush(QColor(125, 162, 206, 192)))
214            painter.drawRoundedRect(rect.adjusted(-4, -4, 4, 4), 10, 10)
215            painter.restore()
216        super().paint(painter, option, widget)
217
218    def boundingRect(self):
219        return super().boundingRect().adjusted(-5, -5, 5, 5)
220
221
222class GraphicsEdge(QGraphicsLineItem, GraphEdge):
223    def __init__(self, *args, **kwargs):
224        QGraphicsLineItem.__init__(self, *args)
225        GraphEdge.__init__(self, **kwargs)
226        self.setZValue(-30)
227
228    def update_ends(self):
229        try:
230            self.prepareGeometryChange()
231            self.setLine(QLineF(self.node1.edge_out_point(self),
232                                self.node2.edge_in_point(self)))
233        except RuntimeError:  # this gets called through QTimer.singleShot
234            # and might already be deleted by Qt
235            pass
236
237
238class TreeGraphicsView(QGraphicsView):
239    resized = pyqtSignal(QSize, name="resized")
240
241    def __init__(self, scene, *args):
242        super().__init__(scene, *args)
243        self.viewport().setMouseTracking(True)
244        self.setFocusPolicy(Qt.WheelFocus)
245        self.setRenderHint(QPainter.Antialiasing)
246        self.setRenderHint(QPainter.TextAntialiasing)
247        self.setRenderHint(QPainter.HighQualityAntialiasing)
248
249    def resizeEvent(self, event):
250        super().resizeEvent(event)
251        self.resized.emit(self.size())
252
253
254class TreeGraphicsScene(QGraphicsScene):
255    _HSPACING = 10
256    _VSPACING = 10
257
258    def __init__(self, master, *args):
259        super().__init__(*args)
260        self.master = master
261        self.nodeList = []
262        self.edgeList = []
263        self.gx = self.gy = 10
264
265    def fix_pos(self, node=None, x=10, y=10):
266        self.gx, self.gy = x, y
267        if not node:
268            if self.nodes():
269                node = [node for node in self.nodes() if not node.parent][0]
270            else:
271                return
272        if not x or not y:
273            x, y = self._HSPACING, self._VSPACING
274        self._fix_pos(node, x, y)
275        self.setSceneRect(QRectF(0, 0, self.gx, self.gy).adjusted(-10, -10, 100, 100))
276        self.update()
277
278    def _fix_pos(self, node, x, y):
279        """Fix the position of the tree stemming from the given node."""
280        def brect(node):
281            """Get the bounding box of the parent rect and all its children."""
282            return node.boundingRect() | node.childrenBoundingRect()
283
284        if node.branches and node.isOpen:
285            for n in node.branches:
286                x, _ = self._fix_pos(n, x, y + self._VSPACING + brect(node).height())
287            x = (node.branches[0].pos().x() + node.branches[-1].pos().x()) / 2
288            node.setPos(x, y)
289            for e in node.graph_edges():
290                e.update_ends()
291        else:
292            node.setPos(self.gx, y)
293            self.gx += self._HSPACING + brect(node).width()
294            x += self._HSPACING + brect(node).width()
295            self.gy = max(y, self.gy)
296        return x, y
297
298    def mouseMoveEvent(self, event):
299        return QGraphicsScene.mouseMoveEvent(self, event)
300
301    def mousePressEvent(self, event):
302        return QGraphicsScene.mousePressEvent(self, event)
303
304    def edges(self):
305        return [item for item in self.items() if isinstance(item, GraphEdge)]
306
307    def nodes(self):
308        return [item for item in self.items() if isinstance(item, GraphNode)]
309
310
311class TreeNavigator(QGraphicsView):
312    def __init__(self, master_view, *_):
313        super().__init__()
314        self.master_view = master_view
315        self.setScene(self.master_view.scene())
316        self.scene().sceneRectChanged.connect(self.updateSceneRect)
317        self.master_view.resized.connect(self.update_view)
318        self.setRenderHint(QPainter.Antialiasing)
319
320    def mousePressEvent(self, event):
321        if event.buttons() & Qt.LeftButton:
322            self.master_view.centerOn(self.mapToScene(event.pos()))
323            self.update_view()
324        return super().mousePressEvenr(event)
325
326    def mouseMoveEvent(self, event):
327        if event.buttons() & Qt.LeftButton:
328            self.master_view.centerOn(self.mapToScene(event.pos()))
329            self.update_view()
330        return super().mouseMoveEvent(event)
331
332    def resizeEvent(self, event):
333        QGraphicsView.resizeEvent(self, event)
334        self.update_view()
335
336    # noinspection PyPep8Naming
337    def resizeView(self):
338        self.update_view()
339
340    def updateSceneRect(self, rect):
341        super().updateSceneRect(rect)
342        self.update_view()
343
344    def update_view(self, *_):
345        if self.scene():
346            self.fitInView(self.scene().sceneRect())
347
348    def paintEvent(self, event):
349        super().paintEvent(event)
350        painter = QPainter(self.viewport())
351        painter.setBrush(QColor(100, 100, 100, 100))
352        painter.setRenderHints(self.renderHints())
353        painter.drawPolygon(self.viewPolygon())
354
355    # noinspection PyPep8Naming
356    def viewPolygon(self):
357        return self.mapFromScene(
358            self.master_view.mapToScene(self.master_view.viewport().rect()))
359
360
361class OWTreeViewer2D(OWWidget, openclass=True):
362    zoom = Setting(5)
363    line_width_method = Setting(2)
364    max_tree_depth = Setting(0)
365    max_node_width = Setting(150)
366
367    _VSPACING = 5
368    _HSPACING = 5
369    _TOOLTIPS_ENABLED = True
370    _DEF_NODE_WIDTH = 24
371    _DEF_NODE_HEIGHT = 20
372
373    graph_name = "scene"
374
375    def __init__(self):
376        super().__init__()
377        self.selected_node = None
378        self.root_node = None
379        self.model = None
380
381        box = gui.vBox(
382            self.controlArea, 'Tree',
383            sizePolicy=QSizePolicy(QSizePolicy.Minimum, QSizePolicy.Fixed))
384        self.infolabel = gui.widgetLabel(box, 'No tree.')
385
386        layout = QFormLayout()
387        layout.setFieldGrowthPolicy(layout.ExpandingFieldsGrow)
388        box = self.display_box = gui.widgetBox(self.controlArea, "Display",
389                                               orientation=layout)
390        layout.addRow(
391            "Zoom: ",
392            gui.hSlider(box, self, 'zoom',
393                        minValue=1, maxValue=10, step=1, ticks=False,
394                        callback=self.toggle_zoom_slider,
395                        createLabel=False, addToLayout=False))
396        layout.addRow(
397            "Width: ",
398            gui.hSlider(box, self, 'max_node_width',
399                        minValue=50, maxValue=200, step=1, ticks=False,
400                        callback=self.toggle_node_size,
401                        createLabel=False, addToLayout=False))
402        policy = QSizePolicy(QSizePolicy.MinimumExpanding, QSizePolicy.Fixed)
403        layout.addRow(
404            "Depth: ",
405            gui.comboBox(box, self, 'max_tree_depth',
406                         items=["Unlimited"] + [
407                             "{} levels".format(x) for x in range(2, 10)],
408                         addToLayout=False, sendSelectedValue=False,
409                         callback=self.toggle_tree_depth, sizePolicy=policy))
410        layout.addRow(
411            "Edge width: ",
412            gui.comboBox(box, self, 'line_width_method',
413                         items=['Fixed', 'Relative to root',
414                                'Relative to parent'],
415                         addToLayout=False,
416                         callback=self.toggle_line_width, sizePolicy=policy))
417        gui.rubber(self.controlArea)
418
419        self.scene = TreeGraphicsScene(self)
420        self.scene_view = TreeGraphicsView(self.scene)
421        self.scene_view.setViewportUpdateMode(QGraphicsView.FullViewportUpdate)
422        self.mainArea.layout().addWidget(self.scene_view)
423        self.toggle_zoom_slider()
424        self.scene.selectionChanged.connect(self.update_selection)
425
426    def send_report(self):
427        from AnyQt.QtSvg import QSvgGenerator
428
429        if self.model:
430            self.reportSection("Tree")
431            _, filefn = self.getUniqueImageName(ext=".svg")
432            svg = QSvgGenerator()
433            svg.setFileName(filefn)
434            ssize = self.scene.sceneRect().size()
435            w, h = ssize.width(), ssize.height()
436            fact = 600 / w
437            svg.setSize(QSize(600, h * fact))
438            painter = QPainter()
439            painter.begin(svg)
440            self.scene.render(painter)
441            painter.end()
442
443    def toggle_zoom_slider(self):
444        k = 0.0028 * (self.zoom ** 2) + 0.2583 * self.zoom + 1.1389
445        self.scene_view.setTransform(QTransform().scale(k / 2, k / 2))
446        self.scene.update()
447
448    def toggle_tree_depth(self):
449        self.walkupdate(self.root_node)
450        self.scene.fix_pos(self.root_node, 10, 10)
451        self.scene.update()
452
453    def toggle_line_width(self):
454        if self.root_node is None:
455            return
456
457        tree_adapter = self.root_node.tree_adapter
458        root_instances = tree_adapter.num_samples(self.root_node.node_inst)
459        width = 3
460        OFFSET = 0.20
461        for edge in self.scene.edges():
462            num_inst = tree_adapter.num_samples(edge.node2.node_inst)
463            if self.line_width_method == 1:
464                width = 8 * num_inst / root_instances + OFFSET
465            elif self.line_width_method == 2:
466                width = 8 * num_inst / tree_adapter.num_samples(
467                    edge.node1.node_inst) + OFFSET
468            edge.setPen(QPen(Qt.gray, width, Qt.SolidLine, Qt.RoundCap))
469        self.scene.update()
470
471    def toggle_node_size(self):
472        self.set_node_info()
473        self.scene.update()
474        self.scene_view.repaint()
475
476    def toggle_navigator(self):
477        self.nav_widget.setHidden(not self.nav_widget.isHidden())
478
479    def activate_loaded_settings(self):
480        if not self.model:
481            return
482        self.rescale_tree()
483        self.scene.fix_pos(self.root_node, 10, 10)
484        self.scene.update()
485        self.toggle_tree_depth()
486        self.toggle_line_width()
487
488    def clear_scene(self):
489        self.scene.clear()
490        self.scene.setSceneRect(QRectF())
491
492    def setup_scene(self):
493        if self.root_node is not None:
494            self.scene.fix_pos(self.root_node, self._HSPACING, self._VSPACING)
495            self.activate_loaded_settings()
496            self.scene_view.centerOn(self.root_node.x(), self.root_node.y())
497            self.update_node_tooltips()
498        self.scene.update()
499
500    def walkupdate(self, node, level=0):
501        if not node:
502            return
503        if self.max_tree_depth and self.max_tree_depth < level + 1:
504            node.set_open(False)
505            return
506        else:
507            node.set_open(True)
508        for n in node.branches:
509            self.walkupdate(n, level + 1)
510
511    def update_node_tooltips(self):
512        for node in self.scene.nodes():
513            node.setToolTip(self.node_tooltip(node) if self._TOOLTIPS_ENABLED
514                            else "")
515
516    def node_tooltip(self, tree):
517        return "tree node"
518
519    def rescale_tree(self):
520        node_height = self._DEF_NODE_HEIGHT
521        node_width = self._DEF_NODE_WIDTH
522        for r in self.scene.nodeList:
523            r.set_rect(r.rect().x(), r.rect().y(), node_width, node_height)
524        self.scene.fix_pos()
525
526    def update_selection(self):
527        self.selected_node = (self.scene.selectedItems() + [None])[0]
528        # self.centerNodeButton.setDisabled(not self.selected_node)
529        # self.send("Data", self.selectedNode.tree.examples if self.selectedNode
530        # else None)
531