1# pylint: disable=missing-docstring,protected-access
2
3from unittest.mock import Mock
4
5from AnyQt.QtCore import Qt, QItemSelection, QItemSelectionModel
6
7from Orange.classification.random_forest import RandomForestLearner
8from Orange.data import Table
9from Orange.regression.random_forest import RandomForestRegressionLearner
10from Orange.widgets.tests.base import WidgetTest
11from Orange.widgets.tests.utils import simulate
12from Orange.widgets.visualize.owpythagoreanforest import OWPythagoreanForest
13from Orange.widgets.visualize.pythagorastreeviewer import PythagorasTreeViewer
14
15
16class TestOWPythagoreanForest(WidgetTest):
17    @classmethod
18    def setUpClass(cls):
19        super().setUpClass()
20
21        # Set up for widget tests
22        titanic_data = Table('titanic')[::50]
23        cls.titanic = RandomForestLearner(n_estimators=3)(titanic_data)
24        cls.titanic.instances = titanic_data
25
26        housing_data = Table('housing')[:10]
27        cls.housing = RandomForestRegressionLearner(
28            n_estimators=3)(housing_data)
29        cls.housing.instances = housing_data
30
31    def setUp(self):
32        self.widget = self.create_widget(OWPythagoreanForest)  # type: OWPythagoreanForest
33
34    def test_migrate_version_1_settings(self):
35        # Version 1 zoom ranged from 20-150, version 2 zoom ranges from 100-400
36        # Test minimium value falls into the range
37        widget_min_zoom = self.create_widget(
38            OWPythagoreanForest,
39            stored_settings={'zoom': 20, 'version': 2},
40        )  # type: OWPythagoreanForest
41        self.assertTrue(widget_min_zoom.zoom <= 400)
42        self.assertTrue(widget_min_zoom.zoom >= 100)
43
44        # Test maximum value falls into the range
45        widget_max_zoom = self.create_widget(
46            OWPythagoreanForest,
47            stored_settings={'zoom': 150, 'version': 2},
48        )  # type: OWPythagoreanForest
49        self.assertTrue(widget_max_zoom.zoom <= 400)
50        self.assertTrue(widget_max_zoom.zoom >= 100)
51
52    def get_tree_widgets(self):
53        model = self.widget.forest_model
54        trees = []
55        for idx in range(len(model)):
56            scene = model.data(model.index(idx), Qt.DisplayRole)
57            tree, = [item for item in scene.items()
58                     if isinstance(item, PythagorasTreeViewer)]
59            trees.append(tree)
60        return trees
61
62    def test_sending_rf_draws_trees(self):
63        w = self.widget
64        # No trees by default
65        self.assertEqual(len(self.get_tree_widgets()), 0,
66                         'No trees should be drawn when no forest on input')
67
68        # Draw trees for classification rf
69        self.send_signal(w.Inputs.random_forest, self.titanic)
70        self.assertEqual(len(self.get_tree_widgets()), 3,
71                         'Incorrect number of trees when forest on input')
72
73        # Clear trees when None
74        self.send_signal(w.Inputs.random_forest, None)
75        self.assertEqual(len(self.get_tree_widgets()), 0,
76                         'Trees are cleared when forest is disconnected')
77
78        # Draw trees for regression rf
79        self.send_signal(w.Inputs.random_forest, self.housing)
80        self.assertEqual(len(self.get_tree_widgets()), 3,
81                         'Incorrect number of trees when forest on input')
82
83    def test_info_label(self):
84        w = self.widget
85        regex = r'Trees:(.+)'
86        # If no forest on input, display a message saying that
87        self.assertNotRegex(w.ui_info.text(), regex,
88                            'Initial info should not contain info on trees')
89
90        self.send_signal(w.Inputs.random_forest, self.titanic)
91        self.assertRegex(self.widget.ui_info.text(), regex, 'Valid RF does not update info')
92
93        self.send_signal(w.Inputs.random_forest, None)
94        self.assertNotRegex(w.ui_info.text(), regex, 'Removing RF does not clear info box')
95
96    def test_depth_slider(self):
97        w = self.widget
98        self.send_signal(w.Inputs.random_forest, self.titanic)
99
100        trees = self.get_tree_widgets()
101        for tree in trees:
102            tree.set_depth_limit = Mock()
103
104        w.ui_depth_slider.setValue(0)
105        for tree in trees:
106            tree.set_depth_limit.assert_called_once_with(0)
107
108    def _get_first_tree(self):
109        """Pick a random tree from all the trees on the grid.
110
111        Returns
112        -------
113        PythagorasTreeViewer
114
115        """
116        widgets = self.get_tree_widgets()
117        assert len(widgets), 'Empty list of tree widgets'
118        return widgets[0]
119
120    def _get_visible_squares(self, tree):
121        return [x for _, x in tree._square_objects.items() if x.isVisible()]
122
123    def _check_all_same(self, items):
124        iter_items = iter(items)
125        try:
126            first = next(iter_items)
127        except StopIteration:
128            return True
129        return all(first == curr for curr in iter_items)
130
131    def test_changing_target_class_changes_coloring(self):
132        """Changing the `Target class` combo box should update colors."""
133        w = self.widget
134
135        def _test(data_type):
136            colors, tree = [], self._get_first_tree()
137
138            def _callback():
139                colors.append([sq.brush().color() for sq in self._get_visible_squares(tree)])
140
141            simulate.combobox_run_through_all(w.ui_target_class_combo, callback=_callback)
142
143            # Check that individual squares all have different colors
144            squares_same = [self._check_all_same(x) for x in zip(*colors)]
145            # Check that at least some of the squares have different colors
146            self.assertTrue(any(x is False for x in squares_same),
147                            'Colors did not change for %s data' % data_type)
148
149        self.send_signal(w.Inputs.random_forest, self.titanic)
150        _test('classification')
151        self.send_signal(w.Inputs.random_forest, self.housing)
152        _test('regression')
153
154    def test_changing_size_adjustment_changes_sizes(self):
155        w = self.widget
156        self.send_signal(w.Inputs.random_forest, self.titanic)
157        squares = []
158        tree = self._get_first_tree()
159
160        def _callback():
161            squares.append([sq.rect() for sq in self._get_visible_squares(tree)])
162
163        simulate.combobox_run_through_all(w.ui_size_calc_combo, callback=_callback)
164
165        # Check that individual squares are in different position
166        squares_same = [self._check_all_same(x) for x in zip(*squares)]
167        # Check that at least some of the squares have different positions
168        self.assertTrue(any(x is False for x in squares_same))
169
170    def test_zoom(self):
171        w = self.widget
172        self.send_signal(w.Inputs.random_forest, self.titanic)
173
174        min_zoom = w.ui_zoom_slider.minimum()
175        max_zoom = w.ui_zoom_slider.maximum()
176
177        # Increase the size of grid item
178        w.ui_zoom_slider.setValue(max_zoom)
179        item_size = w.forest_model.data(w.forest_model.index(0), Qt.SizeHintRole)
180        max_w, max_h = item_size.width(), item_size.height()
181
182        # Decrease the size of grid item
183        w.ui_zoom_slider.setValue(min_zoom)
184        item_size = w.forest_model.data(w.forest_model.index(0), Qt.SizeHintRole)
185        min_w, min_h = item_size.width(), item_size.height()
186
187        self.assertTrue(min_w < max_w and min_h < max_h)
188
189    def test_keep_colors_on_sizing_change(self):
190        """The color should be the same after a full recompute of the tree."""
191        w = self.widget
192        self.send_signal(w.Inputs.random_forest, self.titanic)
193        colors = []
194        tree = self._get_first_tree()
195
196        def _callback():
197            colors.append([sq.brush().color() for sq in self._get_visible_squares(tree)])
198
199        simulate.combobox_run_through_all(w.ui_size_calc_combo, callback=_callback)
200
201        # Check that individual squares all have the same color
202        colors_same = [self._check_all_same(x) for x in zip(*colors)]
203        self.assertTrue(all(colors_same))
204
205    def select_tree(self, idx: int) -> None:
206        list_view = self.widget.list_view
207        index = list_view.model().index(idx)
208        selection = QItemSelection(index, index)
209        list_view.selectionModel().select(selection, QItemSelectionModel.ClearAndSelect)
210
211    def test_storing_selection(self):
212        # Select one of the trees
213        idx = 1
214        self.send_signal(self.widget.Inputs.random_forest, self.titanic)
215        self.select_tree(idx)
216        # Clear input
217        self.send_signal(self.widget.Inputs.random_forest, None)
218        # Restore previous data; context settings should be restored
219        self.send_signal(self.widget.Inputs.random_forest, self.titanic)
220
221        output = self.get_output(self.widget.Outputs.tree)
222        self.assertIsNotNone(output)
223        self.assertIs(output.skl_model, self.titanic.trees[idx].skl_model)
224
225    def test_context(self):
226        iris = Table("iris")
227        iris_tree = RandomForestLearner()(iris)
228        iris_tree.instances = iris
229        self.send_signal(self.widget.Inputs.random_forest, self.titanic)
230        self.widget.target_class_index = 1
231
232        self.send_signal(self.widget.Inputs.random_forest, iris_tree)
233        self.assertEqual(0, self.widget.target_class_index)
234
235        self.widget.target_class_index = 2
236        self.send_signal(self.widget.Inputs.random_forest, self.titanic)
237        self.assertEqual(1, self.widget.target_class_index)
238
239        self.send_signal(self.widget.Inputs.random_forest, iris_tree)
240        self.assertEqual(2, self.widget.target_class_index)
241