1""" Qt component for Gene sets """
2from typing import Union
3from collections import defaultdict
4
5import numpy as np
6
7from AnyQt.QtCore import Qt
8from AnyQt.QtWidgets import QWidget, QGroupBox, QTreeView, QTreeWidget, QTreeWidgetItem, QTreeWidgetItemIterator
9
10from orangecontrib.bioinformatics.geneset import GeneSet, GeneSets, list_all, load_gene_sets
11
12# TODO: better handle stored selection
13# TODO: Don't use hardcoded 'Custom sets', use table name if available
14
15
16class GeneSetsSelection(QWidget):
17    def __init__(self, box, parent, settings_var, **kwargs):
18        # type: (Union[QGroupBox, QWidget], QWidget, str) -> None
19        super().__init__(**kwargs)
20
21        self.parent = parent
22        self.stored_selection = settings_var
23        # gene sets object
24        self.gs_object = GeneSets()  # type: GeneSets
25
26        self.hierarchy_tree_widget = QTreeWidget(self)
27        self.hierarchy_tree_widget.setHeaderHidden(True)
28        self.hierarchy_tree_widget.setEditTriggers(QTreeView.NoEditTriggers)
29        box.layout().addWidget(self.hierarchy_tree_widget)
30
31        self.custom_set_hier = None
32        self.default_selection = [
33            ('GO', 'molecular_function'),
34            ('GO', 'biological_process'),
35            ('GO', 'cellular_component'),
36        ]
37
38    def clear_custom_sets(self):
39        # delete any custom sets if they exists
40        self.gs_object.delete_sets_by_hierarchy(self.custom_set_hier)
41
42    def add_custom_sets(self, gene_sets_names, gene_names, hierarchy_title=None, select_customs_flag=False):
43        # type: (np.ndarray, np.ndarray) -> None
44
45        self.custom_set_hier = hierarchy_title
46        self.clear_custom_sets()
47
48        temp_dict = defaultdict(list)
49        for set_name, gene_name in zip(gene_sets_names, gene_names):
50            temp_dict[set_name].append(gene_name)
51
52        g_sets = []
53        for key, value in temp_dict.items():
54            g_sets.append(
55                GeneSet(
56                    gs_id=key,
57                    hierarchy=self.custom_set_hier,
58                    organism=self.gs_object.common_org(),
59                    name=key,
60                    genes=set(value),
61                )
62            )
63
64        self.gs_object.update(g_sets)
65        self.update_gs_hierarchy(select_customs_flag=select_customs_flag)
66
67    def load_gene_sets(self, tax_id):
68        # type: (str) -> None
69        self.gs_object = GeneSets()
70        self.clear()
71
72        gene_sets = list_all(organism=tax_id)
73        self.set_hierarchy_model(self.hierarchy_tree_widget, self.hierarchy_tree(gene_sets))
74
75        for gene_set in gene_sets:
76            g_sets = load_gene_sets(gene_set, tax_id)
77            self.gs_object.update([g_set for g_set in g_sets])
78
79        self.set_selected_hierarchies()
80
81    def clear_gene_sets(self):
82        self.gs_object = GeneSets()
83
84    def clear(self):
85        # reset hierarchy widget state
86        self.hierarchy_tree_widget.clear()
87
88    def update_gs_hierarchy(self, select_customs_flag=False):
89        self.clear()
90        self.set_hierarchy_model(self.hierarchy_tree_widget, self.hierarchy_tree(self.gs_object.hierarchies()))
91        if select_customs_flag:
92            self.set_custom_sets()
93        else:
94            self.set_selected_hierarchies()
95
96    def set_hierarchy_model(self, tree_widget, sets):
97        def beautify_displayed_text(text):
98            if '_' in text:
99                return text.replace('_', ' ').title()
100            else:
101                return text
102
103        # TODO: maybe optimize this code?
104        for key, value in sets.items():
105            item = QTreeWidgetItem(tree_widget, [beautify_displayed_text(key)])
106            item.setFlags(item.flags() & (Qt.ItemIsUserCheckable | Qt.ItemIsSelectable | Qt.ItemIsEnabled))
107            item.setExpanded(True)
108            item.hierarchy = key
109
110            if value:
111                item.setFlags(item.flags() | Qt.ItemIsTristate)
112                self.set_hierarchy_model(item, value)
113            else:
114                if item.parent():
115                    item.hierarchy = (item.parent().hierarchy, key)
116
117            if not item.childCount() and not item.parent():
118                item.hierarchy = (key,)
119
120    def get_hierarchies(self, **kwargs):
121        """ return selected hierarchy
122        """
123        only_selected = kwargs.get('only_selected', None)
124
125        sets_to_display = []
126
127        if only_selected:
128            iterator = QTreeWidgetItemIterator(self.hierarchy_tree_widget, QTreeWidgetItemIterator.Checked)
129        else:
130            iterator = QTreeWidgetItemIterator(self.hierarchy_tree_widget)
131
132        while iterator.value():
133            # note: if hierarchy value is not a tuple, then this is just top level qTreeWidgetItem that
134            #       holds subcategories. We don't want to display all sets from category
135            if type(iterator.value().hierarchy) is not str:
136
137                if not only_selected:
138                    sets_to_display.append(iterator.value().hierarchy)
139                else:
140                    if not iterator.value().isDisabled():
141                        sets_to_display.append(iterator.value().hierarchy)
142
143            iterator += 1
144
145        return sets_to_display
146
147    def set_selected_hierarchies(self):
148        iterator = QTreeWidgetItemIterator(self.hierarchy_tree_widget, QTreeWidgetItemIterator.All)
149        defaults = []
150
151        while iterator.value():
152
153            # note: if hierarchy value is not a tuple, then this is just top level qTreeWidgetItem that
154            #       holds subcategories. We don't want to display all sets from category
155            if type(iterator.value().hierarchy) is not str:
156                if iterator.value().hierarchy in self.parent.__getattribute__(self.stored_selection):
157                    iterator.value().setCheckState(0, Qt.Checked)
158                else:
159                    iterator.value().setCheckState(0, Qt.Unchecked)
160
161            # if no items are checked, set defaults
162            if iterator.value().hierarchy in self.default_selection:
163                defaults.append(iterator.value())
164
165            iterator += 1
166
167        if len(self.get_hierarchies(only_selected=True)) == 0:
168            [item.setCheckState(0, Qt.Checked) for item in defaults]
169
170    def set_custom_sets(self):
171        iterator = QTreeWidgetItemIterator(self.hierarchy_tree_widget, QTreeWidgetItemIterator.All)
172
173        while iterator.value():
174
175            # note: if hierarchy value is not a tuple, then this is just top level qTreeWidgetItem that
176            #       holds subcategories. We don't want to display all sets from category
177            if type(iterator.value().hierarchy) is not str:
178                if iterator.value().hierarchy == self.custom_set_hier:
179                    iterator.value().setCheckState(0, Qt.Checked)
180                else:
181                    iterator.value().setCheckState(0, Qt.Unchecked)
182
183            iterator += 1
184
185    @staticmethod
186    def hierarchy_tree(gene_sets):
187        def tree():
188            return defaultdict(tree)
189
190        collection = tree()
191
192        def collect(col, set_hierarchy):
193            if set_hierarchy:
194                collect(col[set_hierarchy[0]], set_hierarchy[1:])
195
196        for hierarchy in gene_sets:
197            collect(collection, hierarchy)
198
199        return collection
200