1import enum
2from collections import defaultdict
3from itertools import islice
4from typing import (
5    Iterable, Mapping, Any, TypeVar, NamedTuple, Sequence, Optional,
6    Union, Tuple, List, Callable
7)
8
9import numpy as np
10import scipy.sparse as sp
11
12from AnyQt.QtWidgets import (
13    QGraphicsScene, QGraphicsView, QFormLayout, QComboBox, QGroupBox,
14    QMenu, QAction, QSizePolicy
15)
16from AnyQt.QtGui import QStandardItemModel, QStandardItem, QFont, QKeySequence
17from AnyQt.QtCore import Qt, QSize, QRectF, QObject
18
19from orangewidget.utils.combobox import ComboBox, ComboBoxSearch
20from Orange.data import Domain, Table, Variable, DiscreteVariable, \
21    ContinuousVariable
22from Orange.data.sql.table import SqlTable
23import Orange.distance
24
25from Orange.clustering import hierarchical, kmeans
26from Orange.widgets.utils import colorpalettes, apply_all, enum_get, itemmodels
27from Orange.widgets.utils.itemmodels import DomainModel
28from Orange.widgets.utils.stickygraphicsview import StickyGraphicsView
29from Orange.widgets.utils.graphicsview import GraphicsWidgetView
30from Orange.widgets.utils.colorpalettes import Palette
31
32from Orange.widgets.utils.annotated_data import (create_annotated_table,
33                                                 ANNOTATED_DATA_SIGNAL_NAME)
34from Orange.widgets import widget, gui, settings
35from Orange.widgets.widget import Msg, Input, Output
36
37from Orange.widgets.data.oweditdomain import table_column_data
38from Orange.widgets.visualize.utils.heatmap import HeatmapGridWidget, \
39    ColorMap, CategoricalColorMap, GradientColorMap
40from Orange.widgets.utils.colorgradientselection import ColorGradientSelection
41from Orange.widgets.utils.widgetpreview import WidgetPreview
42
43
44__all__ = []
45
46
47def kmeans_compress(X, k=50):
48    km = kmeans.KMeans(n_clusters=k, n_init=5, random_state=42)
49    return km.get_model(X)
50
51
52def split_domain(domain: Domain, split_label: str):
53    """Split the domain based on values of `split_label` value.
54    """
55    groups = defaultdict(list)
56    for var in domain.attributes:
57        val = var.attributes.get(split_label)
58        groups[val].append(var)
59    if None in groups:
60        na = groups.pop(None)
61        return [*groups.items(), ("N/A", na)]
62    else:
63        return list(groups.items())
64
65
66def cbselect(cb: QComboBox, value, role: Qt.ItemDataRole = Qt.EditRole) -> None:
67    """
68    Find and select the `value` in the `cb` QComboBox.
69
70    Parameters
71    ----------
72    cb: QComboBox
73    value: Any
74    role: Qt.ItemDataRole
75        The data role in the combo box model to match value against
76    """
77    cb.setCurrentIndex(cb.findData(value, role))
78
79
80class Clustering(enum.IntEnum):
81    #: No clustering
82    None_ = 0
83    #: Hierarchical clustering
84    Clustering = 1
85    #: Hierarchical clustering with optimal leaf ordering
86    OrderedClustering = 2
87
88
89ClusteringRole = Qt.UserRole + 13
90#: Item data for clustering method selection models
91ClusteringModelData = [
92    {
93        Qt.DisplayRole: "None",
94        Qt.ToolTipRole: "No clustering",
95        ClusteringRole: Clustering.None_,
96    }, {
97        Qt.DisplayRole: "Clustering",
98        Qt.ToolTipRole: "Apply hierarchical clustering",
99        ClusteringRole: Clustering.Clustering,
100    }, {
101        Qt.DisplayRole: "Clustering (opt. ordering)",
102        Qt.ToolTipRole: "Apply hierarchical clustering with optimal leaf "
103                        "ordering.",
104        ClusteringRole: Clustering.OrderedClustering,
105    }
106]
107
108ColumnLabelsPosData = [
109    {Qt.DisplayRole: name, Qt.UserRole: value}
110    for name, value in [
111        ("None", HeatmapGridWidget.NoPosition),
112        ("Top", HeatmapGridWidget.PositionTop),
113        ("Bottom", HeatmapGridWidget.PositionBottom),
114        ("Top and Bottom", (HeatmapGridWidget.PositionTop |
115                            HeatmapGridWidget.PositionBottom)),
116    ]
117]
118
119
120def create_list_model(
121        items: Iterable[Mapping[Qt.ItemDataRole, Any]],
122        parent: Optional[QObject] = None,
123) -> QStandardItemModel:
124    """Create list model from an item date iterable."""
125    model = QStandardItemModel(parent)
126    for item in items:
127        sitem = QStandardItem()
128        for role, value in item.items():
129            sitem.setData(value, role)
130        model.appendRow([sitem])
131    return model
132
133
134class OWHeatMap(widget.OWWidget):
135    name = "Heat Map"
136    description = "Plot a data matrix heatmap."
137    icon = "icons/Heatmap.svg"
138    priority = 260
139    keywords = []
140
141    class Inputs:
142        data = Input("Data", Table)
143
144    class Outputs:
145        selected_data = Output("Selected Data", Table, default=True)
146        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)
147
148    settings_version = 3
149
150    settingsHandler = settings.DomainContextHandler()
151
152    # Disable clustering for inputs bigger than this
153    MaxClustering = 25000
154    # Disable cluster leaf ordering for inputs bigger than this
155    MaxOrderedClustering = 1000
156
157    threshold_low = settings.Setting(0.0)
158    threshold_high = settings.Setting(1.0)
159    color_center = settings.Setting(0)
160
161    merge_kmeans = settings.Setting(False)
162    merge_kmeans_k = settings.Setting(50)
163
164    # Display column with averages
165    averages: bool = settings.Setting(True)
166    # Display legend
167    legend: bool = settings.Setting(True)
168    # Annotations
169    #: text row annotation (row names)
170    annotation_var = settings.ContextSetting(None)
171    #: color row annotation
172    annotation_color_var = settings.ContextSetting(None)
173    column_annotation_color_key: Optional[Tuple[str, str]] = settings.ContextSetting(None)
174
175    # Discrete variable used to split that data/heatmaps (vertically)
176    split_by_var = settings.ContextSetting(None)
177    # Split heatmap columns by 'key' (horizontal)
178    split_columns_key: Optional[Tuple[str, str]] = settings.ContextSetting(None)
179    # Selected row/column clustering method (name)
180    col_clustering_method: str = settings.Setting(Clustering.None_.name)
181    row_clustering_method: str = settings.Setting(Clustering.None_.name)
182
183    palette_name = settings.Setting(colorpalettes.DefaultContinuousPaletteName)
184    column_label_pos: int = settings.Setting(1)
185    selected_rows: List[int] = settings.Setting(None, schema_only=True)
186
187    auto_commit = settings.Setting(True)
188
189    graph_name = "scene"
190
191    class Information(widget.OWWidget.Information):
192        sampled = Msg("Data has been sampled")
193        discrete_ignored = Msg("{} categorical feature{} ignored")
194        row_clust = Msg("{}")
195        col_clust = Msg("{}")
196        sparse_densified = Msg("Showing this data may require a lot of memory")
197
198    class Error(widget.OWWidget.Error):
199        no_continuous = Msg("No numeric features")
200        not_enough_features = Msg("Not enough features for column clustering")
201        not_enough_instances = Msg("Not enough instances for clustering")
202        not_enough_instances_k_means = Msg(
203            "Not enough instances for k-means merging")
204        not_enough_memory = Msg("Not enough memory to show this data")
205
206    class Warning(widget.OWWidget.Warning):
207        empty_clusters = Msg("Empty clusters were removed")
208
209    UserAdviceMessages = [
210        widget.Message(
211            "For data with a meaningful mid-point, "
212            "choose one of diverging palettes.",
213            "diverging_palette")]
214
215    def __init__(self):
216        super().__init__()
217        self.__pending_selection = self.selected_rows
218
219        # A kingdom for a save_state/restore_state
220        self.col_clustering = enum_get(
221            Clustering, self.col_clustering_method, Clustering.None_)
222        self.row_clustering = enum_get(
223            Clustering, self.row_clustering_method, Clustering.None_)
224
225        self.settingsAboutToBePacked.connect(self._save_state_for_serialization)
226        self.keep_aspect = False
227
228        #: The original data with all features (retained to
229        #: preserve the domain on the output)
230        self.input_data = None
231        #: The effective data striped of discrete features, and often
232        #: merged using k-means
233        self.data = None
234        self.effective_data = None
235        #: Source of column annotations (derived from self.data)
236        self.col_annot_data: Optional[Table] = None
237        #: kmeans model used to merge rows of input_data
238        self.kmeans_model = None
239        #: merge indices derived from kmeans
240        #: a list (len==k) of int ndarray where the i-th item contains
241        #: the indices which merge the input_data into the heatmap row i
242        self.merge_indices = None
243        self.parts: Optional[Parts] = None
244        self.__rows_cache = {}
245        self.__columns_cache = {}
246
247        # GUI definition
248        colorbox = gui.vBox(self.controlArea, "Color")
249
250        self.color_map_widget = cmw = ColorGradientSelection(
251            thresholds=(self.threshold_low, self.threshold_high),
252            center=self.color_center
253        )
254        model = itemmodels.ContinuousPalettesModel(parent=self)
255        cmw.setModel(model)
256        idx = cmw.findData(self.palette_name, model.KeyRole)
257        if idx != -1:
258            cmw.setCurrentIndex(idx)
259
260        cmw.activated.connect(self.update_color_schema)
261
262        def _set_thresholds(low, high):
263            self.threshold_low, self.threshold_high = low, high
264            self.update_color_schema()
265        cmw.thresholdsChanged.connect(_set_thresholds)
266
267        def _set_centering(center):
268            self.color_center = center
269            self.update_color_schema()
270        cmw.centerChanged.connect(_set_centering)
271
272        colorbox.layout().addWidget(self.color_map_widget)
273
274        mergebox = gui.vBox(self.controlArea, "Merge",)
275        gui.checkBox(mergebox, self, "merge_kmeans", "Merge by k-means",
276                     callback=self.__update_row_clustering)
277        ibox = gui.indentedBox(mergebox)
278        gui.spin(ibox, self, "merge_kmeans_k", minv=5, maxv=500,
279                 label="Clusters:", keyboardTracking=False,
280                 callbackOnReturn=True, callback=self.update_merge)
281
282        cluster_box = gui.vBox(self.controlArea, "Clustering")
283        # Row clustering
284        self.row_cluster_cb = cb = ComboBox()
285        cb.setModel(create_list_model(ClusteringModelData, self))
286        cbselect(cb, self.row_clustering, ClusteringRole)
287        self.connect_control(
288            "row_clustering",
289            lambda value, cb=cb: cbselect(cb, value, ClusteringRole)
290        )
291        @cb.activated.connect
292        def _(idx, cb=cb):
293            self.set_row_clustering(cb.itemData(idx, ClusteringRole))
294
295        # Column clustering
296        self.col_cluster_cb = cb = ComboBox()
297        cb.setModel(create_list_model(ClusteringModelData, self))
298        cbselect(cb, self.col_clustering, ClusteringRole)
299        self.connect_control(
300            "col_clustering",
301            lambda value, cb=cb: cbselect(cb, value, ClusteringRole)
302        )
303        @cb.activated.connect
304        def _(idx, cb=cb):
305            self.set_col_clustering(cb.itemData(idx, ClusteringRole))
306
307        form = QFormLayout(
308            labelAlignment=Qt.AlignLeft, formAlignment=Qt.AlignLeft,
309            fieldGrowthPolicy=QFormLayout.AllNonFixedFieldsGrow,
310        )
311        form.addRow("Rows:", self.row_cluster_cb)
312        form.addRow("Columns:", self.col_cluster_cb)
313        cluster_box.layout().addLayout(form)
314        box = gui.vBox(self.controlArea, "Split By")
315        form = QFormLayout(
316            formAlignment=Qt.AlignLeft, labelAlignment=Qt.AlignLeft,
317            fieldGrowthPolicy=QFormLayout.AllNonFixedFieldsGrow,
318        )
319        box.layout().addLayout(form)
320
321        self.row_split_model = DomainModel(
322            placeholder="(None)",
323            valid_types=(Orange.data.DiscreteVariable,),
324            parent=self,
325        )
326        self.row_split_cb = cb = ComboBoxSearch(
327            enabled=not self.merge_kmeans,
328            sizeAdjustPolicy=ComboBox.AdjustToMinimumContentsLengthWithIcon,
329            minimumContentsLength=14,
330            toolTip="Split the heatmap vertically by a categorical column"
331        )
332        self.row_split_cb.setModel(self.row_split_model)
333        self.connect_control(
334            "split_by_var", lambda value, cb=cb: cbselect(cb, value)
335        )
336        self.connect_control(
337            "merge_kmeans", self.row_split_cb.setDisabled
338        )
339        self.split_by_var = None
340
341        self.row_split_cb.activated.connect(
342            self.__on_split_rows_activated
343        )
344        self.col_split_model = DomainModel(
345            placeholder="(None)",
346            order=DomainModel.MIXED,
347            valid_types=(Orange.data.DiscreteVariable,),
348            parent=self,
349        )
350        self.col_split_cb = cb = ComboBoxSearch(
351            sizeAdjustPolicy=ComboBox.AdjustToMinimumContentsLengthWithIcon,
352            minimumContentsLength=14,
353            toolTip="Split the heatmap horizontally by column annotation"
354        )
355        self.col_split_cb.setModel(self.col_split_model)
356        self.connect_control(
357            "split_columns_var", lambda value, cb=cb: cbselect(cb, value)
358        )
359        self.split_columns_var = None
360        self.col_split_cb.activated.connect(self.__on_split_cols_activated)
361        form.addRow("Rows:", self.row_split_cb)
362        form.addRow("Columns:", self.col_split_cb)
363
364        box = gui.vBox(self.controlArea, 'Annotation && Legends')
365
366        gui.checkBox(box, self, 'legend', 'Show legend',
367                     callback=self.update_legend)
368
369        gui.checkBox(box, self, 'averages', 'Stripes with averages',
370                     callback=self.update_averages_stripe)
371        gui.separator(box)
372        annotbox = QGroupBox("Row Annotations")
373        form = QFormLayout(
374            annotbox,
375            formAlignment=Qt.AlignLeft,
376            labelAlignment=Qt.AlignLeft,
377            fieldGrowthPolicy=QFormLayout.AllNonFixedFieldsGrow
378        )
379        self.annotation_model = DomainModel(placeholder="(None)")
380        self.annotation_text_cb = ComboBoxSearch(
381            minimumContentsLength=12,
382            sizeAdjustPolicy=QComboBox.AdjustToMinimumContentsLengthWithIcon
383        )
384        self.annotation_text_cb.setModel(self.annotation_model)
385        self.annotation_text_cb.activated.connect(self.set_annotation_var)
386        self.connect_control("annotation_var", self.annotation_var_changed)
387
388        self.row_side_color_model = DomainModel(
389            order=(DomainModel.CLASSES, DomainModel.Separator,
390                   DomainModel.METAS),
391            placeholder="(None)", valid_types=DomainModel.PRIMITIVE,
392            flags=Qt.ItemIsSelectable | Qt.ItemIsEnabled,
393            parent=self,
394        )
395        self.row_side_color_cb = ComboBoxSearch(
396            sizeAdjustPolicy=QComboBox.AdjustToMinimumContentsLengthWithIcon,
397            minimumContentsLength=12
398        )
399        self.row_side_color_cb.setModel(self.row_side_color_model)
400        self.row_side_color_cb.activated.connect(self.set_annotation_color_var)
401        self.connect_control("annotation_color_var", self.annotation_color_var_changed)
402        form.addRow("Text", self.annotation_text_cb)
403        form.addRow("Color", self.row_side_color_cb)
404        box.layout().addWidget(annotbox)
405        annotbox = QGroupBox("Column annotations")
406        form = QFormLayout(
407            annotbox,
408            formAlignment=Qt.AlignLeft,
409            labelAlignment=Qt.AlignLeft,
410            fieldGrowthPolicy=QFormLayout.AllNonFixedFieldsGrow
411        )
412        self.col_side_color_model = DomainModel(
413            placeholder="(None)",
414            valid_types=(DiscreteVariable, ContinuousVariable),
415            parent=self
416        )
417        self.col_side_color_cb = cb = ComboBoxSearch(
418            sizeAdjustPolicy=QComboBox.AdjustToMinimumContentsLengthWithIcon,
419            minimumContentsLength=12
420        )
421        self.col_side_color_cb.setModel(self.col_side_color_model)
422        self.connect_control(
423            "column_annotation_color_var", self.column_annotation_color_var_changed,
424        )
425        self.column_annotation_color_var = None
426        self.col_side_color_cb.activated.connect(
427            self.__set_column_annotation_color_var_index)
428
429        cb = gui.comboBox(
430            None, self, "column_label_pos",
431            callback=self.update_column_annotations)
432        cb.setModel(create_list_model(ColumnLabelsPosData, parent=self))
433        cb.setCurrentIndex(self.column_label_pos)
434        form.addRow("Position", cb)
435        form.addRow("Color", self.col_side_color_cb)
436        box.layout().addWidget(annotbox)
437
438        gui.checkBox(self.controlArea, self, "keep_aspect",
439                     "Keep aspect ratio", box="Resize",
440                     callback=self.__aspect_mode_changed)
441
442        gui.rubber(self.controlArea)
443
444        gui.auto_send(self.buttonsArea, self, "auto_commit")
445
446        # Scene with heatmap
447        class HeatmapScene(QGraphicsScene):
448            widget: Optional[HeatmapGridWidget] = None
449
450        self.scene = self.scene = HeatmapScene(parent=self)
451        self.view = GraphicsView(
452            self.scene,
453            verticalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
454            horizontalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
455            viewportUpdateMode=QGraphicsView.FullViewportUpdate,
456            widgetResizable=True,
457        )
458        self.view.setContextMenuPolicy(Qt.CustomContextMenu)
459        self.view.customContextMenuRequested.connect(
460            self._on_view_context_menu
461        )
462        self.mainArea.layout().addWidget(self.view)
463        self.selected_rows = []
464        self.__font_inc = QAction(
465            "Increase Font", self, shortcut=QKeySequence("ctrl+>"))
466        self.__font_dec = QAction(
467            "Decrease Font", self, shortcut=QKeySequence("ctrl+<"))
468        self.__font_inc.triggered.connect(lambda: self.__adjust_font_size(1))
469        self.__font_dec.triggered.connect(lambda: self.__adjust_font_size(-1))
470        if hasattr(QAction, "setShortcutVisibleInContextMenu"):
471            apply_all(
472                [self.__font_inc, self.__font_dec],
473                lambda a: a.setShortcutVisibleInContextMenu(True)
474            )
475        self.addActions([self.__font_inc, self.__font_dec])
476
477    def _save_state_for_serialization(self):
478        def desc(var: Optional[Variable]) -> Optional[Tuple[str, str]]:
479            if var is not None:
480                return type(var).__name__, var.name
481            else:
482                return None
483
484        self.col_clustering_method = self.col_clustering.name
485        self.row_clustering_method = self.row_clustering.name
486
487        self.column_annotation_color_key = desc(self.column_annotation_color_var)
488        self.split_columns_key = desc(self.split_columns_var)
489
490    @property
491    def center_palette(self):
492        palette = self.color_map_widget.currentData()
493        return bool(palette.flags & palette.Diverging)
494
495    @property
496    def _column_label_pos(self) -> HeatmapGridWidget.Position:
497        return ColumnLabelsPosData[self.column_label_pos][Qt.UserRole]
498
499    def annotation_color_var_changed(self, value):
500        cbselect(self.row_side_color_cb, value, Qt.EditRole)
501
502    def annotation_var_changed(self, value):
503        cbselect(self.annotation_text_cb, value, Qt.EditRole)
504
505    def set_row_clustering(self, method: Clustering) -> None:
506        assert isinstance(method, Clustering)
507        if self.row_clustering != method:
508            self.row_clustering = method
509            cbselect(self.row_cluster_cb, method, ClusteringRole)
510            self.__update_row_clustering()
511
512    def set_col_clustering(self, method: Clustering) -> None:
513        assert isinstance(method, Clustering)
514        if self.col_clustering != method:
515            self.col_clustering = method
516            cbselect(self.col_cluster_cb, method, ClusteringRole)
517            self.__update_column_clustering()
518
519    def sizeHint(self) -> QSize:
520        return super().sizeHint().expandedTo(QSize(900, 700))
521
522    def color_palette(self):
523        return self.color_map_widget.currentData().lookup_table()
524
525    def color_map(self) -> GradientColorMap:
526        return GradientColorMap(
527            self.color_palette(), (self.threshold_low, self.threshold_high),
528            self.color_map_widget.center() if self.center_palette else None
529        )
530
531    def clear(self):
532        self.data = None
533        self.input_data = None
534        self.effective_data = None
535        self.kmeans_model = None
536        self.merge_indices = None
537        self.annotation_model.set_domain(None)
538        self.annotation_var = None
539        self.row_side_color_model.set_domain(None)
540        self.col_side_color_model.set_domain(None)
541        self.annotation_color_var = None
542        self.column_annotation_color_var = None
543        self.row_split_model.set_domain(None)
544        self.col_split_model.set_domain(None)
545        self.split_by_var = None
546        self.split_columns_var = None
547        self.parts = None
548        self.clear_scene()
549        self.selected_rows = []
550        self.__columns_cache.clear()
551        self.__rows_cache.clear()
552        self.__update_clustering_enable_state(None)
553
554    def clear_scene(self):
555        if self.scene.widget is not None:
556            self.scene.widget.layoutDidActivate.disconnect(
557                self.__on_layout_activate
558            )
559            self.scene.widget.selectionFinished.disconnect(
560                self.on_selection_finished
561            )
562        self.scene.widget = None
563        self.scene.clear()
564
565        self.view.setSceneRect(QRectF())
566        self.view.setHeaderSceneRect(QRectF())
567        self.view.setFooterSceneRect(QRectF())
568
569    @Inputs.data
570    def set_dataset(self, data=None):
571        """Set the input dataset to display."""
572        self.closeContext()
573        self.clear()
574        self.clear_messages()
575
576        if isinstance(data, SqlTable):
577            if data.approx_len() < 4000:
578                data = Table(data)
579            else:
580                self.Information.sampled()
581                data_sample = data.sample_time(1, no_cache=True)
582                data_sample.download_data(2000, partial=True)
583                data = Table(data_sample)
584
585        if data is not None and not len(data):
586            data = None
587
588        if data is not None and sp.issparse(data.X):
589            try:
590                data = data.to_dense()
591            except MemoryError:
592                data = None
593                self.Error.not_enough_memory()
594            else:
595                self.Information.sparse_densified()
596
597        input_data = data
598
599        # Data contains no attributes or meta attributes only
600        if data is not None and len(data.domain.attributes) == 0:
601            self.Error.no_continuous()
602            input_data = data = None
603
604        # Data contains some discrete attributes which must be filtered
605        if data is not None and \
606                any(var.is_discrete for var in data.domain.attributes):
607            ndisc = sum(var.is_discrete for var in data.domain.attributes)
608            data = data.transform(
609                Domain([var for var in data.domain.attributes
610                        if var.is_continuous],
611                       data.domain.class_vars,
612                       data.domain.metas))
613            if not data.domain.attributes:
614                self.Error.no_continuous()
615                input_data = data = None
616            else:
617                self.Information.discrete_ignored(
618                    ndisc, "s" if ndisc > 1 else "")
619
620        self.data = data
621        self.input_data = input_data
622
623        if data is not None:
624            self.annotation_model.set_domain(self.input_data.domain)
625            self.row_side_color_model.set_domain(self.input_data.domain)
626            self.annotation_var = None
627            self.annotation_color_var = None
628            self.row_split_model.set_domain(data.domain)
629            self.col_annot_data = data.transpose(data[:0].transform(Domain(data.domain.attributes)))
630            self.col_split_model.set_domain(self.col_annot_data.domain)
631            self.col_side_color_model.set_domain(self.col_annot_data.domain)
632            if data.domain.has_discrete_class:
633                self.split_by_var = data.domain.class_var
634            else:
635                self.split_by_var = None
636            self.split_columns_var = None
637            self.column_annotation_color_var = None
638            self.openContext(self.input_data)
639            if self.split_by_var not in self.row_split_model:
640                self.split_by_var = None
641
642            def match(desc: Tuple[str, str], source: Iterable[Variable]):
643                for v in source:
644                    if desc == (type(v).__name__, v.name):
645                        return v
646                return None
647
648            def is_variable(obj):
649                return isinstance(obj, Variable)
650
651            if self.split_columns_key is not None:
652                self.split_columns_var = match(
653                    self.split_columns_key,
654                    filter(is_variable, self.col_split_model)
655                )
656
657            if self.column_annotation_color_key is not None:
658                self.column_annotation_color_var = match(
659                    self.column_annotation_color_key,
660                    filter(is_variable, self.col_side_color_model)
661                )
662
663        self.update_heatmaps()
664        if data is not None and self.__pending_selection is not None:
665            assert self.scene.widget is not None
666            self.scene.widget.selectRows(self.__pending_selection)
667            self.selected_rows = self.__pending_selection
668            self.__pending_selection = None
669
670        self.unconditional_commit()
671
672    def __on_split_rows_activated(self):
673        self.set_split_variable(self.row_split_cb.currentData(Qt.EditRole))
674
675    def set_split_variable(self, var):
676        if var is not self.split_by_var:
677            self.split_by_var = var
678            self.update_heatmaps()
679
680    def __on_split_cols_activated(self):
681        self.set_column_split_var(self.col_split_cb.currentData(Qt.EditRole))
682
683    def set_column_split_var(self, var: Optional[Variable]):
684        if var is not self.split_columns_var:
685            self.split_columns_var = var
686            self.update_heatmaps()
687
688    def update_heatmaps(self):
689        if self.data is not None:
690            self.clear_scene()
691            self.clear_messages()
692            if self.col_clustering != Clustering.None_ and \
693                    len(self.data.domain.attributes) < 2:
694                self.Error.not_enough_features()
695            elif (self.col_clustering != Clustering.None_ or
696                  self.row_clustering != Clustering.None_) and \
697                    len(self.data) < 2:
698                self.Error.not_enough_instances()
699            elif self.merge_kmeans and len(self.data) < 3:
700                self.Error.not_enough_instances_k_means()
701            else:
702                parts = self.construct_heatmaps(self.data, self.split_by_var, self.split_columns_var)
703                self.construct_heatmaps_scene(parts, self.effective_data)
704                self.selected_rows = []
705        else:
706            self.clear()
707
708    def update_merge(self):
709        self.kmeans_model = None
710        self.merge_indices = None
711        if self.data is not None and self.merge_kmeans:
712            self.update_heatmaps()
713            self.commit()
714
715    def _make_parts(self, data, group_var=None, column_split_key=None):
716        """
717        Make initial `Parts` for data, split by group_var, group_key
718        """
719        if group_var is not None:
720            assert group_var.is_discrete
721            _col_data = table_column_data(data, group_var)
722            row_indices = [np.flatnonzero(_col_data == i)
723                           for i in range(len(group_var.values))]
724
725            row_groups = [RowPart(title=name, indices=ind,
726                                  cluster=None, cluster_ordered=None)
727                          for name, ind in zip(group_var.values, row_indices)]
728            if np.any(_col_data.mask):
729                row_groups.append(RowPart(
730                    title="N/A", indices=np.flatnonzero(_col_data.mask),
731                    cluster=None, cluster_ordered=None
732                ))
733        else:
734            row_groups = [RowPart(title=None, indices=range(0, len(data)),
735                                  cluster=None, cluster_ordered=None)]
736
737        if column_split_key is not None:
738            col_groups = split_domain(data.domain, column_split_key)
739            assert len(col_groups) > 0
740            col_indices = [np.array([data.domain.index(var) for var in group])
741                           for _, group in col_groups]
742            col_groups = [ColumnPart(title=str(name), domain=d, indices=ind,
743                                     cluster=None, cluster_ordered=None)
744                          for (name, d), ind in zip(col_groups, col_indices)]
745        else:
746            col_groups = [
747                ColumnPart(
748                    title=None, indices=range(0, len(data.domain.attributes)),
749                    domain=data.domain.attributes, cluster=None, cluster_ordered=None)
750            ]
751
752        minv, maxv = np.nanmin(data.X), np.nanmax(data.X)
753        return Parts(row_groups, col_groups, span=(minv, maxv))
754
755    def cluster_rows(self, data: Table, parts: 'Parts', ordered=False) -> 'Parts':
756        row_groups = []
757        for row in parts.rows:
758            if row.cluster is not None:
759                cluster = row.cluster
760            else:
761                cluster = None
762            if row.cluster_ordered is not None:
763                cluster_ord = row.cluster_ordered
764            else:
765                cluster_ord = None
766
767            if row.can_cluster:
768                matrix = None
769                need_dist = cluster is None or (ordered and cluster_ord is None)
770                if need_dist:
771                    subset = data[row.indices]
772                    matrix = Orange.distance.Euclidean(subset)
773
774                if cluster is None:
775                    cluster = hierarchical.dist_matrix_clustering(
776                        matrix, linkage=hierarchical.WARD
777                    )
778                if ordered and cluster_ord is None:
779                    cluster_ord = hierarchical.optimal_leaf_ordering(
780                        cluster, matrix,
781                    )
782            row_groups.append(row._replace(cluster=cluster, cluster_ordered=cluster_ord))
783
784        return parts._replace(rows=row_groups)
785
786    def cluster_columns(self, data, parts: 'Parts', ordered=False):
787        assert all(var.is_continuous for var in data.domain.attributes)
788        col_groups = []
789        for col in parts.columns:
790            if col.cluster is not None:
791                cluster = col.cluster
792            else:
793                cluster = None
794            if col.cluster_ordered is not None:
795                cluster_ord = col.cluster_ordered
796            else:
797                cluster_ord = None
798            if col.can_cluster:
799                need_dist = cluster is None or (ordered and cluster_ord is None)
800                matrix = None
801                if need_dist:
802                    subset = data.transform(Domain(col.domain))
803                    subset = Orange.distance._preprocess(subset)
804                    matrix = np.asarray(Orange.distance.PearsonR(subset, axis=0))
805                    # nan values break clustering below
806                    matrix = np.nan_to_num(matrix)
807
808                if cluster is None:
809                    assert matrix is not None
810                    cluster = hierarchical.dist_matrix_clustering(
811                        matrix, linkage=hierarchical.WARD
812                    )
813                if ordered and cluster_ord is None:
814                    cluster_ord = hierarchical.optimal_leaf_ordering(cluster, matrix)
815
816            col_groups.append(col._replace(cluster=cluster, cluster_ordered=cluster_ord))
817        return parts._replace(columns=col_groups)
818
819    def construct_heatmaps(self, data, group_var=None, column_split_key=None) -> 'Parts':
820        if self.merge_kmeans:
821            if self.kmeans_model is None:
822                effective_data = self.input_data.transform(
823                    Orange.data.Domain(
824                        [var for var in self.input_data.domain.attributes
825                         if var.is_continuous],
826                        self.input_data.domain.class_vars,
827                        self.input_data.domain.metas))
828                nclust = min(self.merge_kmeans_k, len(effective_data) - 1)
829                self.kmeans_model = kmeans_compress(effective_data, k=nclust)
830                effective_data.domain = self.kmeans_model.domain
831                merge_indices = [np.flatnonzero(self.kmeans_model.labels == ind)
832                                 for ind in range(nclust)]
833                not_empty_indices = [i for i, x in enumerate(merge_indices)
834                                     if len(x) > 0]
835                self.merge_indices = \
836                    [merge_indices[i] for i in not_empty_indices]
837                if len(merge_indices) != len(self.merge_indices):
838                    self.Warning.empty_clusters()
839                effective_data = Orange.data.Table(
840                    Orange.data.Domain(effective_data.domain.attributes),
841                    self.kmeans_model.centroids[not_empty_indices]
842                )
843            else:
844                effective_data = self.effective_data
845
846            group_var = None
847        else:
848            self.kmeans_model = None
849            self.merge_indices = None
850            effective_data = data
851
852        self.effective_data = effective_data
853
854        parts = self._make_parts(
855            effective_data, group_var,
856            column_split_key.name if column_split_key is not None else None)
857
858        self.__update_clustering_enable_state(parts)
859        # Restore/update the row/columns items descriptions from cache if
860        # available
861        rows_cache_key = (group_var,
862                          self.merge_kmeans_k if self.merge_kmeans else None)
863        if rows_cache_key in self.__rows_cache:
864            parts = parts._replace(rows=self.__rows_cache[rows_cache_key].rows)
865
866        if column_split_key in self.__columns_cache:
867            parts = parts._replace(
868                columns=self.__columns_cache[column_split_key].columns)
869
870        if self.row_clustering != Clustering.None_:
871            parts = self.cluster_rows(
872                effective_data, parts,
873                ordered=self.row_clustering == Clustering.OrderedClustering
874            )
875        if self.col_clustering != Clustering.None_:
876            parts = self.cluster_columns(
877                effective_data, parts,
878                ordered=self.col_clustering == Clustering.OrderedClustering
879            )
880
881        # Cache the updated parts
882        self.__rows_cache[rows_cache_key] = parts
883        return parts
884
885    def construct_heatmaps_scene(self, parts: 'Parts', data: Table) -> None:
886        _T = TypeVar("_T", bound=Union[RowPart, ColumnPart])
887
888        def select_cluster(clustering: Clustering, item: _T) -> _T:
889            if clustering == Clustering.None_:
890                return item._replace(cluster=None, cluster_ordered=None)
891            elif clustering == Clustering.Clustering:
892                return item._replace(cluster=item.cluster, cluster_ordered=None)
893            elif clustering == Clustering.OrderedClustering:
894                return item._replace(cluster=item.cluster_ordered, cluster_ordered=None)
895            else:  # pragma: no cover
896                raise TypeError()
897
898        rows = [select_cluster(self.row_clustering, rowitem)
899                for rowitem in parts.rows]
900        cols = [select_cluster(self.col_clustering, colitem)
901                for colitem in parts.columns]
902        parts = Parts(columns=cols, rows=rows, span=parts.span)
903
904        self.setup_scene(parts, data)
905
906    def setup_scene(self, parts, data):
907        # type: (Parts, Table) -> None
908        widget = HeatmapGridWidget()
909        widget.setColorMap(self.color_map())
910        self.scene.addItem(widget)
911        self.scene.widget = widget
912        columns = [v.name for v in data.domain.attributes]
913        parts = HeatmapGridWidget.Parts(
914            rows=[
915                HeatmapGridWidget.RowItem(r.title, r.indices, r.cluster)
916                for r in parts.rows
917            ],
918            columns=[
919                HeatmapGridWidget.ColumnItem(c.title, c.indices, c.cluster)
920                for c in parts.columns
921            ],
922            data=data.X,
923            span=parts.span,
924            row_names=None,
925            col_names=columns,
926        )
927        widget.setHeatmaps(parts)
928
929        side = self.row_side_colors()
930        if side is not None:
931            widget.setRowSideColorAnnotations(side[0], side[1], name=side[2].name)
932
933        side = self.column_side_colors()
934        if side is not None:
935            widget.setColumnSideColorAnnotations(side[0], side[1], name=side[2].name)
936
937        widget.setColumnLabelsPosition(self._column_label_pos)
938        widget.setAspectRatioMode(
939            Qt.KeepAspectRatio if self.keep_aspect else Qt.IgnoreAspectRatio
940        )
941        widget.setShowAverages(self.averages)
942        widget.setLegendVisible(self.legend)
943
944        widget.layoutDidActivate.connect(self.__on_layout_activate)
945        widget.selectionFinished.connect(self.on_selection_finished)
946
947        self.update_annotations()
948        self.view.setCentralWidget(widget)
949        self.parts = parts
950
951    def __update_scene_rects(self):
952        widget = self.scene.widget
953        if widget is None:
954            return
955        rect = widget.geometry()
956        self.scene.setSceneRect(rect)
957        self.view.setSceneRect(rect)
958        self.view.setHeaderSceneRect(widget.headerGeometry())
959        self.view.setFooterSceneRect(widget.footerGeometry())
960
961    def __on_layout_activate(self):
962        self.__update_scene_rects()
963
964    def __aspect_mode_changed(self):
965        widget = self.scene.widget
966        if widget is None:
967            return
968        widget.setAspectRatioMode(
969            Qt.KeepAspectRatio if self.keep_aspect else Qt.IgnoreAspectRatio
970        )
971        # when aspect fixed the vertical sh is fixex, when not, it can
972        # shrink vertically
973        sp = widget.sizePolicy()
974        if self.keep_aspect:
975            sp.setVerticalPolicy(QSizePolicy.Fixed)
976        else:
977            sp.setVerticalPolicy(QSizePolicy.Preferred)
978        widget.setSizePolicy(sp)
979
980    def __update_clustering_enable_state(self, parts: Optional['Parts']):
981        def c_cost(sizes: Iterable[int]) -> int:
982            """Estimated cost for clustering of `sizes`"""
983            return sum(n ** 2 for n in sizes)
984
985        def co_cost(sizes: Iterable[int]) -> int:
986            """Estimated cost for cluster ordering of `sizes`"""
987            # ~O(N ** 3) but O(N ** 4) worst case.
988            return sum(n ** 4 for n in sizes)
989
990        if parts is not None:
991            Ns = [len(p.indices) for p in parts.rows]
992            Ms = [len(p.indices) for p in parts.columns]
993        else:
994            Ns = Ms = [0]
995
996        rc_enabled = c_cost(Ns) <= c_cost([self.MaxClustering])
997        rco_enabled = co_cost(Ns) <= co_cost([self.MaxOrderedClustering])
998        cc_enabled = c_cost(Ms) <= c_cost([self.MaxClustering])
999        cco_enabled = co_cost(Ms) <= co_cost([self.MaxOrderedClustering])
1000        row_clust, col_clust = self.row_clustering, self.col_clustering
1001
1002        row_clust_msg = ""
1003        col_clust_msg = ""
1004
1005        if not rco_enabled and row_clust == Clustering.OrderedClustering:
1006            row_clust = Clustering.Clustering
1007            row_clust_msg = "Row cluster ordering was disabled due to the " \
1008                            "estimated runtime cost"
1009        if not rc_enabled and row_clust == Clustering.Clustering:
1010            row_clust = Clustering.None_
1011            row_clust_msg = "Row clustering was was disabled due to the " \
1012                            "estimated runtime cost"
1013
1014        if not cco_enabled and col_clust == Clustering.OrderedClustering:
1015            col_clust = Clustering.Clustering
1016            col_clust_msg = "Column cluster ordering was disabled due to " \
1017                            "estimated runtime cost"
1018        if not cc_enabled and col_clust == Clustering.Clustering:
1019            col_clust = Clustering.None_
1020            col_clust_msg = "Column clustering was disabled due to the " \
1021                            "estimated runtime cost"
1022
1023        self.col_clustering = col_clust
1024        self.row_clustering = row_clust
1025
1026        self.Information.row_clust(row_clust_msg, shown=bool(row_clust_msg))
1027        self.Information.col_clust(col_clust_msg, shown=bool(col_clust_msg))
1028
1029        # Disable/enable the combobox items for the clustering methods
1030        def setenabled(cb: QComboBox, clu: bool, clu_op: bool):
1031            model = cb.model()
1032            assert isinstance(model, QStandardItemModel)
1033            idx = cb.findData(Clustering.OrderedClustering, ClusteringRole)
1034            assert idx != -1
1035            model.item(idx).setEnabled(clu_op)
1036            idx = cb.findData(Clustering.Clustering, ClusteringRole)
1037            assert idx != -1
1038            model.item(idx).setEnabled(clu)
1039
1040        setenabled(self.row_cluster_cb, rc_enabled, rco_enabled)
1041        setenabled(self.col_cluster_cb, cc_enabled, cco_enabled)
1042
1043    def update_averages_stripe(self):
1044        """Update the visibility of the averages stripe.
1045        """
1046        widget = self.scene.widget
1047        if widget is not None:
1048            widget.setShowAverages(self.averages)
1049
1050    def update_color_schema(self):
1051        self.palette_name = self.color_map_widget.currentData().name
1052        w = self.scene.widget
1053        if w is not None:
1054            w.setColorMap(self.color_map())
1055
1056    def __update_column_clustering(self):
1057        self.update_heatmaps()
1058        self.commit()
1059
1060    def __update_row_clustering(self):
1061        self.update_heatmaps()
1062        self.commit()
1063
1064    def update_legend(self):
1065        widget = self.scene.widget
1066        if widget is not None:
1067            widget.setLegendVisible(self.legend)
1068
1069    def row_annotation_var(self):
1070        return self.annotation_var
1071
1072    def row_annotation_data(self):
1073        var = self.row_annotation_var()
1074        if var is None:
1075            return None
1076        return column_str_from_table(self.input_data, var)
1077
1078    def _merge_row_indices(self):
1079        if self.merge_kmeans and self.kmeans_model is not None:
1080            return self.merge_indices
1081        else:
1082            return None
1083
1084    def set_annotation_var(self, var: Union[None, Variable, int]):
1085        if isinstance(var, int):
1086            var = self.annotation_model[var]
1087        if self.annotation_var is not var:
1088            self.annotation_var = var
1089            self.update_annotations()
1090
1091    def update_annotations(self):
1092        widget = self.scene.widget
1093        if widget is not None:
1094            annot_col = self.row_annotation_data()
1095            merge_indices = self._merge_row_indices()
1096            if merge_indices is not None and annot_col is not None:
1097                join = lambda _1: join_elided(", ", 42, _1, " ({} more)")
1098                annot_col = aggregate_apply(join, annot_col, merge_indices)
1099            if annot_col is not None:
1100                widget.setRowLabels(annot_col)
1101                widget.setRowLabelsVisible(True)
1102            else:
1103                widget.setRowLabelsVisible(False)
1104                widget.setRowLabels(None)
1105
1106    def row_side_colors(self):
1107        var = self.annotation_color_var
1108        if var is None:
1109            return None
1110        column_data = column_data_from_table(self.input_data, var)
1111        merges = self._merge_row_indices()
1112        if merges is not None:
1113            column_data = aggregate(var, column_data, merges)
1114        data, colormap = colorize(var, column_data)
1115        if var.is_continuous:
1116            span = (np.nanmin(column_data), np.nanmax(column_data))
1117            if np.any(np.isnan(span)):
1118                span = 0., 1.
1119            colormap.span = span
1120        return data, colormap, var
1121
1122    def set_annotation_color_var(self, var: Union[None, Variable, int]):
1123        """Set the current side color annotation variable."""
1124        if isinstance(var, int):
1125            var = self.row_side_color_model[var]
1126        if self.annotation_color_var is not var:
1127            self.annotation_color_var = var
1128            self.update_row_side_colors()
1129
1130    def update_row_side_colors(self):
1131        widget = self.scene.widget
1132        if widget is None:
1133            return
1134        colors = self.row_side_colors()
1135        if colors is None:
1136            widget.setRowSideColorAnnotations(None)
1137        else:
1138            widget.setRowSideColorAnnotations(colors[0], colors[1], colors[2].name)
1139
1140    def __set_column_annotation_color_var_index(self, index: int):
1141        key = self.col_side_color_cb.itemData(index, Qt.EditRole)
1142        self.set_column_annotation_color_var(key)
1143
1144    def column_annotation_color_var_changed(self, value):
1145        cbselect(self.col_side_color_cb, value, Qt.EditRole)
1146
1147    def set_column_annotation_color_var(self, var):
1148        if self.column_annotation_color_var is not var:
1149            self.column_annotation_color_var = var
1150            colors = self.column_side_colors()
1151            if colors is not None:
1152                self.scene.widget.setColumnSideColorAnnotations(
1153                    colors[0], colors[1], colors[2].name,
1154                )
1155            else:
1156                self.scene.widget.setColumnSideColorAnnotations(None)
1157
1158    def column_side_colors(self):
1159        var = self.column_annotation_color_var
1160        if var is None:
1161            return None
1162        table = self.col_annot_data
1163        return color_annotation_data(table, var)
1164
1165    def update_column_annotations(self):
1166        widget = self.scene.widget
1167        if self.data is not None and widget is not None:
1168            widget.setColumnLabelsPosition(self._column_label_pos)
1169
1170    def __adjust_font_size(self, diff):
1171        widget = self.scene.widget
1172        if widget is None:
1173            return
1174        curr = widget.font().pointSizeF()
1175        new = curr + diff
1176
1177        self.__font_dec.setEnabled(new > 1.0)
1178        self.__font_inc.setEnabled(new <= 32)
1179        if new > 1.0:
1180            font = QFont()
1181            font.setPointSizeF(new)
1182            widget.setFont(font)
1183
1184    def _on_view_context_menu(self, pos):
1185        widget = self.scene.widget
1186        if widget is None:
1187            return
1188        assert isinstance(widget, HeatmapGridWidget)
1189        menu = QMenu(self.view.viewport())
1190        menu.setAttribute(Qt.WA_DeleteOnClose)
1191        menu.addActions(self.view.actions())
1192        menu.addSeparator()
1193        menu.addActions([self.__font_inc, self.__font_dec])
1194        menu.addSeparator()
1195        a = QAction("Keep aspect ratio", menu, checkable=True)
1196        a.setChecked(self.keep_aspect)
1197
1198        def ontoggled(state):
1199            self.keep_aspect = state
1200            self.__aspect_mode_changed()
1201        a.toggled.connect(ontoggled)
1202        menu.addAction(a)
1203        menu.popup(self.view.viewport().mapToGlobal(pos))
1204
1205    def on_selection_finished(self):
1206        if self.scene.widget is not None:
1207            self.selected_rows = list(self.scene.widget.selectedRows())
1208        else:
1209            self.selected_rows = []
1210        self.commit()
1211
1212    def commit(self):
1213        data = None
1214        indices = None
1215        if self.merge_kmeans:
1216            merge_indices = self.merge_indices
1217        else:
1218            merge_indices = None
1219
1220        if self.input_data is not None and self.selected_rows:
1221            indices = self.selected_rows
1222            if merge_indices is not None:
1223                # expand merged indices
1224                indices = np.hstack([merge_indices[i] for i in indices])
1225
1226            data = self.input_data[indices]
1227
1228        self.Outputs.selected_data.send(data)
1229        self.Outputs.annotated_data.send(create_annotated_table(self.input_data, indices))
1230
1231    def onDeleteWidget(self):
1232        self.clear()
1233        super().onDeleteWidget()
1234
1235    def send_report(self):
1236        self.report_items((
1237            ("Columns:", "Clustering" if self.col_clustering else "No sorting"),
1238            ("Rows:", "Clustering" if self.row_clustering else "No sorting"),
1239            ("Split:",
1240             self.split_by_var is not None and self.split_by_var.name),
1241            ("Row annotation",
1242             self.annotation_var is not None and self.annotation_var.name),
1243        ))
1244        self.report_plot()
1245
1246    @classmethod
1247    def migrate_settings(cls, settings, version):
1248        if version is not None and version < 3:
1249            def st2cl(state: bool) -> Clustering:
1250                return Clustering.OrderedClustering if state else \
1251                    Clustering.None_
1252
1253            rc = settings.pop("row_clustering", False)
1254            cc = settings.pop("col_clustering", False)
1255            settings["row_clustering_method"] = st2cl(rc).name
1256            settings["col_clustering_method"] = st2cl(cc).name
1257
1258
1259# If StickyGraphicsView ever defines qt signals/slots/properties this will
1260# break
1261class GraphicsView(GraphicsWidgetView, StickyGraphicsView):
1262    pass
1263
1264
1265class RowPart(NamedTuple):
1266    """
1267    A row group
1268
1269    Attributes
1270    ----------
1271    title: str
1272        Group title
1273    indices : (N, ) Sequence[int]
1274        Indices in the input data to retrieve the row subset for the group.
1275    cluster : hierarchical.Tree optional
1276    cluster_ordered : hierarchical.Tree optional
1277    """
1278    title: str
1279    indices: Sequence[int]
1280    cluster: Optional[hierarchical.Tree] = None
1281    cluster_ordered: Optional[hierarchical.Tree] = None
1282
1283    @property
1284    def can_cluster(self) -> bool:
1285        if isinstance(self.indices, slice):
1286            return (self.indices.stop - self.indices.start) > 1
1287        else:
1288            return len(self.indices) > 1
1289
1290
1291class ColumnPart(NamedTuple):
1292    """
1293    A column group
1294
1295    Attributes
1296    ----------
1297    title : str
1298        Column group title
1299    indices : (N, ) int ndarray
1300        Indexes the input data to retrieve the column subset for the group.
1301    domain : List[Variable]
1302        List of variables in the group.
1303    cluster : hierarchical.Tree optional
1304    cluster_ordered : hierarchical.Tree optional
1305    """
1306    title: str
1307    indices: Sequence[int]
1308    domain: Sequence[int]
1309    cluster: Optional[hierarchical.Tree] = None
1310    cluster_ordered: Optional[hierarchical.Tree] = None
1311
1312    @property
1313    def can_cluster(self) -> bool:
1314        if isinstance(self.indices, slice):
1315            return (self.indices.stop - self.indices.start) > 1
1316        else:
1317            return len(self.indices) > 1
1318
1319
1320class Parts(NamedTuple):
1321    rows: Sequence[RowPart]
1322    columns: Sequence[ColumnPart]
1323    span: Tuple[float, float]
1324
1325
1326def join_elided(sep, maxlen, values, elidetemplate="..."):
1327    def generate(sep, ellidetemplate, values):
1328        count = len(values)
1329        length = 0
1330        parts = []
1331        for i, val in enumerate(values):
1332            elide = ellidetemplate.format(count - i) if count - i > 1 else ""
1333            parts.append(val)
1334            length += len(val) + (len(sep) if parts else 0)
1335            yield i, islice(parts, i + 1), length, elide
1336
1337    best = None
1338    for _, parts, length, elide in generate(sep, elidetemplate, values):
1339        if length > maxlen:
1340            if best is None:
1341                best = sep.join(parts) + elide
1342            return best
1343        fulllen = length + len(elide)
1344        if fulllen < maxlen or best is None:
1345            best = sep.join(parts) + elide
1346    return best
1347
1348
1349def column_str_from_table(
1350        table: Orange.data.Table,
1351        column: Union[int, Orange.data.Variable],
1352) -> np.ndarray:
1353    var = table.domain[column]
1354    data, _ = table.get_column_view(column)
1355    return np.asarray([var.str_val(v) for v in data], dtype=object)
1356
1357
1358def column_data_from_table(
1359        table: Orange.data.Table,
1360        column: Union[int, Orange.data.Variable],
1361) -> np.ndarray:
1362    var = table.domain[column]
1363    data, _ = table.get_column_view(column)
1364    if var.is_primitive() and data.dtype.kind != "f":
1365        data = data.astype(float)
1366    return data
1367
1368
1369def color_annotation_data(
1370        table: Table, var: Union[int, str, Variable]
1371) -> Tuple[np.ndarray, ColorMap, Variable]:
1372    var = table.domain[var]
1373    column_data = column_data_from_table(table, var)
1374    data, colormap = colorize(var, column_data)
1375    return data, colormap, var
1376
1377
1378def colorize(var: Variable, data: np.ndarray) -> Tuple[np.ndarray, ColorMap]:
1379    palette = var.palette  # type: Palette
1380    colors = np.array(
1381        [[c.red(), c.green(), c.blue()] for c in palette.qcolors_w_nan],
1382        dtype=np.uint8,
1383    )
1384    if var.is_discrete:
1385        mask = np.isnan(data)
1386        data = data.astype(int)
1387        data[mask] = -1
1388        if mask.any():
1389            values = (*var.values, "N/A")
1390        else:
1391            values = var.values
1392            colors = colors[: -1]
1393        return data, CategoricalColorMap(colors, values)
1394    elif var.is_continuous:
1395        span = np.nanmin(data), np.nanmax(data)
1396        if np.any(np.isnan(span)):
1397            span = 0, 1.
1398        return data, GradientColorMap(colors[:-1], span=span)
1399    else:
1400        raise TypeError
1401
1402
1403def aggregate(
1404        var: Variable, data: np.ndarray, groupindices: Sequence[Sequence[int]],
1405) -> np.ndarray:
1406    if var.is_string:
1407        join = lambda values: (join_elided(", ", 42, values, " ({} more)"))
1408        # collect all original labels for every merged row
1409        values = [data[indices] for indices in groupindices]
1410        data = [join(list(map(var.str_val, vals))) for vals in values]
1411        return np.array(data, dtype=object)
1412    elif var.is_continuous:
1413        data = [np.nanmean(data[indices]) if len(indices) else np.nan
1414                for indices in groupindices]
1415        return np.array(data, dtype=float)
1416    elif var.is_discrete:
1417        from Orange.statistics.util import nanmode
1418        data = [nanmode(data[indices])[0] if len(indices) else np.nan
1419                for indices in groupindices]
1420        return np.asarray(data, dtype=float)
1421    else:
1422        raise TypeError(type(var))
1423
1424
1425def agg_join_str(var, data, groupindices, maxlen=50, elidetemplate=" ({} more)"):
1426    join_s = lambda values: (
1427        join_elided(", ", maxlen, values, elidetemplate=elidetemplate)
1428    )
1429    join = lambda values: join_s(map(var.str_val, values))
1430    return aggregate_apply(join, data, groupindices)
1431
1432
1433_T = TypeVar("_T")
1434
1435
1436def aggregate_apply(
1437        f: Callable[[Sequence], _T],
1438        data: np.ndarray,
1439        groupindices: Sequence[Sequence[int]]
1440) -> Sequence[_T]:
1441    return [f(data[indices]) for indices in groupindices]
1442
1443
1444if __name__ == "__main__":  # pragma: no cover
1445    WidgetPreview(OWHeatMap).run(Table("brown-selected.tab"))
1446