1import sys 2import copy 3 4from AnyQt.QtCore import Qt 5from AnyQt.QtWidgets import QGridLayout, QSizePolicy 6 7from Orange.data import Table 8from Orange.widgets import widget, gui, settings 9from Orange.widgets.widget import OWWidget, Input, Output 10from skfusion import fusion 11from orangecontrib.datafusion.models import Relation 12 13import numpy as np 14 15 16class SampleBy: 17 ROWS = 'Rows' 18 COLS = 'Columns' 19 ROWS_COLS = 'Rows and columns' 20 ENTRIES = 'Entries' 21 all = [ROWS, COLS, ROWS_COLS, ENTRIES] 22 23 24def hide_data(table, percentage, sampling_type): 25 assert not np.ma.is_masked(table) 26 np.random.seed(0) 27 if sampling_type == SampleBy.ROWS_COLS: 28 29 row_s_mask, row_oos_mask = hide_data(table, np.sqrt(percentage), SampleBy.ROWS) 30 col_s_mask, col_oos_mask = hide_data(table, np.sqrt(percentage), SampleBy.COLS) 31 32 oos_mask = np.logical_and(row_oos_mask, col_oos_mask) 33 return oos_mask 34 35 elif sampling_type == SampleBy.ROWS: 36 rand = np.repeat(np.random.rand(table.X.shape[0], 1), table.X.shape[1], axis=1) 37 elif sampling_type == SampleBy.COLS: 38 rand = np.repeat(np.random.rand(1, table.X.shape[1]), table.X.shape[0], axis=0) 39 elif sampling_type == SampleBy.ENTRIES: 40 rand = np.random.rand(*table.X.shape) 41 else: 42 raise ValueError("Unknown sampling method.") 43 44 oos_mask = np.logical_and(rand >= percentage, ~np.isnan(table)) 45 return oos_mask 46 47 48class OWSampleMatrix(OWWidget): 49 name = "Matrix Sampler" 50 description = "Sample a relation matrix." 51 priority = 60000 52 icon = "icons/MatrixSampler.svg" 53 want_main_area = False 54 resizing_enabled = False 55 56 class Inputs: 57 data = Input("Data", Table, default=True) 58 59 class Outputs: 60 in_sample_data = Output("In-sample Data", Relation) 61 out_of_sample_data = Output("Out-of-sample Data", Relation) 62 63 percent = settings.Setting(90) 64 method = settings.Setting(0) 65 bools = settings.Setting([]) 66 67 def __init__(self): 68 super().__init__() 69 self.data = None 70 71 form = QGridLayout() 72 73 self.row_type = "" 74 gui.lineEdit(self.controlArea, self, "row_type", "Row Type", callback=self.send_output) 75 76 self.col_type = "" 77 gui.lineEdit(self.controlArea, self, "col_type", "Column Type", callback=self.send_output) 78 79 methodbox = gui.radioButtonsInBox( 80 self.controlArea, self, "method", [], 81 box=self.tr("Sampling method"), orientation=form) 82 83 rows = gui.appendRadioButton(methodbox, "Rows", addToLayout=False) 84 form.addWidget(rows, 0, 0, Qt.AlignLeft) 85 86 cols = gui.appendRadioButton(methodbox, "Columns", addToLayout=False) 87 form.addWidget(cols, 0, 1, Qt.AlignLeft) 88 89 rows_and_cols = gui.appendRadioButton(methodbox, "Rows and columns", addToLayout=False) 90 form.addWidget(rows_and_cols, 1, 0, Qt.AlignLeft) 91 92 entries = gui.appendRadioButton(methodbox, "Entries", addToLayout=False) 93 form.addWidget(entries, 1, 1, Qt.AlignLeft) 94 95 sample_size = gui.widgetBox(self.controlArea, "Proportion of data in the sample") 96 gui.hSlider(sample_size, self, 'percent', minValue=1, maxValue=100, step=5, ticks=10, 97 labelFormat=" %d%%") 98 99 gui.button(self.controlArea, self, "&Apply", 100 callback=self.send_output, default=True) 101 102 self.setSizePolicy(QSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)) 103 104 self.setMinimumWidth(250) 105 self.send_output() 106 107 @Inputs.data 108 def set_data(self, data): 109 self.data = data 110 111 if hasattr(self.data, 'row_type'): 112 self.row_type = self.data.row_type 113 114 if hasattr(self.data, 'col_type'): 115 self.col_type = self.data.col_type 116 117 self.send_output() 118 119 def send_output(self): 120 if self.data is not None: 121 relation_ = None 122 if isinstance(self.data, Relation): 123 relation_ = Relation(self.data.relation) 124 if self.row_type: 125 relation_.relation.row_type = fusion.ObjectType(self.row_type) 126 if self.col_type: 127 relation_.relation.col_type = fusion.ObjectType(self.col_type) 128 else: 129 relation_ = Relation.create(self.data.X, 130 fusion.ObjectType(self.row_type or "Unknown"), 131 fusion.ObjectType(self.col_type or "Unknown")) 132 133 oos_mask = hide_data(relation_, 134 self.percent / 100, 135 SampleBy.all[self.method]) 136 def _mask_relation(relation, mask): 137 if np.ma.is_masked(relation.data): 138 mask = np.logical_or(mask, relation.data.mask) 139 data = copy.copy(relation) 140 data.data = np.ma.array(data.data, mask=mask) 141 return data 142 143 oos_mask = _mask_relation(relation_.relation, oos_mask) 144 145 self.Outputs.in_sample_data.send(Relation(oos_mask)) 146 self.Outputs.out_of_sample_data.send(Relation(oos_mask)) 147 148 149if __name__ == "__main__": 150 from AnyQt.QtWidgets import QApplication 151 app = QApplication(sys.argv) 152 ow = OWSampleMatrix() 153 # ow.set_data(Orange.data.Table("housing.tab")) 154 ow.send_output() 155 ow.show() 156 app.exec_() 157