1import numpy as np
2
3from AnyQt.QtCore import Qt
4
5from Orange.data import Table, Domain, ContinuousVariable
6from Orange.data.util import get_unique_names
7from Orange.preprocess import RemoveNaNColumns, Impute
8from Orange import distance
9from Orange.widgets import gui
10from Orange.widgets.settings import Setting
11from Orange.widgets.utils.signals import Input, Output
12from Orange.widgets.widget import OWWidget, Msg
13from Orange.widgets.utils.widgetpreview import WidgetPreview
14
15METRICS = [
16    ("Euclidean", distance.Euclidean),
17    ("Manhattan", distance.Manhattan),
18    ("Mahalanobis", distance.Mahalanobis),
19    ("Cosine", distance.Cosine),
20    ("Jaccard", distance.Jaccard),
21    ("Spearman", distance.SpearmanR),
22    ("Absolute Spearman", distance.SpearmanRAbsolute),
23    ("Pearson", distance.PearsonR),
24    ("Absolute Pearson", distance.PearsonRAbsolute),
25]
26
27
28class OWNeighbors(OWWidget):
29    name = "Neighbors"
30    description = "Compute nearest neighbors in data according to reference."
31    icon = "icons/Neighbors.svg"
32
33    replaces = ["orangecontrib.prototypes.widgets.owneighbours.OWNeighbours"]
34
35    class Inputs:
36        data = Input("Data", Table)
37        reference = Input("Reference", Table)
38
39    class Outputs:
40        data = Output("Neighbors", Table)
41
42    class Info(OWWidget.Warning):
43        removed_references = \
44            Msg("Input data includes reference instance(s).\n"
45                "Reference instances are excluded from the output.")
46
47    class Warning(OWWidget.Warning):
48        all_data_as_reference = \
49            Msg("Every data instance is same as some reference")
50
51    class Error(OWWidget.Error):
52        diff_domains = Msg("Data and reference have different features")
53
54    n_neighbors: int
55    distance_index: int
56
57    n_neighbors = Setting(10)
58    limit_neighbors = Setting(True)
59    distance_index = Setting(0)
60    auto_apply = Setting(True)
61
62    want_main_area = False
63    resizing_enabled = False
64
65    def __init__(self):
66        super().__init__()
67
68        self.data = None
69        self.reference = None
70        self.distances = None
71
72        box = gui.vBox(self.controlArea, box=True)
73        gui.comboBox(
74            box, self, "distance_index", orientation=Qt.Horizontal,
75            label="Distance metric: ", items=[d[0] for d in METRICS],
76            callback=self.recompute)
77        gui.spin(
78            box, self, "n_neighbors", label="Limit number of neighbors to:",
79            step=1, spinType=int, minv=0, maxv=100, checked='limit_neighbors',
80            # call apply by gui.auto_commit, pylint: disable=unnecessary-lambda
81            checkCallback=lambda: self.apply(),
82            callback=lambda: self.apply())
83
84        self.apply_button = gui.auto_apply(self.buttonsArea, self, commit=self.apply)
85
86    @Inputs.data
87    def set_data(self, data):
88        self.controls.n_neighbors.setMaximum(len(data) if data else 100)
89        self.data = data
90
91    @Inputs.reference
92    def set_ref(self, refs):
93        self.reference = refs
94
95    def handleNewSignals(self):
96        self.compute_distances()
97        self.unconditional_apply()
98
99    def recompute(self):
100        self.compute_distances()
101        self.apply()
102
103    def compute_distances(self):
104        self.Error.diff_domains.clear()
105        if not self.data or not self.reference:
106            self.distances = None
107            return
108        if set(self.reference.domain.attributes) != \
109                set(self.data.domain.attributes):
110            self.Error.diff_domains()
111            self.distances = None
112            return
113
114        metric = METRICS[self.distance_index][1]
115        n_ref = len(self.reference)
116
117        # comparing only attributes, no metas and class-vars
118        new_domain = Domain(self.data.domain.attributes)
119        reference = self.reference.transform(new_domain)
120        data = self.data.transform(new_domain)
121
122        all_data = Table.concatenate([reference, data], 0)
123        pp_all_data = Impute()(RemoveNaNColumns()(all_data))
124        pp_reference, pp_data = pp_all_data[:n_ref], pp_all_data[n_ref:]
125        self.distances = metric(pp_data, pp_reference).min(axis=1)
126
127    def apply(self):
128        indices = self._compute_indices()
129
130        if indices is None:
131            neighbors = None
132        else:
133            neighbors = self._data_with_similarity(indices)
134        self.Outputs.data.send(neighbors)
135
136    def _compute_indices(self):
137        self.Warning.all_data_as_reference.clear()
138        self.Info.removed_references.clear()
139
140        if self.distances is None:
141            return None
142
143        inrefs = np.isin(self.data.ids, self.reference.ids)
144        if np.all(inrefs):
145            self.Warning.all_data_as_reference()
146            return None
147        if np.any(inrefs):
148            self.Info.removed_references()
149
150        dist = np.copy(self.distances)
151        dist[inrefs] = np.max(dist) + 1
152        up_to = len(dist) - np.sum(inrefs)
153        if self.limit_neighbors and self.n_neighbors < up_to:
154            up_to = self.n_neighbors
155        return np.argpartition(dist, up_to - 1)[:up_to]
156
157    def _data_with_similarity(self, indices):
158        data = self.data
159        varname = get_unique_names(data.domain, "distance")
160        metas = data.domain.metas + (ContinuousVariable(varname), )
161        domain = Domain(data.domain.attributes, data.domain.class_vars, metas)
162        data_metas = self.distances[indices].reshape((-1, 1))
163        if data.domain.metas:
164            data_metas = np.hstack((data.metas[indices], data_metas))
165        neighbors = Table(domain, data.X[indices], data.Y[indices], data_metas)
166        neighbors.ids = data.ids[indices]
167        neighbors.attributes = self.data.attributes
168        return neighbors
169
170
171if __name__ == "__main__":  # pragma: no cover
172    iris = Table("iris.tab")
173    WidgetPreview(OWNeighbors).run(
174        set_data=iris,
175        set_ref=iris[:1])
176