1 2from collections import defaultdict 3 4from Orange.widgets import widget, gui, settings 5from Orange.widgets.utils.itemmodels import PyTableModel 6from Orange.widgets.widget import Input, Output 7 8from skfusion import fusion 9from orangecontrib.datafusion.models import Relation, FusionGraph, RelationCompleter 10from orangecontrib.datafusion.widgets.owfusiongraph import rel_shape, rel_cols 11 12import numpy as np 13 14 15class MeanBy: 16 ROWS = 'Rows' 17 COLUMNS = 'Columns' 18 VALUES = 'All values' 19 all = (COLUMNS, ROWS, VALUES) 20 21 22class MeanFuser(RelationCompleter): 23 def __init__(self, mean_by): 24 self.axis = { 25 MeanBy.ROWS: 1, 26 MeanBy.COLUMNS: 0, 27 MeanBy.VALUES: None}[MeanBy.all[mean_by]] 28 self.mean_by = mean_by 29 30 @property 31 def name(self): 32 return 'Mean by ' + MeanBy.all[self.mean_by].lower() 33 34 def __getattr__(self, attr): 35 return self 36 37 def retrain(self): 38 """Mean is deterministic, return the same Completer.""" 39 return self 40 41 def can_complete(self, relation): 42 """MeanFuser can complete any relation.""" 43 return True 44 45 def complete(self, relation): 46 """Mock ``skfusion.fusion.FusionFit.complete()``""" 47 assert isinstance(relation, fusion.Relation) 48 A = relation.data.copy() 49 if not np.ma.is_masked(A): 50 return A 51 mean_value = np.nanmean(A, axis=None) 52 if self.axis is None: 53 # Replace the mask with mean of the matrix 54 A[A.mask] = mean_value 55 else: 56 # Replace the mask with mean by axes 57 mean = np.nanmean(A, axis=self.axis) 58 # Replace any NaNs in mean with mean of the matrix 59 mean[np.isnan(mean)] = mean_value 60 A[A.mask] = np.take(mean, A.mask.nonzero()[not self.axis]) 61 return A 62 63 64class OWMeanFuser(widget.OWWidget): 65 name = 'Mean Fuser' 66 priority = 55000 67 icon = 'icons/MeanFuser.svg' 68 69 class Inputs: 70 fusion_graph = Input('Fusion graph', FusionGraph) 71 relation = Input('Relation', Relation, multiple=True) 72 73 class Outputs: 74 fuser = Output('Mean-fitted fusion graph', MeanFuser, default=True) 75 relation = Output('Relation', Relation) 76 77 want_main_area = False 78 79 mean_by = settings.Setting(0) 80 selected_relation = settings.Setting(0) 81 82 def __init__(self): 83 super().__init__() 84 self.relations = defaultdict(int) 85 self.id_relations = {} 86 self.graph = None 87 self._create_layout() 88 self.commit() 89 90 def _create_layout(self): 91 self.controlArea.layout().addWidget( 92 gui.comboBox(self.controlArea, self, 'mean_by', 93 box='Mean fuser', 94 label='Calculate masked values as mean by:', 95 items=MeanBy.all, callback=self.commit)) 96 box = gui.widgetBox(self.controlArea, 'Output completed relation') 97 98 class TableView(gui.TableView): 99 def __init__(self, parent): 100 super().__init__(parent, selectionMode=self.SingleSelection) 101 self._parent = parent 102 self.bold_font = self.BoldFontDelegate(self) # member because PyQt sometimes unrefs too early 103 self.setItemDelegateForColumn(2, self.bold_font) 104 self.setItemDelegateForColumn(4, self.bold_font) 105 self.horizontalHeader().setVisible(False) 106 107 def selectionChanged(self, *args): 108 super().selectionChanged(*args) 109 self._parent.commit() 110 111 table = self.table = TableView(self) 112 model = self.model = PyTableModel(parent=self) 113 table.setModel(model) 114 box.layout().addWidget(table) 115 self.controlArea.layout().addStretch(1) 116 117 def commit(self, item=None): 118 self.fuser = MeanFuser(self.mean_by) 119 self.Outputs.fuser.send(self.fuser) 120 rows = [i.row() for i in self.table.selectionModel().selectedRows()] 121 if self.model.rowCount() and rows: 122 relation = self.model[rows[0]][0] 123 data = Relation.create(self.fuser.complete(relation), 124 relation.row_type, 125 relation.col_type, 126 self.graph) 127 else: 128 data = None 129 self.Outputs.relation.send(data) 130 131 def update_table(self): 132 self.model.wrap([([rel, rel_shape(rel.data)] + 133 rel_cols(rel) + 134 ['(not masked)' if not np.ma.is_masked(rel.data) else '']) 135 for rel in self.relations]) 136 self.table.hideColumn(0) 137 138 def _add_relation(self, relation): 139 self.relations[relation] += 1 140 141 def _remove_relation(self, relation): 142 self.relations[relation] -= 1 143 if not self.relations[relation]: 144 del self.relations[relation] 145 146 @Inputs.fusion_graph 147 def on_fusion_graph_change(self, graph): 148 if graph: 149 self.graph = graph 150 for rel in graph.relations: 151 self._add_relation(rel) 152 else: 153 self.graph = None 154 for rel in self.graph.relations: 155 self._remove_relation(rel) 156 self.update_table() 157 self.commit() 158 159 @Inputs.relation 160 def on_relation_change(self, relation, id): 161 try: self._remove_relation(self.id_relations.pop(id)) 162 except KeyError: pass 163 if relation: 164 self.id_relations[id] = relation.relation 165 self._add_relation(relation.relation) 166 self.update_table() 167 self.commit() 168 169 170def main(): 171 from AnyQt.QtWidgets import QApplication 172 t1 = fusion.ObjectType('Users', 10) 173 t2 = fusion.ObjectType('Movies', 30) 174 t3 = fusion.ObjectType('Actors', 40) 175 176 # test that MeanFuser completes correctly 177 R = np.ma.array([[1, 1, 0], 178 [3, 0, 0]], mask=[[0, 0, 1], 179 [0, 1, 1]], dtype=float) 180 rel = fusion.Relation(R, t1, t2) 181 assert (MeanFuser(0).complete(rel) == [[1, 1, 5/3], 182 [3, 1, 5/3]]).all() 183 assert (MeanFuser(1).complete(rel) == [[1, 1, 1], 184 [3, 3, 3]]).all() 185 assert (MeanFuser(2).complete(rel) == [[1, 1, 5/3], 186 [3, 5/3, 5/3]]).all() 187 188 R1 = np.ma.array(np.random.random((20, 20))) 189 R2 = np.ma.array(np.random.random((40, 40)), 190 mask=np.random.random((40,40)) > .8) 191 relations = [ 192 fusion.Relation(R1, t1, t2, name='like'), 193 fusion.Relation(R2, t3, t2, name='feature in'), 194 ] 195 G = fusion.FusionGraph() 196 G.add_relations_from(relations) 197 app = QApplication([]) 198 w = OWMeanFuser() 199 w.on_fusion_graph_change(G) 200 w.show() 201 app.exec() 202 203 204if __name__ == "__main__": 205 main() 206