1
2from collections import defaultdict
3
4from Orange.widgets import widget, gui, settings
5from Orange.widgets.utils.itemmodels import PyTableModel
6from Orange.widgets.widget import Input, Output
7
8from skfusion import fusion
9from orangecontrib.datafusion.models import Relation, FusionGraph, RelationCompleter
10from orangecontrib.datafusion.widgets.owfusiongraph import rel_shape, rel_cols
11
12import numpy as np
13
14
15class MeanBy:
16    ROWS = 'Rows'
17    COLUMNS = 'Columns'
18    VALUES = 'All values'
19    all = (COLUMNS, ROWS, VALUES)
20
21
22class MeanFuser(RelationCompleter):
23    def __init__(self, mean_by):
24        self.axis = {
25            MeanBy.ROWS: 1,
26            MeanBy.COLUMNS: 0,
27            MeanBy.VALUES: None}[MeanBy.all[mean_by]]
28        self.mean_by = mean_by
29
30    @property
31    def name(self):
32        return 'Mean by ' + MeanBy.all[self.mean_by].lower()
33
34    def __getattr__(self, attr):
35        return self
36
37    def retrain(self):
38        """Mean is deterministic, return the same Completer."""
39        return self
40
41    def can_complete(self, relation):
42        """MeanFuser can complete any relation."""
43        return True
44
45    def complete(self, relation):
46        """Mock ``skfusion.fusion.FusionFit.complete()``"""
47        assert isinstance(relation, fusion.Relation)
48        A = relation.data.copy()
49        if not np.ma.is_masked(A):
50            return A
51        mean_value = np.nanmean(A, axis=None)
52        if self.axis is None:
53            # Replace the mask with mean of the matrix
54            A[A.mask] = mean_value
55        else:
56            # Replace the mask with mean by axes
57            mean = np.nanmean(A, axis=self.axis)
58            # Replace any NaNs in mean with mean of the matrix
59            mean[np.isnan(mean)] = mean_value
60            A[A.mask] = np.take(mean, A.mask.nonzero()[not self.axis])
61        return A
62
63
64class OWMeanFuser(widget.OWWidget):
65    name = 'Mean Fuser'
66    priority = 55000
67    icon = 'icons/MeanFuser.svg'
68
69    class Inputs:
70        fusion_graph = Input('Fusion graph', FusionGraph)
71        relation = Input('Relation', Relation, multiple=True)
72
73    class Outputs:
74        fuser = Output('Mean-fitted fusion graph', MeanFuser, default=True)
75        relation = Output('Relation', Relation)
76
77    want_main_area = False
78
79    mean_by = settings.Setting(0)
80    selected_relation = settings.Setting(0)
81
82    def __init__(self):
83        super().__init__()
84        self.relations = defaultdict(int)
85        self.id_relations = {}
86        self.graph = None
87        self._create_layout()
88        self.commit()
89
90    def _create_layout(self):
91        self.controlArea.layout().addWidget(
92            gui.comboBox(self.controlArea, self, 'mean_by',
93                         box='Mean fuser',
94                         label='Calculate masked values as mean by:',
95                         items=MeanBy.all, callback=self.commit))
96        box = gui.widgetBox(self.controlArea, 'Output completed relation')
97
98        class TableView(gui.TableView):
99            def __init__(self, parent):
100                super().__init__(parent, selectionMode=self.SingleSelection)
101                self._parent = parent
102                self.bold_font = self.BoldFontDelegate(self)   # member because PyQt sometimes unrefs too early
103                self.setItemDelegateForColumn(2, self.bold_font)
104                self.setItemDelegateForColumn(4, self.bold_font)
105                self.horizontalHeader().setVisible(False)
106
107            def selectionChanged(self, *args):
108                super().selectionChanged(*args)
109                self._parent.commit()
110
111        table = self.table = TableView(self)
112        model = self.model = PyTableModel(parent=self)
113        table.setModel(model)
114        box.layout().addWidget(table)
115        self.controlArea.layout().addStretch(1)
116
117    def commit(self, item=None):
118        self.fuser = MeanFuser(self.mean_by)
119        self.Outputs.fuser.send(self.fuser)
120        rows = [i.row() for i in self.table.selectionModel().selectedRows()]
121        if self.model.rowCount() and rows:
122            relation = self.model[rows[0]][0]
123            data = Relation.create(self.fuser.complete(relation),
124                                   relation.row_type,
125                                   relation.col_type,
126                                   self.graph)
127        else:
128            data = None
129        self.Outputs.relation.send(data)
130
131    def update_table(self):
132        self.model.wrap([([rel, rel_shape(rel.data)] +
133                          rel_cols(rel) +
134                          ['(not masked)' if not np.ma.is_masked(rel.data) else ''])
135                         for rel in self.relations])
136        self.table.hideColumn(0)
137
138    def _add_relation(self, relation):
139        self.relations[relation] += 1
140
141    def _remove_relation(self, relation):
142        self.relations[relation] -= 1
143        if not self.relations[relation]:
144            del self.relations[relation]
145
146    @Inputs.fusion_graph
147    def on_fusion_graph_change(self, graph):
148        if graph:
149            self.graph = graph
150            for rel in graph.relations:
151                self._add_relation(rel)
152        else:
153            self.graph = None
154            for rel in self.graph.relations:
155                self._remove_relation(rel)
156        self.update_table()
157        self.commit()
158
159    @Inputs.relation
160    def on_relation_change(self, relation, id):
161        try: self._remove_relation(self.id_relations.pop(id))
162        except KeyError: pass
163        if relation:
164            self.id_relations[id] = relation.relation
165            self._add_relation(relation.relation)
166        self.update_table()
167        self.commit()
168
169
170def main():
171    from AnyQt.QtWidgets import QApplication
172    t1 = fusion.ObjectType('Users', 10)
173    t2 = fusion.ObjectType('Movies', 30)
174    t3 = fusion.ObjectType('Actors', 40)
175
176    # test that MeanFuser completes correctly
177    R = np.ma.array([[1, 1, 0],
178                     [3, 0, 0]], mask=[[0, 0, 1],
179                                       [0, 1, 1]], dtype=float)
180    rel = fusion.Relation(R, t1, t2)
181    assert (MeanFuser(0).complete(rel) == [[1, 1, 5/3],
182                                           [3, 1, 5/3]]).all()
183    assert (MeanFuser(1).complete(rel) == [[1, 1, 1],
184                                           [3, 3, 3]]).all()
185    assert (MeanFuser(2).complete(rel) == [[1,   1, 5/3],
186                                           [3, 5/3, 5/3]]).all()
187
188    R1 = np.ma.array(np.random.random((20, 20)))
189    R2 = np.ma.array(np.random.random((40, 40)),
190                     mask=np.random.random((40,40)) > .8)
191    relations = [
192        fusion.Relation(R1, t1, t2, name='like'),
193        fusion.Relation(R2, t3, t2, name='feature in'),
194    ]
195    G = fusion.FusionGraph()
196    G.add_relations_from(relations)
197    app = QApplication([])
198    w = OWMeanFuser()
199    w.on_fusion_graph_change(G)
200    w.show()
201    app.exec()
202
203
204if __name__ == "__main__":
205    main()
206