1import sys
2import copy
3
4from AnyQt.QtCore import Qt
5from AnyQt.QtWidgets import QGridLayout, QSizePolicy
6
7from Orange.data import Table
8from Orange.widgets import widget, gui, settings
9from Orange.widgets.widget import OWWidget, Input, Output
10from skfusion import fusion
11from orangecontrib.datafusion.models import Relation
12
13import numpy as np
14
15
16class SampleBy:
17    ROWS = 'Rows'
18    COLS = 'Columns'
19    ROWS_COLS = 'Rows and columns'
20    ENTRIES = 'Entries'
21    all = [ROWS, COLS, ROWS_COLS, ENTRIES]
22
23
24def hide_data(table, percentage, sampling_type):
25    assert not np.ma.is_masked(table)
26    np.random.seed(0)
27    if sampling_type == SampleBy.ROWS_COLS:
28
29        row_s_mask, row_oos_mask = hide_data(table, np.sqrt(percentage), SampleBy.ROWS)
30        col_s_mask, col_oos_mask = hide_data(table, np.sqrt(percentage), SampleBy.COLS)
31
32        oos_mask = np.logical_and(row_oos_mask, col_oos_mask)
33        return oos_mask
34
35    elif sampling_type == SampleBy.ROWS:
36        rand = np.repeat(np.random.rand(table.X.shape[0], 1), table.X.shape[1], axis=1)
37    elif sampling_type == SampleBy.COLS:
38        rand = np.repeat(np.random.rand(1, table.X.shape[1]), table.X.shape[0], axis=0)
39    elif sampling_type == SampleBy.ENTRIES:
40        rand = np.random.rand(*table.X.shape)
41    else:
42        raise ValueError("Unknown sampling method.")
43
44    oos_mask = np.logical_and(rand >= percentage, ~np.isnan(table))
45    return oos_mask
46
47
48class OWSampleMatrix(OWWidget):
49    name = "Matrix Sampler"
50    description = "Sample a relation matrix."
51    priority = 60000
52    icon = "icons/MatrixSampler.svg"
53    want_main_area = False
54    resizing_enabled = False
55
56    class Inputs:
57        data = Input("Data", Table, default=True)
58
59    class Outputs:
60        in_sample_data = Output("In-sample Data", Relation)
61        out_of_sample_data = Output("Out-of-sample Data", Relation)
62
63    percent = settings.Setting(90)
64    method = settings.Setting(0)
65    bools = settings.Setting([])
66
67    def __init__(self):
68        super().__init__()
69        self.data = None
70
71        form = QGridLayout()
72
73        self.row_type = ""
74        gui.lineEdit(self.controlArea, self, "row_type", "Row Type", callback=self.send_output)
75
76        self.col_type = ""
77        gui.lineEdit(self.controlArea, self, "col_type", "Column Type", callback=self.send_output)
78
79        methodbox = gui.radioButtonsInBox(
80            self.controlArea, self, "method", [],
81            box=self.tr("Sampling method"), orientation=form)
82
83        rows = gui.appendRadioButton(methodbox, "Rows", addToLayout=False)
84        form.addWidget(rows, 0, 0, Qt.AlignLeft)
85
86        cols = gui.appendRadioButton(methodbox, "Columns", addToLayout=False)
87        form.addWidget(cols, 0, 1, Qt.AlignLeft)
88
89        rows_and_cols = gui.appendRadioButton(methodbox, "Rows and columns", addToLayout=False)
90        form.addWidget(rows_and_cols, 1, 0, Qt.AlignLeft)
91
92        entries = gui.appendRadioButton(methodbox, "Entries", addToLayout=False)
93        form.addWidget(entries, 1, 1, Qt.AlignLeft)
94
95        sample_size = gui.widgetBox(self.controlArea, "Proportion of data in the sample")
96        gui.hSlider(sample_size, self, 'percent', minValue=1, maxValue=100, step=5, ticks=10,
97                    labelFormat=" %d%%")
98
99        gui.button(self.controlArea, self, "&Apply",
100                   callback=self.send_output, default=True)
101
102        self.setSizePolicy(QSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed))
103
104        self.setMinimumWidth(250)
105        self.send_output()
106
107    @Inputs.data
108    def set_data(self, data):
109        self.data = data
110
111        if hasattr(self.data, 'row_type'):
112            self.row_type = self.data.row_type
113
114        if hasattr(self.data, 'col_type'):
115            self.col_type = self.data.col_type
116
117        self.send_output()
118
119    def send_output(self):
120        if self.data is not None:
121            relation_ = None
122            if isinstance(self.data, Relation):
123                relation_ = Relation(self.data.relation)
124                if self.row_type:
125                    relation_.relation.row_type = fusion.ObjectType(self.row_type)
126                if self.col_type:
127                    relation_.relation.col_type = fusion.ObjectType(self.col_type)
128            else:
129                relation_ = Relation.create(self.data.X,
130                    fusion.ObjectType(self.row_type or "Unknown"),
131                    fusion.ObjectType(self.col_type or "Unknown"))
132
133            oos_mask = hide_data(relation_,
134                                 self.percent / 100,
135                                 SampleBy.all[self.method])
136            def _mask_relation(relation, mask):
137                if np.ma.is_masked(relation.data):
138                    mask = np.logical_or(mask, relation.data.mask)
139                data = copy.copy(relation)
140                data.data = np.ma.array(data.data, mask=mask)
141                return data
142
143            oos_mask = _mask_relation(relation_.relation, oos_mask)
144
145            self.Outputs.in_sample_data.send(Relation(oos_mask))
146            self.Outputs.out_of_sample_data.send(Relation(oos_mask))
147
148
149if __name__ == "__main__":
150    from AnyQt.QtWidgets import QApplication
151    app = QApplication(sys.argv)
152    ow = OWSampleMatrix()
153    # ow.set_data(Orange.data.Table("housing.tab"))
154    ow.send_output()
155    ow.show()
156    app.exec_()
157