1from scipy.sparse import issparse
2import bottleneck as bn
3
4from AnyQt.QtCore import Qt
5
6import Orange.data
7import Orange.misc
8from Orange import distance
9from Orange.widgets import gui
10from Orange.widgets.settings import Setting
11from Orange.widgets.utils.concurrent import TaskState, ConcurrentWidgetMixin
12from Orange.widgets.utils.sql import check_sql_input
13from Orange.widgets.utils.widgetpreview import WidgetPreview
14from Orange.widgets.widget import OWWidget, Msg, Input, Output
15
16
17METRICS = [
18    ("Euclidean", distance.Euclidean),
19    ("Manhattan", distance.Manhattan),
20    ("Cosine", distance.Cosine),
21    ("Jaccard", distance.Jaccard),
22    ("Spearman", distance.SpearmanR),
23    ("Absolute Spearman", distance.SpearmanRAbsolute),
24    ("Pearson", distance.PearsonR),
25    ("Absolute Pearson", distance.PearsonRAbsolute),
26    ("Hamming", distance.Hamming),
27    ("Mahalanobis", distance.Mahalanobis),
28    ('Bhattacharyya', distance.Bhattacharyya)
29]
30
31
32class InterruptException(Exception):
33    pass
34
35
36class DistanceRunner:
37    @staticmethod
38    def run(data: Orange.data.Table, metric: distance, normalized_dist: bool,
39            axis: int, state: TaskState) -> Orange.misc.DistMatrix:
40        if data is None:
41            return None
42
43        def callback(i: float) -> bool:
44            state.set_progress_value(i)
45            if state.is_interruption_requested():
46                raise InterruptException
47
48        state.set_status("Calculating...")
49        kwargs = {"axis": 1 - axis, "impute": True, "callback": callback}
50        if metric.supports_normalization and normalized_dist:
51            kwargs["normalize"] = True
52        return metric(data, **kwargs)
53
54
55class OWDistances(OWWidget, ConcurrentWidgetMixin):
56    name = "Distances"
57    description = "Compute a matrix of pairwise distances."
58    icon = "icons/Distance.svg"
59    keywords = []
60
61    class Inputs:
62        data = Input("Data", Orange.data.Table)
63
64    class Outputs:
65        distances = Output("Distances", Orange.misc.DistMatrix, dynamic=False)
66
67    settings_version = 3
68
69    axis = Setting(0)        # type: int
70    metric_idx = Setting(0)  # type: int
71
72    #: Use normalized distances if the metric supports it.
73    #: The default is `True`, expect when restoring from old pre v2 settings
74    #: (see `migrate_settings`).
75    normalized_dist = Setting(True)  # type: bool
76    autocommit = Setting(True)       # type: bool
77
78    want_main_area = False
79    resizing_enabled = False
80
81    class Error(OWWidget.Error):
82        no_continuous_features = Msg("No numeric features")
83        no_binary_features = Msg("No binary features")
84        dense_metric_sparse_data = Msg("{} requires dense data.")
85        distances_memory_error = Msg("Not enough memory")
86        distances_value_error = Msg("Problem in calculation:\n{}")
87        data_too_large_for_mahalanobis = Msg(
88            "Mahalanobis handles up to 1000 {}.")
89
90    class Warning(OWWidget.Warning):
91        ignoring_discrete = Msg("Ignoring categorical features")
92        ignoring_nonbinary = Msg("Ignoring non-binary features")
93        imputing_data = Msg("Missing values were imputed")
94
95    def __init__(self):
96        OWWidget.__init__(self)
97        ConcurrentWidgetMixin.__init__(self)
98
99        self.data = None
100
101        gui.radioButtons(
102            self.controlArea, self, "axis", ["Rows", "Columns"],
103            box="Distances between", callback=self._invalidate
104        )
105        box = gui.widgetBox(self.controlArea, "Distance Metric")
106        self.metrics_combo = gui.comboBox(
107            box, self, "metric_idx",
108            items=[m[0] for m in METRICS],
109            callback=self._metric_changed
110        )
111        self.normalization_check = gui.checkBox(
112            box, self, "normalized_dist", "Normalized",
113            callback=self._invalidate,
114            tooltip=("All dimensions are (implicitly) scaled to a common"
115                     "scale to normalize the influence across the domain."),
116            stateWhenDisabled=False, attribute=Qt.WA_LayoutUsesWidgetRect
117        )
118        _, metric = METRICS[self.metric_idx]
119        self.normalization_check.setEnabled(metric.supports_normalization)
120
121        gui.auto_apply(self.buttonsArea, self, "autocommit")
122
123    @Inputs.data
124    @check_sql_input
125    def set_data(self, data):
126        self.cancel()
127        self.data = data
128        self.refresh_metrics()
129        self.unconditional_commit()
130
131    def refresh_metrics(self):
132        sparse = self.data is not None and issparse(self.data.X)
133        for i, metric in enumerate(METRICS):
134            item = self.metrics_combo.model().item(i)
135            item.setEnabled(not sparse or metric[1].supports_sparse)
136
137    def commit(self):
138        # pylint: disable=invalid-sequence-index
139        metric = METRICS[self.metric_idx][1]
140        self.compute_distances(metric, self.data)
141
142    def compute_distances(self, metric, data):
143        def _check_sparse():
144            # pylint: disable=invalid-sequence-index
145            if issparse(data.X) and not metric.supports_sparse:
146                self.Error.dense_metric_sparse_data(METRICS[self.metric_idx][0])
147                return False
148            return True
149
150        def _fix_discrete():
151            nonlocal data
152            if data.domain.has_discrete_attributes() \
153                    and metric is not distance.Jaccard \
154                    and (issparse(data.X) and getattr(metric, "fallback", None)
155                         or not metric.supports_discrete
156                         or self.axis == 1):
157                if not data.domain.has_continuous_attributes():
158                    self.Error.no_continuous_features()
159                    return False
160                self.Warning.ignoring_discrete()
161                data = distance.remove_discrete_features(data)
162            return True
163
164        def _fix_nonbinary():
165            nonlocal data
166            if metric is distance.Jaccard and not issparse(data.X):
167                nbinary = sum(a.is_discrete and len(a.values) == 2
168                              for a in data.domain.attributes)
169                if not nbinary:
170                    self.Error.no_binary_features()
171                    return False
172                elif nbinary < len(data.domain.attributes):
173                    self.Warning.ignoring_nonbinary()
174                    data = distance.remove_nonbinary_features(data)
175            return True
176
177        def _fix_missing():
178            nonlocal data
179            if not metric.supports_missing and bn.anynan(data.X):
180                self.Warning.imputing_data()
181                data = distance.impute(data)
182            return True
183
184        def _check_tractability():
185            if metric is distance.Mahalanobis:
186                if self.axis == 1:
187                    # when computing distances by columns, we want < 100 rows
188                    if len(data) > 1000:
189                        self.Error.data_too_large_for_mahalanobis("rows")
190                        return False
191                else:
192                    if len(data.domain.attributes) > 1000:
193                        self.Error.data_too_large_for_mahalanobis("columns")
194                        return False
195            return True
196
197        self.clear_messages()
198        if data is not None:
199            for check in (_check_sparse, _check_tractability,
200                          _fix_discrete, _fix_missing, _fix_nonbinary):
201                if not check():
202                    data = None
203                    break
204
205        self.start(DistanceRunner.run, data, metric,
206                   self.normalized_dist, self.axis)
207
208    def on_partial_result(self, _):
209        pass
210
211    def on_done(self, result: Orange.misc.DistMatrix):
212        assert isinstance(result, Orange.misc.DistMatrix) or result is None
213        self.Outputs.distances.send(result)
214
215    def on_exception(self, ex):
216        if isinstance(ex, ValueError):
217            self.Error.distances_value_error(ex)
218        elif isinstance(ex, MemoryError):
219            self.Error.distances_memory_error()
220        elif isinstance(ex, InterruptException):
221            pass
222        else:
223            raise ex
224
225    def onDeleteWidget(self):
226        self.shutdown()
227        super().onDeleteWidget()
228
229    def _invalidate(self):
230        self.commit()
231
232    def _metric_changed(self):
233        metric = METRICS[self.metric_idx][1]
234        self.normalization_check.setEnabled(metric.supports_normalization)
235        self._invalidate()
236
237    def send_report(self):
238        # pylint: disable=invalid-sequence-index
239        self.report_items((
240            ("Distances Between", ["Rows", "Columns"][self.axis]),
241            ("Metric", METRICS[self.metric_idx][0])
242        ))
243
244    @classmethod
245    def migrate_settings(cls, settings, version):
246        if version is None or version < 2 and "normalized_dist" not in settings:
247            # normalize_dist is set to False when restoring settings from
248            # an older version to preserve old semantics.
249            settings["normalized_dist"] = False
250        if version is None or version < 3:
251            # Mahalanobis was moved from idx = 2 to idx = 9
252            metric_idx = settings["metric_idx"]
253            if metric_idx == 2:
254                settings["metric_idx"] = 9
255            elif 2 < metric_idx <= 9:
256                settings["metric_idx"] -= 1
257
258
259if __name__ == "__main__":  # pragma: no cover
260    WidgetPreview(OWDistances).run(Orange.data.Table("iris"))
261