1import sys
2import pkg_resources
3
4import numpy as np
5
6from AnyQt.QtCore import Signal
7from AnyQt.QtWidgets import (
8    QApplication, QVBoxLayout, QHBoxLayout, QFormLayout, QSpinBox, QComboBox,
9    QButtonGroup, QLabel, QDoubleSpinBox, QGroupBox, QCheckBox, QRadioButton
10)
11
12import Orange.widgets.data.owpreprocess
13from Orange.data import DiscreteVariable
14from Orange.widgets.data.owpreprocess import (
15    PreprocessAction, Description, index_to_enum, enum_to_index
16)
17from Orange.widgets.data.utils.preprocess import (
18    ParametersRole, DescriptionRole, Controller, BaseEditor
19)
20from Orange.widgets.settings import (
21    DomainContextHandler, ContextSetting, Setting
22)
23from Orange.widgets.utils.itemmodels import DomainModel
24from Orange.widgets.utils.sql import check_sql_input
25from Orange.widgets.widget import Input, Output, Msg
26
27from orangecontrib.single_cell.preprocess.scpreprocess import (
28    LogarithmicScale, Binarize, Normalize, NormalizeSamples, Standardize,
29    SelectMostVariableGenes, NormalizeGroups
30)
31
32
33def icon_path(basename):
34    return pkg_resources.resource_filename(__name__, "icons/" + basename)
35
36
37class ScBaseEditor(BaseEditor):
38    def __init__(self, parent=None, master=None, **kwargs):
39        super().__init__(parent, **kwargs)
40
41
42class LogarithmicScaleEditor(ScBaseEditor):
43    DEFAULT_BASE = LogarithmicScale.BinaryLog
44
45    def __init__(self, parent=None, **kwargs):
46        super().__init__(parent, **kwargs)
47        self.setLayout(QVBoxLayout())
48
49        form = QFormLayout()
50        self.base_cb = QComboBox()
51        self.base_cb.addItems(["2 (Binary Logarithm)",
52                               "e (Natural Logarithm)",
53                               "10 (Common Logarithm)"])
54        self.base_cb.currentIndexChanged.connect(self.changed)
55        self.base_cb.activated.connect(self.edited)
56
57        form.addRow("Logarithm Base:", self.base_cb)
58        self.layout().addLayout(form)
59
60    def setParameters(self, params):
61        base = params.get("base", self.DEFAULT_BASE)
62        self.base_cb.setCurrentIndex(
63            enum_to_index(LogarithmicScale.Base, base))
64
65    def parameters(self):
66        return {"base": index_to_enum(LogarithmicScale.Base,
67                                      self.base_cb.currentIndex())}
68
69    @staticmethod
70    def createinstance(params):
71        base = params.get("base", LogarithmicScaleEditor.DEFAULT_BASE)
72        return LogarithmicScale(base)
73
74    def __repr__(self):
75        return "Base: {}".format(self.base_cb.currentText())
76
77
78class BinarizeEditor(ScBaseEditor):
79    DEFAULT_CONDITION = Binarize.GreaterOrEqual
80    DEFAULT_THRESHOLD = 1
81
82    def __init__(self, parent=None, **kwargs):
83        super().__init__(parent, **kwargs)
84        self._threshold = self.DEFAULT_THRESHOLD
85
86        self.setLayout(QVBoxLayout())
87        form = QFormLayout()
88        self.cond_cb = QComboBox()
89        self.cond_cb.addItems(["Greater or Equal", "Greater"])
90        self.cond_cb.currentIndexChanged.connect(self.changed)
91        self.cond_cb.activated.connect(self.edited)
92
93        self.thr_spin = QDoubleSpinBox(
94            minimum=0, singleStep=0.5, decimals=1, value=self._threshold
95        )
96        self.thr_spin.valueChanged[float].connect(self._set_threshold)
97        self.thr_spin.editingFinished.connect(self.edited)
98
99        form.addRow("Condition:", self.cond_cb)
100        form.addRow("Threshold:", self.thr_spin)
101        self.layout().addLayout(form)
102
103    def _set_threshold(self, t):
104        if self._threshold != t:
105            self._threshold = t
106            self.thr_spin.setValue(t)
107            self.changed.emit()
108
109    def setParameters(self, params):
110        cond = params.get("condition", self.DEFAULT_CONDITION)
111        self.cond_cb.setCurrentIndex(enum_to_index(Binarize.Condition, cond))
112        self._set_threshold(params.get("threshold", self.DEFAULT_THRESHOLD))
113
114    def parameters(self):
115        cond = index_to_enum(Binarize.Condition, self.cond_cb.currentIndex())
116        return {"condition": cond, "threshold": self._threshold}
117
118    @staticmethod
119    def createinstance(params):
120        condition = params.get("condition", BinarizeEditor.DEFAULT_CONDITION)
121        threshold = params.get("threshold", BinarizeEditor.DEFAULT_THRESHOLD)
122        return Binarize(condition, threshold)
123
124    def __repr__(self):
125        return "Condition: {}, Threshold: {}".format(
126            self.cond_cb.currentText(), self.thr_spin.value()
127        )
128
129
130class NormalizeEditor(ScBaseEditor):
131    DEFAULT_GROUP_BY = False
132    DEFAULT_GROUP_VAR = None
133    DEFAULT_METHOD = Normalize.CPM
134
135    def __init__(self, parent=None, master=None, **kwargs):
136        super().__init__(parent, **kwargs)
137        self._group_var = self.DEFAULT_GROUP_VAR
138        self._master = master
139        self._master.input_data_changed.connect(self._set_model)
140        self.setLayout(QVBoxLayout())
141
142        form = QFormLayout()
143        cpm_b = QRadioButton("Counts per million", checked=True)
144        med_b = QRadioButton("Median")
145        self.group = QButtonGroup()
146        self.group.buttonClicked.connect(self._on_button_clicked)
147        for i, button in enumerate([cpm_b, med_b]):
148            index = index_to_enum(Normalize.Method, i).value
149            self.group.addButton(button, index - 1)
150            form.addRow(button)
151
152        self.group_by_check = QCheckBox("Cell Groups: ",
153                                        enabled=self.DEFAULT_GROUP_BY)
154        self.group_by_check.clicked.connect(self.edited)
155        self.group_by_combo = QComboBox(enabled=self.DEFAULT_GROUP_BY)
156        self.group_by_model = DomainModel(
157            order=(DomainModel.METAS, DomainModel.CLASSES),
158            valid_types=DiscreteVariable,
159            alphabetical=True
160        )
161        self.group_by_combo.setModel(self.group_by_model)
162        self.group_by_combo.currentIndexChanged.connect(self.changed)
163        self.group_by_combo.activated.connect(self.edited)
164
165        form.addRow(self.group_by_check, self.group_by_combo)
166        self.layout().addLayout(form)
167
168        self._set_model()
169
170    def _set_model(self):
171        data = self._master.data
172        self.group_by_model.set_domain(data and data.domain)
173        enable = bool(self.group_by_model)
174        self.group_by_check.setChecked(False)
175        self.group_by_check.setEnabled(enable)
176        self.group_by_combo.setEnabled(enable)
177        if self.group_by_model:
178            self.group_by_combo.setCurrentIndex(0)
179            if self._group_var and self._group_var in data.domain:
180                index = self.group_by_model.indexOf(self._group_var)
181                self.group_by_combo.setCurrentIndex(index)
182        else:
183            self.group_by_combo.setCurrentText(None)
184
185    def _on_button_clicked(self):
186        self.changed.emit()
187        self.edited.emit()
188
189    def setParameters(self, params):
190        method = params.get("method", self.DEFAULT_METHOD)
191        index = enum_to_index(Normalize.Method, method)
192        self.group.buttons()[index].setChecked(True)
193        self._group_var = params.get("group_var", self.DEFAULT_GROUP_VAR)
194        group = bool(self._group_var and self.group_by_model)
195        if group:
196            index = self.group_by_model.indexOf(self._group_var)
197            self.group_by_combo.setCurrentIndex(index)
198        group_by = params.get("group_by", self.DEFAULT_GROUP_BY)
199        self.group_by_check.setChecked(group_by and group)
200
201    def parameters(self):
202        index = self.group_by_combo.currentIndex()
203        group_var = self.group_by_model[index] if index > -1 else None
204        group_by = self.group_by_check.isChecked()
205        method = index_to_enum(Normalize.Method, self.group.checkedId())
206        return {"group_var": group_var, "group_by": group_by, "method": method}
207
208    @staticmethod
209    def createinstance(params):
210        group_var = params.get("group_var")
211        group_by = params.get("group_by", NormalizeEditor.DEFAULT_GROUP_BY)
212        method = params.get("method", NormalizeEditor.DEFAULT_METHOD)
213        return NormalizeGroups(group_var, method) \
214            if group_by and group_var else NormalizeSamples(method)
215
216    def __repr__(self):
217        method = self.group.button(self.group.checkedId()).text()
218        index = self.group_by_combo.currentIndex()
219        group_var = self.group_by_model[index] if index > -1 else None
220        group_by = self.group_by_check.isChecked()
221        group_text = ", Grouped by: {}".format(group_var) if group_by else ""
222        return "Method: {}".format(method) + group_text
223
224
225class StandardizeEditor(ScBaseEditor):
226    DEFAULT_LOWER_CLIP = False
227    DEFAULT_UPPER_CLIP = False
228    DEFAULT_LOWER_BOUND = -10
229    DEFAULT_UPPER_BOUND = 10
230
231    def __init__(self, parent=None, **kwargs):
232        super().__init__(parent, **kwargs)
233        self._lower_bound = self.DEFAULT_LOWER_BOUND
234        self._upper_bound = self.DEFAULT_UPPER_BOUND
235
236        self.setLayout(QVBoxLayout())
237
238        box = QGroupBox(title="Clipping", flat=True)
239        form = QFormLayout()
240        self.lower_check = QCheckBox("Lower Bound: ")
241        self.lower_check.clicked.connect(self.edited)
242        self.lower_spin = QSpinBox(
243            minimum=-99, maximum=0, value=self._lower_bound
244        )
245        self.lower_spin.valueChanged[int].connect(self._set_lower_bound)
246        self.lower_spin.editingFinished.connect(self.edited)
247
248        self.upper_check = QCheckBox("Upper Bound: ")
249        self.upper_check.clicked.connect(self.edited)
250        self.upper_spin = QSpinBox(value=self._upper_bound)
251        self.upper_spin.valueChanged[int].connect(self._set_upper_bound)
252        self.upper_spin.editingFinished.connect(self.edited)
253
254        form.addRow(self.lower_check, self.lower_spin)
255        form.addRow(self.upper_check, self.upper_spin)
256        box.setLayout(form)
257        self.layout().addWidget(box)
258
259    def _set_lower_bound(self, x):
260        if self._lower_bound != x:
261            self._lower_bound = x
262            self.lower_spin.setValue(x)
263            self.changed.emit()
264
265    def _set_upper_bound(self, x):
266        if self._upper_bound != x:
267            self._upper_bound = x
268            self.upper_spin.setValue(x)
269            self.changed.emit()
270
271    def setParameters(self, params):
272        lower_clip = params.get("lower_clip", self.DEFAULT_LOWER_CLIP)
273        self.lower_check.setChecked(lower_clip)
274        self._set_lower_bound(params.get("lower", self.DEFAULT_LOWER_BOUND))
275        upper_clip = params.get("upper_clip", self.DEFAULT_UPPER_CLIP)
276        self.upper_check.setChecked(upper_clip)
277        self._set_upper_bound(params.get("upper", self.DEFAULT_UPPER_BOUND))
278
279    def parameters(self):
280        return {"lower_clip": self.lower_check.isChecked(),
281                "lower": self._lower_bound,
282                "upper_clip": self.upper_check.isChecked(),
283                "upper": self._upper_bound}
284
285    @staticmethod
286    def createinstance(params):
287        lower, upper = None, None
288        if params.get("lower_clip", StandardizeEditor.DEFAULT_LOWER_CLIP):
289            lower = params.get("lower", StandardizeEditor.DEFAULT_LOWER_BOUND)
290        if params.get("upper_clip", StandardizeEditor.DEFAULT_UPPER_CLIP):
291            upper = params.get("upper", StandardizeEditor.DEFAULT_UPPER_BOUND)
292        return Standardize(lower, upper)
293
294    def __repr__(self):
295        clips = []
296        if self.lower_check.isChecked():
297            clips.append("Lower Bound: {}".format(self.lower_spin.value()))
298        if self.upper_check.isChecked():
299            clips.append("Upper Bound: {}".format(self.upper_spin.value()))
300        return ", ".join(clips) if clips else "No Clipping"
301
302
303class SelectGenesEditor(ScBaseEditor):
304    DEFAULT_N_GENS = 1000
305    DEFAULT_METHOD = SelectMostVariableGenes.Dispersion
306    DEFAULT_COMPUTE_STATS = True
307    DEFAULT_N_GROUPS = 20
308
309    def __init__(self, parent=None, **kwargs):
310        super().__init__(parent, **kwargs)
311        self.setLayout(QVBoxLayout())
312        self._n_genes = self.DEFAULT_N_GENS
313        self._n_groups = self.DEFAULT_N_GROUPS
314
315        form = QFormLayout()
316        self.n_genes_spin = QSpinBox(minimum=1, maximum=10 ** 6,
317                                     value=self._n_genes)
318        self.n_genes_spin.valueChanged[int].connect(self._set_n_genes)
319        self.n_genes_spin.editingFinished.connect(self.edited)
320        form.addRow("Number of genes:", self.n_genes_spin)
321        self.layout().addLayout(form)
322
323        disp_b = QRadioButton("Dispersion", checked=True)
324        vari_b = QRadioButton("Variance")
325        mean_b = QRadioButton("Mean")
326        self.group = QButtonGroup()
327        self.group.buttonClicked.connect(self._on_button_clicked)
328        for i, button in enumerate([disp_b, vari_b, mean_b]):
329            index = index_to_enum(SelectMostVariableGenes.Method, i).value
330            self.group.addButton(button, index - 1)
331            form.addRow(button)
332
333        self.stats_check = QCheckBox("Compute statistics for",
334                                     checked=self.DEFAULT_COMPUTE_STATS)
335        self.stats_check.clicked.connect(self.edited)
336        self.n_groups_spin = QSpinBox(minimum=1, value=self._n_groups)
337        self.n_groups_spin.valueChanged[int].connect(self._set_n_groups)
338        self.n_groups_spin.editingFinished.connect(self.edited)
339
340        box = QHBoxLayout()
341        box.addWidget(self.stats_check)
342        box.addWidget(self.n_groups_spin)
343        box.addWidget(QLabel("gene groups."))
344        box.addStretch()
345        self.layout().addLayout(box)
346
347    def _set_n_genes(self, n):
348        if self._n_genes != n:
349            self._n_genes = n
350            self.n_genes_spin.setValue(n)
351            self.changed.emit()
352
353    def _set_n_groups(self, n):
354        if self._n_groups != n:
355            self._n_groups = n
356            self.n_groups_spin.setValue(n)
357            self.changed.emit()
358
359    def _on_button_clicked(self):
360        self.changed.emit()
361        self.edited.emit()
362
363    def setParameters(self, params):
364        self._set_n_genes(params.get("n_genes", self.DEFAULT_N_GENS))
365        method = params.get("method", self.DEFAULT_METHOD)
366        index = enum_to_index(SelectMostVariableGenes.Method, method)
367        self.group.buttons()[index].setChecked(True)
368        compute_stats = params.get("compute_stats", self.DEFAULT_COMPUTE_STATS)
369        self.stats_check.setChecked(compute_stats)
370        self._set_n_groups(params.get("n_groups", self.DEFAULT_N_GROUPS))
371
372    def parameters(self):
373        method = index_to_enum(SelectMostVariableGenes.Method,
374                               self.group.checkedId())
375        return {"n_genes": self._n_genes, "method": method,
376                "compute_stats": self.stats_check.isChecked(),
377                "n_groups": self._n_groups}
378
379    @staticmethod
380    def createinstance(params):
381        method = params.get("method", SelectGenesEditor.DEFAULT_METHOD)
382        n_genes = params.get("n_genes", SelectGenesEditor.DEFAULT_N_GENS)
383        compute_stats = params.get(
384            "compute_stats", SelectGenesEditor.DEFAULT_COMPUTE_STATS)
385        n_groups = params.get("n_groups", SelectGenesEditor.DEFAULT_N_GROUPS) \
386            if compute_stats else None
387        return SelectMostVariableGenes(method, n_genes, n_groups)
388
389    def __repr__(self):
390        method = self.group.button(self.group.checkedId()).text()
391        text = "Method: {}, Number of Genes: {}".format(method, self._n_genes)
392        if self.stats_check.isChecked():
393            text += ", Number of Groups: {}".format(self._n_groups)
394        return text
395
396
397PREPROCESS_ACTIONS = [
398    PreprocessAction(
399        "Logarithmic Scale", "preprocess.log_scale", "Value-Based",
400        Description("Logarithmic Scale",
401                    icon_path("LogarithmicScale.svg")),
402        LogarithmicScaleEditor
403    ),
404    PreprocessAction(
405        "Binarize Expression", "preprocess.binarize", "Value-Based",
406        Description("Binarize Expression",
407                    icon_path("Binarize.svg")),
408        BinarizeEditor
409    ),
410    PreprocessAction(
411        "Normalize Samples", "preprocess.normalize", "Row-Based",
412        Description("Normalize Samples",
413                    icon_path("Normalize.svg")),
414        NormalizeEditor
415    ),
416    PreprocessAction(
417        "Standardize Genes", "preprocess.standardize", "Column-Based",
418        Description("Standardize Genes",
419                    icon_path("Standardize.svg")),
420        StandardizeEditor
421    ),
422    PreprocessAction(
423        "Select Most Variable Genes", "preprocess.select_genes",
424        "Column-Based",
425        Description("Select Most Variable Genes",
426                    icon_path("SelectGenes.svg")),
427        SelectGenesEditor
428    )
429]
430
431
432class ScController(Controller):
433    def __init__(self, view, model=None, parent=None):
434        super().__init__(view, model, parent)
435        self._master = parent
436
437    def createWidgetFor(self, index):
438        definition = index.data(DescriptionRole)
439        widget = definition.viewclass(master=self._master)
440        return widget
441
442
443class OWscPreprocess(Orange.widgets.data.owpreprocess.OWPreprocess):
444    name = "Single Cell Preprocess"
445    description = "Preprocess Single Cell data set"
446    icon = "icons/SingleCellPreprocess.svg"
447    priority = 220
448
449    class Inputs:
450        data = Input("Data", Orange.data.Table)
451
452    class Outputs:
453        preprocessed_data = Output("Preprocessed Data", Orange.data.Table)
454
455    class Error(Orange.widgets.data.owpreprocess.OWPreprocess.Error):
456        discrete_attributes = Msg("Data with discrete attributes "
457                                  "can not be preprocessed.")
458
459    class Warning(Orange.widgets.data.owpreprocess.OWPreprocess.Warning):
460        missing_values = Msg("Missing values have been replaced with 0.")
461
462    PREPROCESSORS = PREPROCESS_ACTIONS
463    DEFAULT_PP = {"preprocessors": [("preprocess.normalize", {}),
464                                    ("preprocess.log_scale", {}),
465                                    ("preprocess.select_genes", {}),
466                                    ("preprocess.standardize", {})]}
467    CONTROLLER = ScController
468    storedsettings = Setting(DEFAULT_PP)
469    group_var = ContextSetting(None)
470    settingsHandler = DomainContextHandler()
471
472    input_data_changed = Signal()
473
474    @Inputs.data
475    @check_sql_input
476    def set_data(self, data=None):
477        """Set the input dataset."""
478        self.closeContext()
479        self.data = data
480        self.openContext(data)
481        self.check_data()
482        self.input_data_changed.emit()
483        self.load_group_var()
484
485    def check_data(self):
486        self.Error.discrete_attributes.clear()
487        if self.data and self.data.domain.has_discrete_attributes():
488            self.data = None
489            self.Error.discrete_attributes()
490
491        self.Warning.missing_values.clear()
492        if self.data and np.isnan(self.data.X).any():
493            self.data.X = np.nan_to_num(self.data.X)
494            self.Warning.missing_values()
495
496    def load_group_var(self):
497        for index in range(self.preprocessormodel.rowCount()):
498            item = self.preprocessormodel.item(index)
499            params = item.data(ParametersRole)
500            if "group_var" in params:
501                params["group_var"] = self.group_var
502                item.setData(params, ParametersRole)
503
504    def save(self, model):
505        d = {"name": ""}
506        preprocessors = []
507        for i in range(model.rowCount()):
508            item = model.item(i)
509            pp_def = item.data(DescriptionRole)
510            params = item.data(ParametersRole)
511            group_var = params.get("group_var")
512            if group_var is not None:
513                self.group_var = group_var
514                params = dict(params)
515                params["group_var"] = None
516            preprocessors.append((pp_def.qualname, params))
517
518        d["preprocessors"] = preprocessors
519        return d
520
521    def apply(self):
522        self.storeSpecificSettings()
523        preprocessor = self.buildpreproc()
524        data = None
525        if self.data is not None:
526            self.error()
527            try:
528                data = preprocessor(self.data)
529            except (ValueError, ZeroDivisionError) as e:
530                self.error(str(e))
531        self.Outputs.preprocessed_data.send(data)
532
533
534def main(args=None):
535    from Orange.data import Table
536
537    app = QApplication(args or [])
538    w = OWscPreprocess()
539    w.set_data(Table("iris"))
540    w.show()
541    w.raise_()
542    app.exec_()
543    w.saveSettings()
544    w.onDeleteWidget()
545
546
547if __name__ == "__main__":
548    sys.exit(main())
549