1# Test methods with long descriptive names can omit docstrings
2# pylint: disable=missing-docstring, protected-access
3import unittest
4from unittest.mock import Mock
5
6import numpy as np
7
8from Orange import distance
9from Orange.data import Table, Domain, ContinuousVariable
10from Orange.misc import DistMatrix
11from Orange.widgets.unsupervised.owdistances import OWDistances, METRICS, \
12    DistanceRunner
13from Orange.widgets.tests.base import WidgetTest
14
15
16class TestDistanceRunner(unittest.TestCase):
17    @classmethod
18    def setUpClass(cls):
19        super().setUpClass()
20        cls.iris = Table("iris")[::5]
21        cls.iris.X[0, 2] = np.nan
22        cls.iris.X[1, 3] = np.nan
23        cls.iris.X[2, 1] = np.nan
24        cls.zoo = Table("zoo")[::5]
25        cls.zoo.X[0, 2] = np.nan
26        cls.zoo.X[1, 3] = np.nan
27        cls.zoo.X[2, 1] = np.nan
28
29    def test_run(self):
30        state = Mock()
31        state.is_interruption_requested = Mock(return_value=False)
32        for name, metric in METRICS:
33            data = self.iris
34            if not metric.supports_missing or name == "Bhattacharyya":
35                data = distance.impute(data)
36            elif name == "Jaccard":
37                data = self.zoo
38
39            # between rows, normalized
40            dist1 = DistanceRunner.run(data, metric, True, 0, state)
41            dist2 = metric(data, axis=1, impute=True, normalize=True)
42            self.assertDistMatrixEqual(dist1, dist2)
43
44            # between rows, not normalized
45            dist1 = DistanceRunner.run(data, metric, False, 0, state)
46            dist2 = metric(data, axis=1, impute=True, normalize=False)
47            self.assertDistMatrixEqual(dist1, dist2)
48
49            # between columns, normalized
50            dist1 = DistanceRunner.run(data, metric, True, 1, state)
51            dist2 = metric(data, axis=0, impute=True, normalize=True)
52            self.assertDistMatrixEqual(dist1, dist2)
53
54            # between columns, not normalized
55            dist1 = DistanceRunner.run(data, metric, False, 1, state)
56            dist2 = metric(data, axis=0, impute=True, normalize=False)
57            self.assertDistMatrixEqual(dist1, dist2)
58
59    def assertDistMatrixEqual(self, dist1, dist2):
60        self.assertIsInstance(dist1, DistMatrix)
61        self.assertIsInstance(dist2, DistMatrix)
62        self.assertEqual(dist1.axis, dist2.axis)
63        self.assertEqual(dist1.row_items, dist2.row_items)
64        self.assertEqual(dist1.col_items, dist2.col_items)
65        np.testing.assert_array_almost_equal(dist1, dist2)
66
67
68class TestOWDistances(WidgetTest):
69    @classmethod
70    def setUpClass(cls):
71        super().setUpClass()
72        cls.iris = Table("iris")[::5]
73        cls.titanic = Table("titanic")[::10]
74
75    def setUp(self):
76        self.widget = self.create_widget(OWDistances)
77
78    def test_distance_combo(self):
79        """Check distances when the metric changes"""
80        self.assertEqual(self.widget.metrics_combo.count(), len(METRICS))
81        self.send_signal(self.widget.Inputs.data, self.iris)
82        for i, (_, metric) in enumerate(METRICS):
83            self.widget.metrics_combo.activated.emit(i)
84            self.widget.metrics_combo.setCurrentIndex(i)
85            self.wait_until_stop_blocking()
86            if metric.supports_normalization:
87                expected = metric(self.iris, normalize=self.widget.normalized_dist)
88            else:
89                expected = metric(self.iris)
90
91            if metric is not distance.Jaccard:
92                np.testing.assert_array_almost_equal(
93                    expected, self.get_output(self.widget.Outputs.distances))
94
95    def test_error_message(self):
96        """Check if error message appears and then disappears when
97        data is removed from input"""
98        self.widget.metric_idx = 2
99        self.send_signal(self.widget.Inputs.data, self.iris)
100        self.wait_until_stop_blocking()
101        self.assertFalse(self.widget.Error.no_continuous_features.is_shown())
102        self.send_signal(self.widget.Inputs.data, self.titanic)
103        self.wait_until_stop_blocking()
104        self.assertTrue(self.widget.Error.no_continuous_features.is_shown())
105        self.send_signal(self.widget.Inputs.data, None)
106        self.assertFalse(self.widget.Error.no_continuous_features.is_shown())
107
108    def test_jaccard_messages(self):
109        for self.widget.metric_idx, (name, _) in enumerate(METRICS):
110            if name == "Jaccard":
111                break
112        self.send_signal(self.widget.Inputs.data, self.iris)
113        self.wait_until_stop_blocking()
114        self.assertTrue(self.widget.Error.no_binary_features.is_shown())
115        self.assertFalse(self.widget.Warning.ignoring_nonbinary.is_shown())
116
117        self.send_signal(self.widget.Inputs.data, None)
118        self.wait_until_stop_blocking()
119        self.assertFalse(self.widget.Error.no_binary_features.is_shown())
120        self.assertFalse(self.widget.Warning.ignoring_nonbinary.is_shown())
121
122        self.send_signal(self.widget.Inputs.data, self.titanic)
123        self.wait_until_stop_blocking()
124        self.assertFalse(self.widget.Error.no_binary_features.is_shown())
125        self.assertTrue(self.widget.Warning.ignoring_nonbinary.is_shown())
126
127        self.send_signal(self.widget.Inputs.data, None)
128        self.wait_until_stop_blocking()
129        self.assertFalse(self.widget.Error.no_binary_features.is_shown())
130        self.assertFalse(self.widget.Warning.ignoring_nonbinary.is_shown())
131
132        self.send_signal(self.widget.Inputs.data, self.titanic)
133        self.wait_until_stop_blocking()
134        self.assertFalse(self.widget.Error.no_binary_features.is_shown())
135        self.assertTrue(self.widget.Warning.ignoring_nonbinary.is_shown())
136
137        dom = self.titanic.domain
138        dom = Domain(dom.attributes[1:], dom.class_var)
139        self.send_signal(self.widget.Inputs.data, self.titanic.transform(dom))
140        self.wait_until_stop_blocking()
141        self.assertFalse(self.widget.Error.no_binary_features.is_shown())
142        self.assertFalse(self.widget.Warning.ignoring_nonbinary.is_shown())
143
144        self.send_signal(self.widget.Inputs.data, Table("heart_disease"))
145        self.wait_until_stop_blocking()
146        self.assertFalse(self.widget.Error.no_binary_features.is_shown())
147        self.assertFalse(self.widget.Warning.ignoring_discrete.is_shown())
148
149    def test_too_big_array(self):
150        """
151        Users sees an error message when calculating too large arrays and Orange
152        does not crash.
153        GH-2315
154        """
155        self.assertEqual(len(self.widget.Error.active), 0)
156        self.send_signal(self.widget.Inputs.data, self.iris)
157
158        mock = Mock(side_effect=ValueError)
159        self.widget.compute_distances(mock, self.iris)
160        self.wait_until_finished()
161        self.assertTrue(self.widget.Error.distances_value_error.is_shown())
162
163        mock = Mock(side_effect=MemoryError)
164        self.widget.compute_distances(mock, self.iris)
165        self.wait_until_finished()
166        self.assertEqual(len(self.widget.Error.active), 1)
167        self.assertTrue(self.widget.Error.distances_memory_error.is_shown())
168
169    def test_migrates_normalized_dist(self):
170        w = self.create_widget(OWDistances, stored_settings={"metric_idx": 0})
171        self.assertFalse(w.normalized_dist)
172
173    def test_negative_values_bhattacharyya(self):
174        self.iris.X[0, 0] *= -1
175        for self.widget.metric_idx, (_, metric) in enumerate(METRICS):
176            if metric == distance.Bhattacharyya:
177                break
178        self.send_signal(self.widget.Inputs.data, self.iris)
179        self.wait_until_finished()
180        self.assertTrue(self.widget.Error.distances_value_error.is_shown())
181        self.iris.X[0, 0] *= -1
182
183    def test_limit_mahalanobis(self):
184        def assert_error_shown():
185            self.assertTrue(
186                self.widget.Error.data_too_large_for_mahalanobis.is_shown())
187
188        def assert_no_error():
189            self.assertFalse(
190                self.widget.Error.data_too_large_for_mahalanobis.is_shown())
191
192        widget = self.widget
193        axis_buttons = widget.controls.axis.buttons
194
195        self.assertEqual(widget.metrics_combo.count(), len(METRICS))
196        for i, (_, metric) in enumerate(METRICS):
197            if metric == distance.Mahalanobis:
198                widget.metrics_combo.setCurrentIndex(i)
199                widget.metrics_combo.activated.emit(i)
200                break
201
202        X = np.random.random((1010, 4))
203        bigrows = Table.from_numpy(Domain(self.iris.domain.attributes), X)
204        bigcols = Table.from_numpy(
205            Domain([ContinuousVariable(f"{i}") for i in range(1010)]), X.T)
206
207        self.send_signal(widget.Inputs.data, self.iris)
208        assert_no_error()
209
210        axis_buttons[0].click()
211        assert_no_error()
212        axis_buttons[1].click()
213        assert_no_error()
214
215        # by columns -- cannot handle too many rows
216        self.send_signal(self.widget.Inputs.data, bigrows)
217        assert_error_shown()
218        axis_buttons[0].click()
219        assert_no_error()
220        axis_buttons[1].click()
221        assert_error_shown()
222
223        self.send_signal(self.widget.Inputs.data, bigcols)
224        assert_no_error()
225        axis_buttons[0].click()
226        assert_error_shown()
227
228        self.send_signal(widget.Inputs.data, self.iris)
229        assert_no_error()
230
231
232if __name__ == "__main__":
233    unittest.main()
234