1# pylint: disable=missing-docstring,protected-access 2 3from unittest.mock import Mock 4 5from AnyQt.QtCore import Qt, QItemSelection, QItemSelectionModel 6 7from Orange.classification.random_forest import RandomForestLearner 8from Orange.data import Table 9from Orange.regression.random_forest import RandomForestRegressionLearner 10from Orange.widgets.tests.base import WidgetTest 11from Orange.widgets.tests.utils import simulate 12from Orange.widgets.visualize.owpythagoreanforest import OWPythagoreanForest 13from Orange.widgets.visualize.pythagorastreeviewer import PythagorasTreeViewer 14 15 16class TestOWPythagoreanForest(WidgetTest): 17 @classmethod 18 def setUpClass(cls): 19 super().setUpClass() 20 21 # Set up for widget tests 22 titanic_data = Table('titanic')[::50] 23 cls.titanic = RandomForestLearner(n_estimators=3)(titanic_data) 24 cls.titanic.instances = titanic_data 25 26 housing_data = Table('housing')[:10] 27 cls.housing = RandomForestRegressionLearner( 28 n_estimators=3)(housing_data) 29 cls.housing.instances = housing_data 30 31 def setUp(self): 32 self.widget = self.create_widget(OWPythagoreanForest) # type: OWPythagoreanForest 33 34 def test_migrate_version_1_settings(self): 35 # Version 1 zoom ranged from 20-150, version 2 zoom ranges from 100-400 36 # Test minimium value falls into the range 37 widget_min_zoom = self.create_widget( 38 OWPythagoreanForest, 39 stored_settings={'zoom': 20, 'version': 2}, 40 ) # type: OWPythagoreanForest 41 self.assertTrue(widget_min_zoom.zoom <= 400) 42 self.assertTrue(widget_min_zoom.zoom >= 100) 43 44 # Test maximum value falls into the range 45 widget_max_zoom = self.create_widget( 46 OWPythagoreanForest, 47 stored_settings={'zoom': 150, 'version': 2}, 48 ) # type: OWPythagoreanForest 49 self.assertTrue(widget_max_zoom.zoom <= 400) 50 self.assertTrue(widget_max_zoom.zoom >= 100) 51 52 def get_tree_widgets(self): 53 model = self.widget.forest_model 54 trees = [] 55 for idx in range(len(model)): 56 scene = model.data(model.index(idx), Qt.DisplayRole) 57 tree, = [item for item in scene.items() 58 if isinstance(item, PythagorasTreeViewer)] 59 trees.append(tree) 60 return trees 61 62 def test_sending_rf_draws_trees(self): 63 w = self.widget 64 # No trees by default 65 self.assertEqual(len(self.get_tree_widgets()), 0, 66 'No trees should be drawn when no forest on input') 67 68 # Draw trees for classification rf 69 self.send_signal(w.Inputs.random_forest, self.titanic) 70 self.assertEqual(len(self.get_tree_widgets()), 3, 71 'Incorrect number of trees when forest on input') 72 73 # Clear trees when None 74 self.send_signal(w.Inputs.random_forest, None) 75 self.assertEqual(len(self.get_tree_widgets()), 0, 76 'Trees are cleared when forest is disconnected') 77 78 # Draw trees for regression rf 79 self.send_signal(w.Inputs.random_forest, self.housing) 80 self.assertEqual(len(self.get_tree_widgets()), 3, 81 'Incorrect number of trees when forest on input') 82 83 def test_info_label(self): 84 w = self.widget 85 regex = r'Trees:(.+)' 86 # If no forest on input, display a message saying that 87 self.assertNotRegex(w.ui_info.text(), regex, 88 'Initial info should not contain info on trees') 89 90 self.send_signal(w.Inputs.random_forest, self.titanic) 91 self.assertRegex(self.widget.ui_info.text(), regex, 'Valid RF does not update info') 92 93 self.send_signal(w.Inputs.random_forest, None) 94 self.assertNotRegex(w.ui_info.text(), regex, 'Removing RF does not clear info box') 95 96 def test_depth_slider(self): 97 w = self.widget 98 self.send_signal(w.Inputs.random_forest, self.titanic) 99 100 trees = self.get_tree_widgets() 101 for tree in trees: 102 tree.set_depth_limit = Mock() 103 104 w.ui_depth_slider.setValue(0) 105 for tree in trees: 106 tree.set_depth_limit.assert_called_once_with(0) 107 108 def _get_first_tree(self): 109 """Pick a random tree from all the trees on the grid. 110 111 Returns 112 ------- 113 PythagorasTreeViewer 114 115 """ 116 widgets = self.get_tree_widgets() 117 assert len(widgets), 'Empty list of tree widgets' 118 return widgets[0] 119 120 def _get_visible_squares(self, tree): 121 return [x for _, x in tree._square_objects.items() if x.isVisible()] 122 123 def _check_all_same(self, items): 124 iter_items = iter(items) 125 try: 126 first = next(iter_items) 127 except StopIteration: 128 return True 129 return all(first == curr for curr in iter_items) 130 131 def test_changing_target_class_changes_coloring(self): 132 """Changing the `Target class` combo box should update colors.""" 133 w = self.widget 134 135 def _test(data_type): 136 colors, tree = [], self._get_first_tree() 137 138 def _callback(): 139 colors.append([sq.brush().color() for sq in self._get_visible_squares(tree)]) 140 141 simulate.combobox_run_through_all(w.ui_target_class_combo, callback=_callback) 142 143 # Check that individual squares all have different colors 144 squares_same = [self._check_all_same(x) for x in zip(*colors)] 145 # Check that at least some of the squares have different colors 146 self.assertTrue(any(x is False for x in squares_same), 147 'Colors did not change for %s data' % data_type) 148 149 self.send_signal(w.Inputs.random_forest, self.titanic) 150 _test('classification') 151 self.send_signal(w.Inputs.random_forest, self.housing) 152 _test('regression') 153 154 def test_changing_size_adjustment_changes_sizes(self): 155 w = self.widget 156 self.send_signal(w.Inputs.random_forest, self.titanic) 157 squares = [] 158 tree = self._get_first_tree() 159 160 def _callback(): 161 squares.append([sq.rect() for sq in self._get_visible_squares(tree)]) 162 163 simulate.combobox_run_through_all(w.ui_size_calc_combo, callback=_callback) 164 165 # Check that individual squares are in different position 166 squares_same = [self._check_all_same(x) for x in zip(*squares)] 167 # Check that at least some of the squares have different positions 168 self.assertTrue(any(x is False for x in squares_same)) 169 170 def test_zoom(self): 171 w = self.widget 172 self.send_signal(w.Inputs.random_forest, self.titanic) 173 174 min_zoom = w.ui_zoom_slider.minimum() 175 max_zoom = w.ui_zoom_slider.maximum() 176 177 # Increase the size of grid item 178 w.ui_zoom_slider.setValue(max_zoom) 179 item_size = w.forest_model.data(w.forest_model.index(0), Qt.SizeHintRole) 180 max_w, max_h = item_size.width(), item_size.height() 181 182 # Decrease the size of grid item 183 w.ui_zoom_slider.setValue(min_zoom) 184 item_size = w.forest_model.data(w.forest_model.index(0), Qt.SizeHintRole) 185 min_w, min_h = item_size.width(), item_size.height() 186 187 self.assertTrue(min_w < max_w and min_h < max_h) 188 189 def test_keep_colors_on_sizing_change(self): 190 """The color should be the same after a full recompute of the tree.""" 191 w = self.widget 192 self.send_signal(w.Inputs.random_forest, self.titanic) 193 colors = [] 194 tree = self._get_first_tree() 195 196 def _callback(): 197 colors.append([sq.brush().color() for sq in self._get_visible_squares(tree)]) 198 199 simulate.combobox_run_through_all(w.ui_size_calc_combo, callback=_callback) 200 201 # Check that individual squares all have the same color 202 colors_same = [self._check_all_same(x) for x in zip(*colors)] 203 self.assertTrue(all(colors_same)) 204 205 def select_tree(self, idx: int) -> None: 206 list_view = self.widget.list_view 207 index = list_view.model().index(idx) 208 selection = QItemSelection(index, index) 209 list_view.selectionModel().select(selection, QItemSelectionModel.ClearAndSelect) 210 211 def test_storing_selection(self): 212 # Select one of the trees 213 idx = 1 214 self.send_signal(self.widget.Inputs.random_forest, self.titanic) 215 self.select_tree(idx) 216 # Clear input 217 self.send_signal(self.widget.Inputs.random_forest, None) 218 # Restore previous data; context settings should be restored 219 self.send_signal(self.widget.Inputs.random_forest, self.titanic) 220 221 output = self.get_output(self.widget.Outputs.tree) 222 self.assertIsNotNone(output) 223 self.assertIs(output.skl_model, self.titanic.trees[idx].skl_model) 224 225 def test_context(self): 226 iris = Table("iris") 227 iris_tree = RandomForestLearner()(iris) 228 iris_tree.instances = iris 229 self.send_signal(self.widget.Inputs.random_forest, self.titanic) 230 self.widget.target_class_index = 1 231 232 self.send_signal(self.widget.Inputs.random_forest, iris_tree) 233 self.assertEqual(0, self.widget.target_class_index) 234 235 self.widget.target_class_index = 2 236 self.send_signal(self.widget.Inputs.random_forest, self.titanic) 237 self.assertEqual(1, self.widget.target_class_index) 238 239 self.send_signal(self.widget.Inputs.random_forest, iris_tree) 240 self.assertEqual(2, self.widget.target_class_index) 241