1from scipy.sparse import issparse 2import bottleneck as bn 3 4from AnyQt.QtCore import Qt 5 6import Orange.data 7import Orange.misc 8from Orange import distance 9from Orange.widgets import gui 10from Orange.widgets.settings import Setting 11from Orange.widgets.utils.concurrent import TaskState, ConcurrentWidgetMixin 12from Orange.widgets.utils.sql import check_sql_input 13from Orange.widgets.utils.widgetpreview import WidgetPreview 14from Orange.widgets.widget import OWWidget, Msg, Input, Output 15 16 17METRICS = [ 18 ("Euclidean", distance.Euclidean), 19 ("Manhattan", distance.Manhattan), 20 ("Cosine", distance.Cosine), 21 ("Jaccard", distance.Jaccard), 22 ("Spearman", distance.SpearmanR), 23 ("Absolute Spearman", distance.SpearmanRAbsolute), 24 ("Pearson", distance.PearsonR), 25 ("Absolute Pearson", distance.PearsonRAbsolute), 26 ("Hamming", distance.Hamming), 27 ("Mahalanobis", distance.Mahalanobis), 28 ('Bhattacharyya', distance.Bhattacharyya) 29] 30 31 32class InterruptException(Exception): 33 pass 34 35 36class DistanceRunner: 37 @staticmethod 38 def run(data: Orange.data.Table, metric: distance, normalized_dist: bool, 39 axis: int, state: TaskState) -> Orange.misc.DistMatrix: 40 if data is None: 41 return None 42 43 def callback(i: float) -> bool: 44 state.set_progress_value(i) 45 if state.is_interruption_requested(): 46 raise InterruptException 47 48 state.set_status("Calculating...") 49 kwargs = {"axis": 1 - axis, "impute": True, "callback": callback} 50 if metric.supports_normalization and normalized_dist: 51 kwargs["normalize"] = True 52 return metric(data, **kwargs) 53 54 55class OWDistances(OWWidget, ConcurrentWidgetMixin): 56 name = "Distances" 57 description = "Compute a matrix of pairwise distances." 58 icon = "icons/Distance.svg" 59 keywords = [] 60 61 class Inputs: 62 data = Input("Data", Orange.data.Table) 63 64 class Outputs: 65 distances = Output("Distances", Orange.misc.DistMatrix, dynamic=False) 66 67 settings_version = 3 68 69 axis = Setting(0) # type: int 70 metric_idx = Setting(0) # type: int 71 72 #: Use normalized distances if the metric supports it. 73 #: The default is `True`, expect when restoring from old pre v2 settings 74 #: (see `migrate_settings`). 75 normalized_dist = Setting(True) # type: bool 76 autocommit = Setting(True) # type: bool 77 78 want_main_area = False 79 resizing_enabled = False 80 81 class Error(OWWidget.Error): 82 no_continuous_features = Msg("No numeric features") 83 no_binary_features = Msg("No binary features") 84 dense_metric_sparse_data = Msg("{} requires dense data.") 85 distances_memory_error = Msg("Not enough memory") 86 distances_value_error = Msg("Problem in calculation:\n{}") 87 data_too_large_for_mahalanobis = Msg( 88 "Mahalanobis handles up to 1000 {}.") 89 90 class Warning(OWWidget.Warning): 91 ignoring_discrete = Msg("Ignoring categorical features") 92 ignoring_nonbinary = Msg("Ignoring non-binary features") 93 imputing_data = Msg("Missing values were imputed") 94 95 def __init__(self): 96 OWWidget.__init__(self) 97 ConcurrentWidgetMixin.__init__(self) 98 99 self.data = None 100 101 gui.radioButtons( 102 self.controlArea, self, "axis", ["Rows", "Columns"], 103 box="Distances between", callback=self._invalidate 104 ) 105 box = gui.widgetBox(self.controlArea, "Distance Metric") 106 self.metrics_combo = gui.comboBox( 107 box, self, "metric_idx", 108 items=[m[0] for m in METRICS], 109 callback=self._metric_changed 110 ) 111 self.normalization_check = gui.checkBox( 112 box, self, "normalized_dist", "Normalized", 113 callback=self._invalidate, 114 tooltip=("All dimensions are (implicitly) scaled to a common" 115 "scale to normalize the influence across the domain."), 116 stateWhenDisabled=False, attribute=Qt.WA_LayoutUsesWidgetRect 117 ) 118 _, metric = METRICS[self.metric_idx] 119 self.normalization_check.setEnabled(metric.supports_normalization) 120 121 gui.auto_apply(self.buttonsArea, self, "autocommit") 122 123 @Inputs.data 124 @check_sql_input 125 def set_data(self, data): 126 self.cancel() 127 self.data = data 128 self.refresh_metrics() 129 self.unconditional_commit() 130 131 def refresh_metrics(self): 132 sparse = self.data is not None and issparse(self.data.X) 133 for i, metric in enumerate(METRICS): 134 item = self.metrics_combo.model().item(i) 135 item.setEnabled(not sparse or metric[1].supports_sparse) 136 137 def commit(self): 138 # pylint: disable=invalid-sequence-index 139 metric = METRICS[self.metric_idx][1] 140 self.compute_distances(metric, self.data) 141 142 def compute_distances(self, metric, data): 143 def _check_sparse(): 144 # pylint: disable=invalid-sequence-index 145 if issparse(data.X) and not metric.supports_sparse: 146 self.Error.dense_metric_sparse_data(METRICS[self.metric_idx][0]) 147 return False 148 return True 149 150 def _fix_discrete(): 151 nonlocal data 152 if data.domain.has_discrete_attributes() \ 153 and metric is not distance.Jaccard \ 154 and (issparse(data.X) and getattr(metric, "fallback", None) 155 or not metric.supports_discrete 156 or self.axis == 1): 157 if not data.domain.has_continuous_attributes(): 158 self.Error.no_continuous_features() 159 return False 160 self.Warning.ignoring_discrete() 161 data = distance.remove_discrete_features(data) 162 return True 163 164 def _fix_nonbinary(): 165 nonlocal data 166 if metric is distance.Jaccard and not issparse(data.X): 167 nbinary = sum(a.is_discrete and len(a.values) == 2 168 for a in data.domain.attributes) 169 if not nbinary: 170 self.Error.no_binary_features() 171 return False 172 elif nbinary < len(data.domain.attributes): 173 self.Warning.ignoring_nonbinary() 174 data = distance.remove_nonbinary_features(data) 175 return True 176 177 def _fix_missing(): 178 nonlocal data 179 if not metric.supports_missing and bn.anynan(data.X): 180 self.Warning.imputing_data() 181 data = distance.impute(data) 182 return True 183 184 def _check_tractability(): 185 if metric is distance.Mahalanobis: 186 if self.axis == 1: 187 # when computing distances by columns, we want < 100 rows 188 if len(data) > 1000: 189 self.Error.data_too_large_for_mahalanobis("rows") 190 return False 191 else: 192 if len(data.domain.attributes) > 1000: 193 self.Error.data_too_large_for_mahalanobis("columns") 194 return False 195 return True 196 197 self.clear_messages() 198 if data is not None: 199 for check in (_check_sparse, _check_tractability, 200 _fix_discrete, _fix_missing, _fix_nonbinary): 201 if not check(): 202 data = None 203 break 204 205 self.start(DistanceRunner.run, data, metric, 206 self.normalized_dist, self.axis) 207 208 def on_partial_result(self, _): 209 pass 210 211 def on_done(self, result: Orange.misc.DistMatrix): 212 assert isinstance(result, Orange.misc.DistMatrix) or result is None 213 self.Outputs.distances.send(result) 214 215 def on_exception(self, ex): 216 if isinstance(ex, ValueError): 217 self.Error.distances_value_error(ex) 218 elif isinstance(ex, MemoryError): 219 self.Error.distances_memory_error() 220 elif isinstance(ex, InterruptException): 221 pass 222 else: 223 raise ex 224 225 def onDeleteWidget(self): 226 self.shutdown() 227 super().onDeleteWidget() 228 229 def _invalidate(self): 230 self.commit() 231 232 def _metric_changed(self): 233 metric = METRICS[self.metric_idx][1] 234 self.normalization_check.setEnabled(metric.supports_normalization) 235 self._invalidate() 236 237 def send_report(self): 238 # pylint: disable=invalid-sequence-index 239 self.report_items(( 240 ("Distances Between", ["Rows", "Columns"][self.axis]), 241 ("Metric", METRICS[self.metric_idx][0]) 242 )) 243 244 @classmethod 245 def migrate_settings(cls, settings, version): 246 if version is None or version < 2 and "normalized_dist" not in settings: 247 # normalize_dist is set to False when restoring settings from 248 # an older version to preserve old semantics. 249 settings["normalized_dist"] = False 250 if version is None or version < 3: 251 # Mahalanobis was moved from idx = 2 to idx = 9 252 metric_idx = settings["metric_idx"] 253 if metric_idx == 2: 254 settings["metric_idx"] = 9 255 elif 2 < metric_idx <= 9: 256 settings["metric_idx"] -= 1 257 258 259if __name__ == "__main__": # pragma: no cover 260 WidgetPreview(OWDistances).run(Orange.data.Table("iris")) 261