1# Test methods with long descriptive names can omit docstrings 2# pylint: disable=missing-docstring, protected-access 3import unittest 4from unittest.mock import Mock 5 6import numpy as np 7 8from Orange import distance 9from Orange.data import Table, Domain, ContinuousVariable 10from Orange.misc import DistMatrix 11from Orange.widgets.unsupervised.owdistances import OWDistances, METRICS, \ 12 DistanceRunner 13from Orange.widgets.tests.base import WidgetTest 14 15 16class TestDistanceRunner(unittest.TestCase): 17 @classmethod 18 def setUpClass(cls): 19 super().setUpClass() 20 cls.iris = Table("iris")[::5] 21 cls.iris.X[0, 2] = np.nan 22 cls.iris.X[1, 3] = np.nan 23 cls.iris.X[2, 1] = np.nan 24 cls.zoo = Table("zoo")[::5] 25 cls.zoo.X[0, 2] = np.nan 26 cls.zoo.X[1, 3] = np.nan 27 cls.zoo.X[2, 1] = np.nan 28 29 def test_run(self): 30 state = Mock() 31 state.is_interruption_requested = Mock(return_value=False) 32 for name, metric in METRICS: 33 data = self.iris 34 if not metric.supports_missing or name == "Bhattacharyya": 35 data = distance.impute(data) 36 elif name == "Jaccard": 37 data = self.zoo 38 39 # between rows, normalized 40 dist1 = DistanceRunner.run(data, metric, True, 0, state) 41 dist2 = metric(data, axis=1, impute=True, normalize=True) 42 self.assertDistMatrixEqual(dist1, dist2) 43 44 # between rows, not normalized 45 dist1 = DistanceRunner.run(data, metric, False, 0, state) 46 dist2 = metric(data, axis=1, impute=True, normalize=False) 47 self.assertDistMatrixEqual(dist1, dist2) 48 49 # between columns, normalized 50 dist1 = DistanceRunner.run(data, metric, True, 1, state) 51 dist2 = metric(data, axis=0, impute=True, normalize=True) 52 self.assertDistMatrixEqual(dist1, dist2) 53 54 # between columns, not normalized 55 dist1 = DistanceRunner.run(data, metric, False, 1, state) 56 dist2 = metric(data, axis=0, impute=True, normalize=False) 57 self.assertDistMatrixEqual(dist1, dist2) 58 59 def assertDistMatrixEqual(self, dist1, dist2): 60 self.assertIsInstance(dist1, DistMatrix) 61 self.assertIsInstance(dist2, DistMatrix) 62 self.assertEqual(dist1.axis, dist2.axis) 63 self.assertEqual(dist1.row_items, dist2.row_items) 64 self.assertEqual(dist1.col_items, dist2.col_items) 65 np.testing.assert_array_almost_equal(dist1, dist2) 66 67 68class TestOWDistances(WidgetTest): 69 @classmethod 70 def setUpClass(cls): 71 super().setUpClass() 72 cls.iris = Table("iris")[::5] 73 cls.titanic = Table("titanic")[::10] 74 75 def setUp(self): 76 self.widget = self.create_widget(OWDistances) 77 78 def test_distance_combo(self): 79 """Check distances when the metric changes""" 80 self.assertEqual(self.widget.metrics_combo.count(), len(METRICS)) 81 self.send_signal(self.widget.Inputs.data, self.iris) 82 for i, (_, metric) in enumerate(METRICS): 83 self.widget.metrics_combo.activated.emit(i) 84 self.widget.metrics_combo.setCurrentIndex(i) 85 self.wait_until_stop_blocking() 86 if metric.supports_normalization: 87 expected = metric(self.iris, normalize=self.widget.normalized_dist) 88 else: 89 expected = metric(self.iris) 90 91 if metric is not distance.Jaccard: 92 np.testing.assert_array_almost_equal( 93 expected, self.get_output(self.widget.Outputs.distances)) 94 95 def test_error_message(self): 96 """Check if error message appears and then disappears when 97 data is removed from input""" 98 self.widget.metric_idx = 2 99 self.send_signal(self.widget.Inputs.data, self.iris) 100 self.wait_until_stop_blocking() 101 self.assertFalse(self.widget.Error.no_continuous_features.is_shown()) 102 self.send_signal(self.widget.Inputs.data, self.titanic) 103 self.wait_until_stop_blocking() 104 self.assertTrue(self.widget.Error.no_continuous_features.is_shown()) 105 self.send_signal(self.widget.Inputs.data, None) 106 self.assertFalse(self.widget.Error.no_continuous_features.is_shown()) 107 108 def test_jaccard_messages(self): 109 for self.widget.metric_idx, (name, _) in enumerate(METRICS): 110 if name == "Jaccard": 111 break 112 self.send_signal(self.widget.Inputs.data, self.iris) 113 self.wait_until_stop_blocking() 114 self.assertTrue(self.widget.Error.no_binary_features.is_shown()) 115 self.assertFalse(self.widget.Warning.ignoring_nonbinary.is_shown()) 116 117 self.send_signal(self.widget.Inputs.data, None) 118 self.wait_until_stop_blocking() 119 self.assertFalse(self.widget.Error.no_binary_features.is_shown()) 120 self.assertFalse(self.widget.Warning.ignoring_nonbinary.is_shown()) 121 122 self.send_signal(self.widget.Inputs.data, self.titanic) 123 self.wait_until_stop_blocking() 124 self.assertFalse(self.widget.Error.no_binary_features.is_shown()) 125 self.assertTrue(self.widget.Warning.ignoring_nonbinary.is_shown()) 126 127 self.send_signal(self.widget.Inputs.data, None) 128 self.wait_until_stop_blocking() 129 self.assertFalse(self.widget.Error.no_binary_features.is_shown()) 130 self.assertFalse(self.widget.Warning.ignoring_nonbinary.is_shown()) 131 132 self.send_signal(self.widget.Inputs.data, self.titanic) 133 self.wait_until_stop_blocking() 134 self.assertFalse(self.widget.Error.no_binary_features.is_shown()) 135 self.assertTrue(self.widget.Warning.ignoring_nonbinary.is_shown()) 136 137 dom = self.titanic.domain 138 dom = Domain(dom.attributes[1:], dom.class_var) 139 self.send_signal(self.widget.Inputs.data, self.titanic.transform(dom)) 140 self.wait_until_stop_blocking() 141 self.assertFalse(self.widget.Error.no_binary_features.is_shown()) 142 self.assertFalse(self.widget.Warning.ignoring_nonbinary.is_shown()) 143 144 self.send_signal(self.widget.Inputs.data, Table("heart_disease")) 145 self.wait_until_stop_blocking() 146 self.assertFalse(self.widget.Error.no_binary_features.is_shown()) 147 self.assertFalse(self.widget.Warning.ignoring_discrete.is_shown()) 148 149 def test_too_big_array(self): 150 """ 151 Users sees an error message when calculating too large arrays and Orange 152 does not crash. 153 GH-2315 154 """ 155 self.assertEqual(len(self.widget.Error.active), 0) 156 self.send_signal(self.widget.Inputs.data, self.iris) 157 158 mock = Mock(side_effect=ValueError) 159 self.widget.compute_distances(mock, self.iris) 160 self.wait_until_finished() 161 self.assertTrue(self.widget.Error.distances_value_error.is_shown()) 162 163 mock = Mock(side_effect=MemoryError) 164 self.widget.compute_distances(mock, self.iris) 165 self.wait_until_finished() 166 self.assertEqual(len(self.widget.Error.active), 1) 167 self.assertTrue(self.widget.Error.distances_memory_error.is_shown()) 168 169 def test_migrates_normalized_dist(self): 170 w = self.create_widget(OWDistances, stored_settings={"metric_idx": 0}) 171 self.assertFalse(w.normalized_dist) 172 173 def test_negative_values_bhattacharyya(self): 174 self.iris.X[0, 0] *= -1 175 for self.widget.metric_idx, (_, metric) in enumerate(METRICS): 176 if metric == distance.Bhattacharyya: 177 break 178 self.send_signal(self.widget.Inputs.data, self.iris) 179 self.wait_until_finished() 180 self.assertTrue(self.widget.Error.distances_value_error.is_shown()) 181 self.iris.X[0, 0] *= -1 182 183 def test_limit_mahalanobis(self): 184 def assert_error_shown(): 185 self.assertTrue( 186 self.widget.Error.data_too_large_for_mahalanobis.is_shown()) 187 188 def assert_no_error(): 189 self.assertFalse( 190 self.widget.Error.data_too_large_for_mahalanobis.is_shown()) 191 192 widget = self.widget 193 axis_buttons = widget.controls.axis.buttons 194 195 self.assertEqual(widget.metrics_combo.count(), len(METRICS)) 196 for i, (_, metric) in enumerate(METRICS): 197 if metric == distance.Mahalanobis: 198 widget.metrics_combo.setCurrentIndex(i) 199 widget.metrics_combo.activated.emit(i) 200 break 201 202 X = np.random.random((1010, 4)) 203 bigrows = Table.from_numpy(Domain(self.iris.domain.attributes), X) 204 bigcols = Table.from_numpy( 205 Domain([ContinuousVariable(f"{i}") for i in range(1010)]), X.T) 206 207 self.send_signal(widget.Inputs.data, self.iris) 208 assert_no_error() 209 210 axis_buttons[0].click() 211 assert_no_error() 212 axis_buttons[1].click() 213 assert_no_error() 214 215 # by columns -- cannot handle too many rows 216 self.send_signal(self.widget.Inputs.data, bigrows) 217 assert_error_shown() 218 axis_buttons[0].click() 219 assert_no_error() 220 axis_buttons[1].click() 221 assert_error_shown() 222 223 self.send_signal(self.widget.Inputs.data, bigcols) 224 assert_no_error() 225 axis_buttons[0].click() 226 assert_error_shown() 227 228 self.send_signal(widget.Inputs.data, self.iris) 229 assert_no_error() 230 231 232if __name__ == "__main__": 233 unittest.main() 234