1import math
2import enum
3from itertools import chain, zip_longest
4
5from typing import (
6    Optional, List, NamedTuple, Sequence, Tuple, Dict, Union, Iterable
7)
8
9import numpy as np
10
11from AnyQt.QtCore import (
12    Signal, Property, Qt, QRectF, QSizeF, QEvent, QPointF, QObject
13)
14from AnyQt.QtGui import QPixmap, QPalette, QPen, QColor, QFontMetrics
15from AnyQt.QtWidgets import (
16    QGraphicsWidget, QSizePolicy, QGraphicsGridLayout, QGraphicsRectItem,
17    QApplication, QGraphicsSceneMouseEvent, QGraphicsLinearLayout,
18    QGraphicsItem, QGraphicsSimpleTextItem, QGraphicsLayout,
19    QGraphicsLayoutItem
20)
21
22import pyqtgraph as pg
23
24from Orange.clustering import hierarchical
25from Orange.clustering.hierarchical import Tree
26from Orange.widgets.utils import apply_all
27from Orange.widgets.utils.colorpalettes import DefaultContinuousPalette
28from Orange.widgets.utils.graphicslayoutitem import SimpleLayoutItem, scaled
29from Orange.widgets.utils.graphicsflowlayout import GraphicsFlowLayout
30from Orange.widgets.utils.graphicspixmapwidget import GraphicsPixmapWidget
31from Orange.widgets.utils.image import qimage_from_array
32
33from Orange.widgets.utils.graphicstextlist import TextListWidget
34from Orange.widgets.utils.dendrogram import DendrogramWidget
35
36
37def leaf_indices(tree: Tree) -> Sequence[int]:
38    return [leaf.value.index for leaf in hierarchical.leaves(tree)]
39
40
41class ColorMap:
42    """Base color map class."""
43
44    def apply(self, data: np.ndarray) -> np.ndarray:
45        raise NotImplementedError
46
47    def replace(self, **kwargs) -> 'ColorMap':
48        raise NotImplementedError
49
50
51class CategoricalColorMap(ColorMap):
52    """A categorical color map."""
53    #: A color table. A (N, 3) uint8 ndarray
54    colortable: np.ndarray
55    #: A N sequence of categorical names
56    names: Sequence[str]
57
58    def __init__(self, colortable, names):
59        self.colortable = np.asarray(colortable)
60        self.names = names
61        assert len(colortable) == len(names)
62
63    def apply(self, data) -> np.ndarray:
64        data = np.asarray(data, dtype=int)
65        table = self.colortable[data]
66        return table
67
68    def replace(self, **kwargs) -> 'CategoricalColorMap':
69        kwargs.setdefault("colortable", self.colortable)
70        kwargs.setdefault("names", self.names)
71        return CategoricalColorMap(**kwargs)
72
73
74class GradientColorMap(ColorMap):
75    """Color map for the heatmap."""
76    #: A color table. A (N, 3) uint8 ndarray
77    colortable: np.ndarray
78    #: The data range (min, max)
79    span: Optional[Tuple[float, float]] = None
80    #: Lower and upper thresholding operator parameters. Expressed as relative
81    #: to the data span (range) so (0, 1) applies no thresholding, while
82    #: (0.05, 0.95) squeezes the effective range by 5% from both ends
83    thresholds: Tuple[float, float] = (0., 1.)
84    #: Should the color map be center and if so around which value.
85    center: Optional[float] = None
86
87    def __init__(self, colortable, thresholds=thresholds, center=None, span=None):
88        self.colortable = np.asarray(colortable)
89        self.thresholds = thresholds
90        assert thresholds[0] <= thresholds[1]
91        self.center = center
92        self.span = span
93
94    def adjust_levels(self, low: float, high: float) -> Tuple[float, float]:
95        """
96        Adjust the data low, high levels by applying the thresholding and
97        centering.
98        """
99        if np.any(np.isnan([low, high])):
100            return np.nan, np.nan
101        elif low > high:
102            raise ValueError(f"low > high ({low} > {high})")
103        threshold_low, threshold_high = self.thresholds
104        lt = low + (high - low) * threshold_low
105        ht = low + (high - low) * threshold_high
106        if self.center is not None:
107            center = self.center
108            maxoff = max(abs(center - lt), abs(center - ht))
109            lt = center - maxoff
110            ht = center + maxoff
111        return lt, ht
112
113    def apply(self, data) -> np.ndarray:
114        if self.span is None:
115            low, high = np.nanmin(data), np.nanmax(data)
116        else:
117            low, high = self.span
118        low, high = self.adjust_levels(low, high)
119        mask = np.isnan(data)
120        normalized = data - low
121        finfo = np.finfo(normalized.dtype)
122        if high - low <= 1 / finfo.max:
123            n_fact = finfo.max
124        else:
125            n_fact = 1. / (high - low)
126        # over/underflow to inf are expected and cliped with the rest in the
127        # next step
128        with np.errstate(over="ignore", under="ignore"):
129            normalized *= n_fact
130        normalized = np.clip(normalized, 0, 1, out=normalized)
131        table = np.empty_like(normalized, dtype=np.uint8)
132        ncolors = len(self.colortable)
133        assert ncolors - 1 <= np.iinfo(table.dtype).max
134        table = np.multiply(
135            normalized, ncolors - 1, out=table, where=~mask, casting="unsafe",
136        )
137        colors = self.colortable[table]
138        colors[mask] = 0
139        return colors
140
141    def replace(self, **kwargs) -> 'GradientColorMap':
142        kwargs.setdefault("colortable", self.colortable)
143        kwargs.setdefault("thresholds", self.thresholds)
144        kwargs.setdefault("center", self.center)
145        kwargs.setdefault("span", self.span)
146        return GradientColorMap(**kwargs)
147
148
149def normalized_indices(item: Union['RowItem', 'ColumnItem']) -> np.ndarray:
150    if item.cluster is None:
151        return np.asarray(item.indices, dtype=int)
152    else:
153        reorder = np.array(leaf_indices(item.cluster), dtype=int)
154        indices = np.asarray(item.indices, dtype=int)
155        return indices[reorder]
156
157
158class GridLayout(QGraphicsGridLayout):
159    def setGeometry(self, rect: QRectF) -> None:
160        super().setGeometry(rect)
161        parent = self.parentLayoutItem()
162        if isinstance(parent, HeatmapGridWidget):
163            parent.layoutDidActivate.emit()
164
165
166def grid_layout_row_geometry(layout: QGraphicsGridLayout, row: int) -> QRectF:
167    """
168    Return the geometry of the `row` in the grid layout.
169
170    If the row is empty return an empty geometry
171    """
172    if not 0 <= row < layout.rowCount():
173        return QRectF()
174
175    columns = layout.columnCount()
176    geometries: List[QRectF] = []
177    for item in (layout.itemAt(row, column) for column in range(columns)):
178        if item is not None:
179            itemgeom = item.geometry()
180            if itemgeom.isValid():
181                geometries.append(itemgeom)
182    if geometries:
183        rect = layout.geometry()
184        rect.setTop(min(g.top() for g in geometries))
185        rect.setBottom(max(g.bottom() for g in geometries))
186        return rect
187    else:
188        return QRectF()
189
190
191# Positions
192class Position(enum.IntFlag):
193    NoPosition = 0
194    Left, Top, Right, Bottom = 1, 2, 4, 8
195
196
197Left, Right = Position.Left, Position.Right
198Top, Bottom = Position.Top, Position.Bottom
199
200
201FLT_MAX = np.finfo(np.float32).max
202
203
204class HeatmapGridWidget(QGraphicsWidget):
205    """
206    A graphics widget with a annotated 2D grid of heatmaps.
207    """
208    class RowItem(NamedTuple):
209        """
210        A row group item
211
212        Attributes
213        ----------
214        title: str
215            Group title
216        indices : (N, ) Sequence[int]
217            Indices in the input data to retrieve the row subset for the group.
218        cluster : Optional[Tree]
219
220        """
221        title: str
222        indices: Sequence[int]
223        cluster: Optional[Tree] = None
224
225        @property
226        def size(self):
227            return len(self.indices)
228
229        @property
230        def normalized_indices(self):
231            return normalized_indices(self)
232
233    class ColumnItem(NamedTuple):
234        """
235        A column group
236
237        Attributes
238        ----------
239        title: str
240            Column group title
241        indices: (N, ) Sequence[int]
242            Indexes the input data to retrieve the column subset for the group.
243        cluster: Optional[Tree]
244        """
245        title: str
246        indices: Sequence[int]
247        cluster: Optional[Tree] = None
248
249        @property
250        def size(self):
251            return len(self.indices)
252
253        @property
254        def normalized_indices(self):
255            return normalized_indices(self)
256
257    class Parts(NamedTuple):
258        #: define the splits of data over rows, and define dendrogram and/or row
259        #: reordering
260        rows: Sequence['RowItem']
261        #: define the splits of data over columns, and define dendrogram and/or
262        #: column reordering
263        columns: Sequence['ColumnItem']
264        #: span (min, max) of the values in `data`
265        span: Tuple[float, float]
266        #: the complete data array (shape (N, M))
267        data: np.ndarray
268        #: Row names (len N)
269        row_names: Optional[Sequence[str]] = None
270        #: Column names (len M)
271        col_names: Optional[Sequence[str]] = None
272
273    # Positions
274    class Position(enum.IntFlag):
275        NoPosition = 0
276        Left, Top, Right, Bottom = 1, 2, 4, 8
277
278    Left, Right = Position.Left, Position.Right
279    Top, Bottom = Position.Top, Position.Bottom
280
281    #: The widget's layout has activated (i.e. did a relayout
282    #: of the widget's contents)
283    layoutDidActivate = Signal()
284
285    #: Signal emitted when the user finished a selection operation
286    selectionFinished = Signal()
287    #: Signal emitted on any change in selection
288    selectionChanged = Signal()
289
290    NoPosition, PositionTop, PositionBottom = 0, Top, Bottom
291
292    # Start row/column where the heatmap items are inserted
293    # (after the titles/legends/dendrograms)
294    Row0 = 5
295    Col0 = 3
296    # The (color) legend row and column
297    LegendRow, LegendCol = 0, 4
298    # The column for the vertical dendrogram
299    DendrogramColumn = 1
300    # Horizontal split title column
301    GroupTitleRow = 1
302    # The row for the horizontal dendrograms
303    DendrogramRow = 2
304    # The row for top column annotation labels
305    TopLabelsRow = 3
306    # Top color annotation row
307    TopAnnotationRow = 4
308    # Vertical split title column
309    GroupTitleColumn = 0
310
311    def __init__(self, parent=None, **kwargs):
312        super().__init__(parent, **kwargs)
313        self.__spacing = 3
314        self.__colormap = GradientColorMap(
315            DefaultContinuousPalette.lookup_table()
316        )
317        self.parts = None  # type: Optional[Parts]
318        self.__averagesVisible = False
319        self.__legendVisible = True
320        self.__aspectRatioMode = Qt.IgnoreAspectRatio
321        self.__columnLabelPosition = Top
322        self.heatmap_widget_grid = []  # type: List[List[GraphicsHeatmapWidget]]
323        self.row_annotation_widgets = []  # type: List[TextListWidget]
324        self.col_annotation_widgets = []  # type: List[TextListWidget]
325        self.col_annotation_widgets_top = []  # type: List[TextListWidget]
326        self.col_annotation_widgets_bottom = []  # type: List[TextListWidget]
327        self.col_dendrograms = []  # type: List[Optional[DendrogramWidget]]
328        self.row_dendrograms = []  # type: List[Optional[DendrogramWidget]]
329        self.right_side_colors = []  # type: List[Optional[GraphicsPixmapWidget]]
330        self.top_side_colors = []  # type: List[Optional[GraphicsPixmapWidget]]
331        self.heatmap_colormap_legend = None
332        self.bottom_legend_container = None
333        self.__layout = GridLayout()
334        self.__layout.setSpacing(self.__spacing)
335        self.setLayout(self.__layout)
336        self.__selection_manager = SelectionManager(self)
337        self.__selection_manager.selection_changed.connect(
338            self.__update_selection_geometry
339        )
340        self.__selection_manager.selection_finished.connect(
341            self.selectionFinished
342        )
343        self.__selection_manager.selection_changed.connect(
344            self.selectionChanged
345        )
346        self.selection_rects = []
347
348    def clear(self):
349        """Clear the widget."""
350        for i in reversed(range(self.__layout.count())):
351            item = self.__layout.itemAt(i)
352            self.__layout.removeAt(i)
353            if item is not None and item.graphicsItem() is not None:
354                remove_item(item.graphicsItem())
355
356        self.heatmap_widget_grid = []
357        self.row_annotation_widgets = []
358        self.col_annotation_widgets = []
359        self.col_dendrograms = []
360        self.row_dendrograms = []
361        self.right_side_colors = []
362        self.top_side_colors = []
363        self.heatmap_colormap_legend = None
364        self.bottom_legend_container = None
365        self.parts = None
366        self.updateGeometry()
367
368    def setHeatmaps(self, parts: 'Parts') -> None:
369        """Set the heatmap parts for display"""
370        self.clear()
371        grid = self.__layout
372        N, M = len(parts.rows), len(parts.columns)
373
374        # Start row/column where the heatmap items are inserted
375        # (after the titles/legends/dendrograms)
376        Row0 = self.Row0
377        Col0 = self.Col0
378        # The column for the vertical dendrograms
379        DendrogramColumn = self.DendrogramColumn
380        # The row for the horizontal dendrograms
381        DendrogramRow = self.DendrogramRow
382        RightLabelColumn = Col0 + 2 * M + 1
383        TopAnnotationRow = self.TopAnnotationRow
384        TopLabelsRow = self.TopLabelsRow
385        BottomLabelsRow = Row0 + N
386        colormap = self.__colormap
387        column_dendrograms: List[Optional[DendrogramWidget]] = [None] * M
388        row_dendrograms: List[Optional[DendrogramWidget]] = [None] * N
389        right_side_colors: List[Optional[GraphicsPixmapWidget]] = [None] * N
390        top_side_colors: List[Optional[GraphicsPixmapWidget]] = [None] * M
391
392        data = parts.data
393        if parts.col_names is None:
394            col_names = np.full(data.shape[1], "", dtype=object)
395        else:
396            col_names = np.asarray(parts.col_names, dtype=object)
397        if parts.row_names is None:
398            row_names = np.full(data.shape[0], "", dtype=object)
399        else:
400            row_names = np.asarray(parts.row_names, dtype=object)
401
402        assert len(col_names) == data.shape[1]
403        assert len(row_names) == data.shape[0]
404
405        for i, rowitem in enumerate(parts.rows):
406            if rowitem.title:
407                item = QGraphicsSimpleTextItem(rowitem.title, parent=self)
408                item.setTransform(item.transform().rotate(-90))
409                item = SimpleLayoutItem(item, parent=grid, anchor=(0, 1),
410                                        anchorItem=(0, 0))
411                item.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Maximum)
412                grid.addItem(item, Row0 + i, self.GroupTitleColumn,
413                             alignment=Qt.AlignCenter)
414            if rowitem.cluster:
415                dendrogram = DendrogramWidget(
416                    parent=self,
417                    selectionMode=DendrogramWidget.NoSelection,
418                    hoverHighlightEnabled=True,
419                )
420                dendrogram.set_root(rowitem.cluster)
421                dendrogram.setMaximumWidth(100)
422                dendrogram.setMinimumWidth(100)
423                # Ignore dendrogram vertical size hint (heatmap's size
424                # should define the  row's vertical size).
425                dendrogram.setSizePolicy(
426                    QSizePolicy.Expanding, QSizePolicy.Ignored)
427                dendrogram.itemClicked.connect(
428                    lambda item, partindex=i:
429                    self.__select_by_cluster(item, partindex)
430                )
431                grid.addItem(dendrogram, Row0 + i, DendrogramColumn)
432                row_dendrograms[i] = dendrogram
433
434        for j, colitem in enumerate(parts.columns):
435            if colitem.title:
436                item = SimpleLayoutItem(
437                    QGraphicsSimpleTextItem(colitem.title, parent=self),
438                    parent=grid, anchor=(0.5, 0.5), anchorItem=(0.5, 0.5)
439                )
440                item.setSizePolicy(QSizePolicy.Preferred, QSizePolicy.Fixed)
441                grid.addItem(item, self.GroupTitleRow, Col0 + 2 * j + 1)
442
443            if colitem.cluster:
444                dendrogram = DendrogramWidget(
445                    parent=self,
446                    orientation=DendrogramWidget.Top,
447                    selectionMode=DendrogramWidget.NoSelection,
448                    hoverHighlightEnabled=False
449                )
450                dendrogram.set_root(colitem.cluster)
451                dendrogram.setMaximumHeight(100)
452                dendrogram.setMinimumHeight(100)
453                # Ignore dendrogram horizontal size hint (heatmap's width
454                # should define the column width).
455                dendrogram.setSizePolicy(
456                    QSizePolicy.Ignored, QSizePolicy.Expanding)
457                grid.addItem(dendrogram, DendrogramRow, Col0 + 2 * j + 1)
458                column_dendrograms[j] = dendrogram
459
460        heatmap_widgets = []
461        for i in range(N):
462            heatmap_row = []
463            for j in range(M):
464                row_ix = parts.rows[i].normalized_indices
465                col_ix = parts.columns[j].normalized_indices
466                X_part = data[np.ix_(row_ix, col_ix)]
467                hw = GraphicsHeatmapWidget(
468                    aspectRatioMode=self.__aspectRatioMode,
469                    data=X_part, span=parts.span, colormap=colormap,
470                )
471                sp = QSizePolicy(QSizePolicy.Preferred, QSizePolicy.Preferred)
472                sp.setHeightForWidth(True)
473                hw.setSizePolicy(sp)
474
475                avgimg = GraphicsHeatmapWidget(
476                    data=np.nanmean(X_part, axis=1, keepdims=True),
477                    span=parts.span, colormap=colormap,
478                    visible=self.__averagesVisible,
479                    minimumSize=QSizeF(5, -1)
480                )
481                avgimg.setSizePolicy(QSizePolicy.Preferred, QSizePolicy.Ignored)
482                grid.addItem(avgimg, Row0 + i, Col0 + 2 * j)
483                grid.addItem(hw, Row0 + i, Col0 + 2 * j + 1)
484
485                heatmap_row.append(hw)
486            heatmap_widgets.append(heatmap_row)
487
488        for j in range(M):
489            grid.setColumnStretchFactor(Col0 + 2 * j, 1)
490            grid.setColumnStretchFactor(
491                Col0 + 2 * j + 1, parts.columns[j].size)
492        grid.setColumnStretchFactor(RightLabelColumn - 1, 1)
493
494        for i in range(N):
495            grid.setRowStretchFactor(Row0 + i, parts.rows[i].size)
496
497        row_annotation_widgets = []
498        col_annotation_widgets = []
499        col_annotation_widgets_top = []
500        col_annotation_widgets_bottom = []
501
502        for i, rowitem in enumerate(parts.rows):
503            # Right row annotations
504            indices = np.asarray(rowitem.normalized_indices, dtype=np.intp)
505            labels = row_names[indices]
506            labelslist = TextListWidget(
507                items=labels, parent=self, orientation=Qt.Vertical,
508                alignment=Qt.AlignLeft | Qt.AlignVCenter,
509                sizePolicy=QSizePolicy(QSizePolicy.Fixed, QSizePolicy.Ignored),
510                autoScale=True,
511                objectName="row-labels-right"
512            )
513            labelslist.setMaximumWidth(300)
514            rowauxsidecolor = GraphicsPixmapWidget(
515                parent=self, visible=False,
516                scaleContents=True, aspectMode=Qt.IgnoreAspectRatio,
517                sizePolicy=QSizePolicy(QSizePolicy.Maximum, QSizePolicy.Ignored),
518                minimumSize=QSizeF(10, -1)
519            )
520            grid.addItem(rowauxsidecolor, Row0 + i, RightLabelColumn - 1)
521            grid.addItem(labelslist, Row0 + i, RightLabelColumn, Qt.AlignLeft)
522            row_annotation_widgets.append(labelslist)
523            right_side_colors[i] = rowauxsidecolor
524
525        for j, colitem in enumerate(parts.columns):
526            # Top attr annotations
527            indices = np.asarray(colitem.normalized_indices, dtype=np.intp)
528            labels = col_names[indices]
529            sp = QSizePolicy(QSizePolicy.Ignored, QSizePolicy.Fixed)
530            sp.setHeightForWidth(True)
531            labelslist = TextListWidget(
532                items=labels, parent=self,
533                alignment=Qt.AlignLeft | Qt.AlignVCenter,
534                orientation=Qt.Horizontal,
535                autoScale=True,
536                sizePolicy=sp,
537                visible=self.__columnLabelPosition & Position.Top,
538                objectName="column-labels-top",
539            )
540            colauxsidecolor = GraphicsPixmapWidget(
541                parent=self, visible=False,
542                scaleContents=True, aspectMode=Qt.IgnoreAspectRatio,
543                sizePolicy=QSizePolicy(QSizePolicy.Ignored,
544                                       QSizePolicy.Maximum),
545                minimumSize=QSizeF(-1, 10)
546            )
547
548            grid.addItem(labelslist, TopLabelsRow, Col0 + 2 * j + 1,
549                         Qt.AlignBottom | Qt.AlignLeft)
550            grid.addItem(colauxsidecolor, TopAnnotationRow, Col0 + 2 * j + 1)
551            col_annotation_widgets.append(labelslist)
552            col_annotation_widgets_top.append(labelslist)
553            top_side_colors[j] = colauxsidecolor
554
555            # Bottom attr annotations
556            labelslist = TextListWidget(
557                items=labels, parent=self,
558                alignment=Qt.AlignRight | Qt.AlignVCenter,
559                orientation=Qt.Horizontal,
560                autoScale=True,
561                sizePolicy=sp,
562                visible=self.__columnLabelPosition & Position.Bottom,
563                objectName="column-labels-bottom",
564            )
565            grid.addItem(labelslist, BottomLabelsRow, Col0 + 2 * j + 1)
566            col_annotation_widgets.append(labelslist)
567            col_annotation_widgets_bottom.append(labelslist)
568
569        row_color_annotation_header = QGraphicsSimpleTextItem("", self)
570        row_color_annotation_header.setTransform(
571            row_color_annotation_header.transform().rotate(-90))
572
573        grid.addItem(SimpleLayoutItem(
574            row_color_annotation_header, anchor=(0, 1),
575            aspectMode=Qt.KeepAspectRatio,
576            sizePolicy=QSizePolicy(QSizePolicy.Fixed, QSizePolicy.Preferred),
577            ),
578            0, RightLabelColumn - 1, self.TopLabelsRow + 1, 1,
579            alignment=Qt.AlignBottom
580        )
581
582        col_color_annotation_header = QGraphicsSimpleTextItem("", self)
583        grid.addItem(SimpleLayoutItem(
584            col_color_annotation_header, anchor=(1, 1), anchorItem=(1, 1),
585            aspectMode=Qt.KeepAspectRatio,
586            sizePolicy=QSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed),
587        ),
588            TopAnnotationRow, 0, 1, Col0, alignment=Qt.AlignRight
589        )
590
591        legend = GradientLegendWidget(
592            parts.span[0], parts.span[1],
593            colormap,
594            parent=self,
595            minimumSize=QSizeF(100, 20),
596            visible=self.__legendVisible,
597            sizePolicy=QSizePolicy(QSizePolicy.Ignored, QSizePolicy.Fixed)
598        )
599        legend.setMaximumWidth(300)
600        grid.addItem(legend, self.LegendRow, self.LegendCol, 1, M * 2 - 1)
601
602        def container(parent=None, orientation=Qt.Horizontal, margin=0, spacing=0, **kwargs):
603            widget = QGraphicsWidget(**kwargs)
604            layout = QGraphicsLinearLayout(orientation)
605            layout.setContentsMargins(margin, margin, margin, margin)
606            layout.setSpacing(spacing)
607            widget.setLayout(layout)
608            if parent is not None:
609                widget.setParentItem(parent)
610
611            return widget
612        # Container for color annotation legends
613        legend_container = container(
614            spacing=3,
615            sizePolicy=QSizePolicy(QSizePolicy.Maximum, QSizePolicy.Fixed),
616            visible=False, objectName="annotation-legend-container"
617        )
618        legend_container_rows = container(
619            parent=legend_container,
620            sizePolicy=QSizePolicy(QSizePolicy.Maximum, QSizePolicy.Fixed),
621            visible=False, objectName="row-annotation-legend-container"
622        )
623        legend_container_cols = container(
624            parent=legend_container,
625            sizePolicy=QSizePolicy(QSizePolicy.Maximum, QSizePolicy.Fixed),
626            visible=False, objectName="col-annotation-legend-container"
627        )
628        # ? keep refs to child containers; segfault in scene.clear() ?
629        legend_container._refs = (legend_container_rows, legend_container_cols)
630        legend_container.layout().addItem(legend_container_rows)
631        legend_container.layout().addItem(legend_container_cols)
632
633        grid.addItem(legend_container, BottomLabelsRow + 1, Col0 + 1, 1, M * 2 - 1,
634                     alignment=Qt.AlignRight)
635
636        self.heatmap_widget_grid = heatmap_widgets
637        self.row_annotation_widgets = row_annotation_widgets
638        self.col_annotation_widgets = col_annotation_widgets
639        self.col_annotation_widgets_top = col_annotation_widgets_top
640        self.col_annotation_widgets_bottom = col_annotation_widgets_bottom
641        self.col_dendrograms = column_dendrograms
642        self.row_dendrograms = row_dendrograms
643        self.right_side_colors = right_side_colors
644        self.top_side_colors = top_side_colors
645        self.heatmap_colormap_legend = legend
646        self.bottom_legend_container = legend_container
647        self.parts = parts
648        self.__selection_manager.set_heatmap_widgets(heatmap_widgets)
649
650    def legendVisible(self) -> bool:
651        """Is the colormap legend visible."""
652        return self.__legendVisible
653
654    def setLegendVisible(self, visible: bool) -> None:
655        """Set colormap legend visible state."""
656        self.__legendVisible = visible
657        legends = [
658            self.heatmap_colormap_legend,
659            self.bottom_legend_container
660        ]
661        apply_all(filter(None, legends), lambda item: item.setVisible(visible))
662
663    legendVisible_ = Property(bool, legendVisible, setLegendVisible)
664
665    def setAspectRatioMode(self, mode: Qt.AspectRatioMode) -> None:
666        """
667        Set the scale aspect mode.
668
669        The widget will try to keep (hint) the scale ratio via the sizeHint
670        reimplementation.
671        """
672        if self.__aspectRatioMode != mode:
673            self.__aspectRatioMode = mode
674            for hm in chain.from_iterable(self.heatmap_widget_grid):
675                hm.setAspectMode(mode)
676            sp = self.sizePolicy()
677            sp.setHeightForWidth(mode != Qt.IgnoreAspectRatio)
678            self.setSizePolicy(sp)
679
680    def aspectRatioMode(self) -> Qt.AspectRatioMode:
681        return self.__aspectRatioMode
682
683    aspectRatioMode_ = Property(
684        Qt.AspectRatioMode, aspectRatioMode, setAspectRatioMode
685    )
686
687    def setColumnLabelsPosition(self, position: Position) -> None:
688        self.__columnLabelPosition = position
689        top = bool(position & HeatmapGridWidget.PositionTop)
690        bottom = bool(position & HeatmapGridWidget.PositionBottom)
691        for w in self.col_annotation_widgets_top:
692            w.setVisible(top)
693            w.setMaximumHeight(FLT_MAX if top else 0)
694        for w in self.col_annotation_widgets_bottom:
695            w.setVisible(bottom)
696            w.setMaximumHeight(FLT_MAX if bottom else 0)
697
698    def columnLabelPosition(self) -> Position:
699        return self.__columnLabelPosition
700
701    def setColumnLabels(self, data: Optional[Sequence[str]]) -> None:
702        """Set the column labels to display. If None clear the row names."""
703        if data is not None:
704            data = np.asarray(data, dtype=object)
705        for top, bottom, part in zip(self.col_annotation_widgets_top,
706                                     self.col_annotation_widgets_bottom,
707                                     self.parts.columns):
708            if data is not None:
709                top.setItems(data[part.normalized_indices])
710                bottom.setItems(data[part.normalized_indices])
711            else:
712                top.clear()
713                bottom.clear()
714
715    def setRowLabels(self, data: Optional[Sequence[str]]):
716        """
717        Set the row labels to display. If None clear the row names.
718        """
719        if data is not None:
720            data = np.asarray(data, dtype=object)
721        for widget, part in zip(self.row_annotation_widgets, self.parts.rows):
722            if data is not None:
723                widget.setItems(data[part.normalized_indices])
724            else:
725                widget.clear()
726
727    def setRowLabelsVisible(self, visible: bool):
728        """Set row labels visibility"""
729        for widget in self.row_annotation_widgets:
730            widget.setVisible(visible)
731
732    def setRowSideColorAnnotations(
733            self, data: np.ndarray, colormap: ColorMap = None, name=""
734    ):
735        """
736        Set an optional row side color annotations.
737
738        Parameters
739        ----------
740        data: Optional[np.ndarray]
741            A sequence such that it is accepted by `colormap.apply`. If None
742            then the color annotations are cleared.
743        colormap: ColorMap
744        name: str
745            Name/title for the annotation column.
746        """
747        col = self.Col0 + 2 * len(self.parts.columns)
748        legend_layout = self.bottom_legend_container.layout()
749        legend_container = legend_layout.itemAt(1)
750        self.__setColorAnnotationsHelper(
751            data, colormap, name, self.right_side_colors, col, Qt.Vertical,
752            legend_container
753        )
754        legend_container.setVisible(True)
755
756    def setColumnSideColorAnnotations(
757            self, data: np.ndarray, colormap: ColorMap = None, name=""
758    ):
759        """
760        Set an optional column color annotations.
761
762        Parameters
763        ----------
764        data: Optional[np.ndarray]
765            A sequence such that it is accepted by `colormap.apply`. If None
766            then the color annotations are cleared.
767        colormap: ColorMap
768        name: str
769            Name/title for the annotation column.
770        """
771        row = self.TopAnnotationRow
772        legend_layout = self.bottom_legend_container.layout()
773        legend_container = legend_layout.itemAt(0)
774        self.__setColorAnnotationsHelper(
775            data, colormap, name, self.top_side_colors, row, Qt.Horizontal,
776            legend_container)
777        legend_container.setVisible(True)
778
779    def __setColorAnnotationsHelper(
780            self, data: np.ndarray, colormap: ColorMap, name: str,
781            items: List[GraphicsPixmapWidget], position: int,
782            orientation: Qt.Orientation, legend_container: QGraphicsWidget):
783        layout = self.__layout
784        if orientation == Qt.Horizontal:
785            nameitem = layout.itemAt(position, 0)
786        else:
787            nameitem = layout.itemAt(self.TopLabelsRow, position)
788        size = QFontMetrics(self.font()).lineSpacing()
789        layout_clear(legend_container.layout())
790
791        def grid_set_maximum_size(position: int, size: float):
792            if orientation == Qt.Horizontal:
793                layout.setRowMaximumHeight(position, size)
794            else:
795                layout.setColumnMaximumWidth(position, size)
796
797        def set_minimum_size(item: QGraphicsLayoutItem, size: float):
798            if orientation == Qt.Horizontal:
799                item.setMinimumHeight(size)
800            else:
801                item.setMinimumWidth(size)
802            item.updateGeometry()
803
804        def reset_minimum_size(item: QGraphicsLayoutItem):
805            set_minimum_size(item, -1)
806
807        def set_hidden(item: GraphicsPixmapWidget):
808            item.setVisible(False)
809            reset_minimum_size(item,)
810
811        def set_visible(item: GraphicsPixmapWidget):
812            item.setVisible(True)
813            set_minimum_size(item, 10)
814
815        def set_preferred_size(item, size):
816            if orientation == Qt.Horizontal:
817                item.setPreferredHeight(size)
818            else:
819                item.setPreferredWidth(size)
820            item.updateGeometry()
821
822        if data is None:
823            apply_all(filter(None, items), set_hidden)
824            grid_set_maximum_size(position, 0)
825
826            nameitem.item.setVisible(False)
827            nameitem.updateGeometry()
828            legend_container.setVisible(False)
829            return
830        else:
831            apply_all(filter(None, items), set_visible)
832            grid_set_maximum_size(position, FLT_MAX)
833            legend_container.setVisible(True)
834
835        if orientation == Qt.Horizontal:
836            parts = self.parts.columns
837        else:
838            parts = self.parts.rows
839        for p, item in zip(parts, items):
840            if item is not None:
841                subset = data[p.normalized_indices]
842                subset = colormap.apply(subset)
843                rgbdata = subset.reshape((-1, 1, subset.shape[-1]))
844                if orientation == Qt.Horizontal:
845                    rgbdata = rgbdata.reshape((1, -1, rgbdata.shape[-1]))
846                img = qimage_from_array(rgbdata)
847                item.setPixmap(img)
848                item.setVisible(True)
849                set_preferred_size(item, size)
850
851        nameitem.item.setText(name)
852        nameitem.item.setVisible(True)
853        set_preferred_size(nameitem, size)
854
855        container = legend_container.layout()
856        if isinstance(colormap, CategoricalColorMap):
857            legend = CategoricalColorLegend(
858                colormap, title=name,
859                orientation=Qt.Horizontal,
860                sizePolicy=QSizePolicy(QSizePolicy.Maximum, QSizePolicy.Maximum),
861                visible=self.__legendVisible,
862            )
863            container.addItem(legend)
864        elif isinstance(colormap, GradientColorMap):
865            legend = GradientLegendWidget(
866                *colormap.span, colormap, title=name,
867                sizePolicy=QSizePolicy(QSizePolicy.MinimumExpanding, QSizePolicy.Maximum),
868            )
869            legend.setMinimumWidth(100)
870            container.addItem(legend)
871
872    def headerGeometry(self) -> QRectF:
873        """Return the 'header' geometry.
874
875        This is the top part of the widget spanning the top dendrogram,
876        column labels... (can be empty).
877        """
878        layout = self.__layout
879        geom1 = grid_layout_row_geometry(layout, self.DendrogramRow)
880        geom2 = grid_layout_row_geometry(layout, self.TopLabelsRow)
881        first = grid_layout_row_geometry(layout, self.TopLabelsRow + 1)
882        geom = geom1.united(geom2)
883        if geom.isValid():
884            if first.isValid():
885                geom.setBottom(geom.bottom() / 2.0 + first.top() / 2.0)
886            return QRectF(self.geometry().topLeft(), geom.bottomRight())
887        else:
888            return QRectF()
889
890    def footerGeometry(self) -> QRectF:
891        """Return the 'footer' geometry.
892
893        This is the bottom part of the widget spanning the bottom column labels
894        when applicable (can be empty).
895        """
896        layout = self.__layout
897        row = self.Row0 + len(self.heatmap_widget_grid)
898        geom = grid_layout_row_geometry(layout, row)
899        nextolast = grid_layout_row_geometry(layout, row - 1)
900        if geom.isValid():
901            if nextolast.isValid():
902                geom.setTop(geom.top() / 2 + nextolast.bottom() / 2)
903            return QRectF(geom.topLeft(), self.geometry().bottomRight())
904        else:
905            return QRectF()
906
907    def setColorMap(self, colormap: GradientColorMap) -> None:
908        self.__colormap = colormap
909        for hm in chain.from_iterable(self.heatmap_widget_grid):
910            hm.setColorMap(colormap)
911        for item in self.__avgitems():
912            item.setColorMap(colormap)
913        for ch in self.childItems():
914            if isinstance(ch, GradientLegendWidget):
915                ch.setColorMap(colormap)
916
917    def colorMap(self) -> ColorMap:
918        return self.__colormap
919
920    def __avgitems(self):
921        if self.parts is None:
922            return
923        N = len(self.parts.rows)
924        M = len(self.parts.columns)
925        layout = self.__layout
926        for i in range(N):
927            for j in range(M):
928                item = layout.itemAt(self.Row0 + i, self.Col0 + 2 * j)
929                if isinstance(item, GraphicsHeatmapWidget):
930                    yield item
931
932    def setShowAverages(self, visible):
933        self.__averagesVisible = visible
934        for item in self.__avgitems():
935            item.setVisible(visible)
936            item.setPreferredWidth(0 if not visible else 10)
937
938    def event(self, event):
939        # type: (QEvent) -> bool
940        rval = super().event(event)
941        if event.type() == QEvent.LayoutRequest and self.layout() is not None:
942            self.__update_selection_geometry()
943        return rval
944
945    def setGeometry(self, rect: QRectF) -> None:
946        super().setGeometry(rect)
947        self.__update_selection_geometry()
948
949    def __update_selection_geometry(self):
950        scene = self.scene()
951        self.__selection_manager.update_selection_rects()
952        rects = self.__selection_manager.selection_rects
953        palette = self.palette()
954        pen = QPen(palette.color(QPalette.Foreground), 2)
955        pen.setCosmetic(True)
956        brushcolor = QColor(palette.color(QPalette.Highlight))
957        brushcolor.setAlpha(50)
958        selection_rects = []
959        for rect, item in zip_longest(rects, self.selection_rects):
960            assert rect is not None or item is not None
961            if item is None:
962                item = QGraphicsRectItem(rect, None)
963                item.setPen(pen)
964                item.setBrush(brushcolor)
965                scene.addItem(item)
966                selection_rects.append(item)
967            elif rect is not None:
968                item.setRect(rect)
969                item.setPen(pen)
970                item.setBrush(brushcolor)
971                selection_rects.append(item)
972            else:
973                scene.removeItem(item)
974        self.selection_rects = selection_rects
975
976    def __select_by_cluster(self, item, dendrogramindex):
977        # User clicked on a dendrogram node.
978        # Select all rows corresponding to the cluster item.
979        node = item.node
980        try:
981            hm = self.heatmap_widget_grid[dendrogramindex][0]
982        except IndexError:
983            pass
984        else:
985            key = QApplication.keyboardModifiers()
986            clear = not (key & ((Qt.ControlModifier | Qt.ShiftModifier |
987                                 Qt.AltModifier)))
988            remove = (key & (Qt.ControlModifier | Qt.AltModifier))
989            append = (key & Qt.ControlModifier)
990            self.__selection_manager.selection_add(
991                node.value.first, node.value.last - 1, hm,
992                clear=clear, remove=remove, append=append)
993
994    def heatmapAtPos(self, pos: QPointF) -> Optional['GraphicsHeatmapWidget']:
995        for hw in chain.from_iterable(self.heatmap_widget_grid):
996            if hw.contains(hw.mapFromItem(self, pos)):
997                return hw
998        return None
999
1000    __selecting = False
1001
1002    def mousePressEvent(self, event: QGraphicsSceneMouseEvent) -> None:
1003        pos = event.pos()
1004        heatmap = self.heatmapAtPos(pos)
1005        if heatmap and event.button() & Qt.LeftButton:
1006            row, _ = heatmap.heatmapCellAt(heatmap.mapFromScene(event.scenePos()))
1007            if row != -1:
1008                self.__selection_manager.selection_start(heatmap, event)
1009                self.__selecting = True
1010                event.setAccepted(True)
1011                return
1012        super().mousePressEvent(event)
1013
1014    def mouseMoveEvent(self, event: QGraphicsSceneMouseEvent) -> None:
1015        pos = event.pos()
1016        heatmap = self.heatmapAtPos(pos)
1017        if heatmap and event.buttons() & Qt.LeftButton and self.__selecting:
1018            row, _ = heatmap.heatmapCellAt(heatmap.mapFromScene(pos))
1019            if row != -1:
1020                self.__selection_manager.selection_update(heatmap, event)
1021                event.setAccepted(True)
1022                return
1023        super().mouseMoveEvent(event)
1024
1025    def mouseReleaseEvent(self, event: QGraphicsSceneMouseEvent) -> None:
1026        pos = event.pos()
1027        if event.button() == Qt.LeftButton and self.__selecting:
1028            self.__selection_manager.selection_finish(
1029                self.heatmapAtPos(pos), event)
1030            self.__selecting = False
1031        super().mouseReleaseEvent(event)
1032
1033    def selectedRows(self) -> Sequence[int]:
1034        """Return the current selected rows."""
1035        if self.parts is None:
1036            return []
1037        visual_indices = self.__selection_manager.selections
1038        indices = np.hstack([r.normalized_indices for r in self.parts.rows])
1039        return indices[visual_indices].tolist()
1040
1041    def selectRows(self, selection: Sequence[int]):
1042        """Select the specified rows. Previous selection is cleared."""
1043        if self.parts is not None:
1044            indices = np.hstack([r.normalized_indices for r in self.parts.rows])
1045        else:
1046            indices = []
1047        condition = np.in1d(indices, selection)
1048        visual_indices = np.flatnonzero(condition)
1049        self.__selection_manager.select_rows(visual_indices.tolist())
1050
1051
1052class GraphicsHeatmapWidget(QGraphicsWidget):
1053    __aspectMode = Qt.KeepAspectRatio
1054
1055    def __init__(
1056            self, parent=None,
1057            data: Optional[np.ndarray] = None,
1058            span: Tuple[float, float] = (0., 1.),
1059            colormap: Optional[ColorMap] = None,
1060            aspectRatioMode=Qt.KeepAspectRatio,
1061            **kwargs
1062    ) -> None:
1063        super().__init__(None, **kwargs)
1064        self.setAcceptHoverEvents(True)
1065        self.__levels = span
1066        if colormap is None:
1067            colormap = GradientColorMap(DefaultContinuousPalette.lookup_table())
1068        self.__colormap = colormap
1069        self.__data: Optional[np.ndarray] = None
1070        self.__pixmap = QPixmap()
1071        self.__aspectMode = aspectRatioMode
1072
1073        layout = QGraphicsLinearLayout(Qt.Horizontal)
1074        layout.setContentsMargins(0, 0, 0, 0)
1075        self.__pixmapItem = GraphicsPixmapWidget(
1076            self, scaleContents=True, aspectMode=Qt.IgnoreAspectRatio
1077        )
1078        sp = QSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
1079        sp.setHeightForWidth(True)
1080        self.__pixmapItem.setSizePolicy(sp)
1081        layout.addItem(self.__pixmapItem)
1082        self.setLayout(layout)
1083        sp = QSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
1084        sp.setHeightForWidth(True)
1085        self.setSizePolicy(sp)
1086        self.setHeatmapData(data)
1087
1088        if parent is not None:
1089            self.setParentItem(parent)
1090
1091    def setAspectMode(self, mode: Qt.AspectRatioMode) -> None:
1092        if self.__aspectMode != mode:
1093            self.__aspectMode = mode
1094            sp = self.sizePolicy()
1095            sp.setHeightForWidth(mode != Qt.IgnoreAspectRatio)
1096            self.setSizePolicy(sp)
1097            self.updateGeometry()
1098
1099    def aspectMode(self) -> Qt.AspectRatioMode:
1100        return self.__aspectMode
1101
1102    def sizeHint(self, which: Qt.SizeHint, constraint=QSizeF(-1, -1)) -> QSizeF:
1103        if which == Qt.PreferredSize and constraint.width() >= 0:
1104            sh = super().sizeHint(which)
1105            return scaled(sh, QSizeF(constraint.width(), -1), self.__aspectMode)
1106        return super().sizeHint(which, constraint)
1107
1108    def clear(self):
1109        """Clear/reset the widget."""
1110        self.__data = None
1111        self.__pixmap = QPixmap()
1112        self.__pixmapItem.setPixmap(self.__pixmap)
1113        self.updateGeometry()
1114
1115    def setHeatmapData(self, data):
1116        """Set the heatmap data for display."""
1117        if self.__data is not data:
1118            self.clear()
1119            self.__data = data
1120            self.__updatePixmap()
1121            self.update()
1122
1123    def heatmapData(self) -> Optional[np.ndarray]:
1124        if self.__data is not None:
1125            v = self.__data.view()
1126            v.flags.writeable = False
1127            return v
1128        else:
1129            return None
1130
1131    def pixmap(self) -> QPixmap:
1132        return self.__pixmapItem.pixmap()
1133
1134    def setLevels(self, levels: Tuple[float, float]) -> None:
1135        if levels != self.__levels:
1136            self.__levels = levels
1137            self.__updatePixmap()
1138            self.update()
1139
1140    def setColorMap(self, colormap: ColorMap):
1141        self.__colormap = colormap
1142        self.__updatePixmap()
1143
1144    def colorMap(self,) -> ColorMap:
1145        return self.__colormap
1146
1147    def __updatePixmap(self):
1148        if self.__data is not None:
1149            ll, lh = self.__levels
1150            cmap = self.__colormap.replace(span=(ll, lh))
1151            rgb = cmap.apply(self.__data)
1152            rgb[np.isnan(self.__data)] = (100, 100, 100)
1153            qimage = qimage_from_array(rgb)
1154            self.__pixmap = QPixmap.fromImage(qimage)
1155        else:
1156            self.__pixmap = QPixmap()
1157
1158        self.__pixmapItem.setPixmap(self.__pixmap)
1159        self.__updateSizeHints()
1160
1161    def changeEvent(self, event: QEvent) -> None:
1162        super().changeEvent(event)
1163        if event.type() == QEvent.FontChange:
1164            self.__updateSizeHints()
1165
1166    def __updateSizeHints(self):
1167        hmsize = QSizeF(self.__pixmap.size())
1168        size = QFontMetrics(self.font()).lineSpacing()
1169        self.__pixmapItem.setMinimumSize(hmsize)
1170        self.__pixmapItem.setPreferredSize(hmsize * size)
1171
1172    def heatmapCellAt(self, pos: QPointF) -> Tuple[int, int]:
1173        """Return the cell row, column from `pos` in local coordinates.
1174        """
1175        if self.__pixmap.isNull() or not \
1176                self.__pixmapItem.geometry().contains(pos):
1177            return -1, -1
1178        assert self.__data is not None
1179        item_clicked = self.__pixmapItem
1180        pos = self.mapToItem(item_clicked, pos)
1181        size = self.__pixmapItem.size()
1182
1183        x, y = pos.x(), pos.y()
1184
1185        N, M = self.__data.shape
1186        fx = x / size.width()
1187        fy = y / size.height()
1188        i = min(int(math.floor(fy * N)), N - 1)
1189        j = min(int(math.floor(fx * M)), M - 1)
1190        return i, j
1191
1192    def heatmapCellRect(self, row: int, column: int) -> QRectF:
1193        """Return a rectangle in local coordinates containing the cell
1194        at `row` and `column`.
1195        """
1196        size = self.__pixmap.size()
1197        if not (0 <= column < size.width() or 0 <= row < size.height()):
1198            return QRectF()
1199
1200        topleft = QPointF(column, row)
1201        bottomright = QPointF(column + 1, row + 1)
1202        t = self.__pixmapItem.pixmapTransform()
1203        rect = t.mapRect(QRectF(topleft, bottomright))
1204        rect.translated(self.__pixmapItem.pos())
1205        return rect
1206
1207    def rowRect(self, row):
1208        """
1209        Return a QRectF in local coordinates containing the entire row.
1210        """
1211        rect = self.heatmapCellRect(row, 0)
1212        rect.setLeft(0)
1213        rect.setRight(self.size().width())
1214        return rect
1215
1216    def heatmapCellToolTip(self, row, column):
1217        return "{}, {}: {:g}".format(row, column, self.__data[row, column])
1218
1219    def hoverMoveEvent(self, event):
1220        pos = event.pos()
1221        row, column = self.heatmapCellAt(pos)
1222        if row != -1:
1223            tooltip = self.heatmapCellToolTip(row, column)
1224            self.setToolTip(tooltip)
1225        return super().hoverMoveEvent(event)
1226
1227
1228def remove_item(item: QGraphicsItem) -> None:
1229    scene = item.scene()
1230    if scene is not None:
1231        scene.removeItem(item)
1232    else:
1233        item.setParentItem(None)
1234
1235
1236class _GradientLegendAxisItem(pg.AxisItem):
1237    def boundingRect(self):
1238        br = super().boundingRect()
1239        if self.orientation in ["top", "bottom"]:
1240            # adjust brect (extend in horizontal direction). pg.AxisItem has
1241            # only fixed constant adjustment for tick text over-flow.
1242            font = self.style.get("tickFont")
1243            if font is None:
1244                font = self.font()
1245            fm = QFontMetrics(font)
1246            w = fm.horizontalAdvance('0.0000000') / 2  # bad, should use _tickValues
1247            geomw = self.geometry().size().width()
1248            maxw = max(geomw + 2 * w, br.width())
1249            if br.width() < maxw:
1250                adjust = (maxw - br.width()) / 2
1251                br = br.adjusted(-adjust, 0, adjust, 0)
1252        return br
1253
1254    def showEvent(self, event):
1255        super().showEvent(event)
1256        # AxisItem resizes to 0 width/height when hidden, does not update when
1257        # shown implicitly (i.e. a parent becomes visible).
1258        # Use showLabel(False) which should update the size without actually
1259        # changing anything else (no public interface to explicitly recalc
1260        # fixed sizes).
1261        self.showLabel(False)
1262
1263
1264class GradientLegendWidget(QGraphicsWidget):
1265    def __init__(
1266            self, low, high, colormap: GradientColorMap, parent=None, title="",
1267            **kwargs
1268    ):
1269        kwargs.setdefault(
1270            "sizePolicy", QSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed)
1271        )
1272        super().__init__(None, **kwargs)
1273        self.low = low
1274        self.high = high
1275        self.colormap = colormap
1276        self.title = title
1277
1278        layout = QGraphicsLinearLayout(Qt.Vertical)
1279        layout.setContentsMargins(0, 0, 0, 0)
1280        layout.setSpacing(0)
1281        self.setLayout(layout)
1282        if title:
1283            titleitem = SimpleLayoutItem(
1284                QGraphicsSimpleTextItem(title, self), parent=layout,
1285                anchor=(0.5, 1.), anchorItem=(0.5, 1.0)
1286            )
1287            layout.addItem(titleitem)
1288        self.__axis = axis = _GradientLegendAxisItem(
1289            orientation="top", maxTickLength=3)
1290        axis.setRange(low, high)
1291        layout.addItem(axis)
1292        pen = QPen(self.palette().color(QPalette.Text))
1293        axis.setPen(pen)
1294        self.__pixitem = GraphicsPixmapWidget(
1295            parent=self, scaleContents=True, aspectMode=Qt.IgnoreAspectRatio
1296        )
1297        self.__pixitem.setSizePolicy(QSizePolicy.Ignored, QSizePolicy.Preferred)
1298        self.__pixitem.setMinimumHeight(12)
1299        layout.addItem(self.__pixitem)
1300        self.__update()
1301
1302        if parent is not None:
1303            self.setParentItem(parent)
1304
1305    def setRange(self, low, high):
1306        if self.low != low or self.high != high:
1307            self.low = low
1308            self.high = high
1309            self.__update()
1310
1311    def setColorMap(self, colormap: ColorMap) -> None:
1312        """Set the color map"""
1313        self.colormap = colormap
1314        self.__update()
1315
1316    def colorMap(self) -> ColorMap:
1317        return self.colormap
1318
1319    def __update(self):
1320        low, high = self.low, self.high
1321        data = np.linspace(low, high, num=1000).reshape((1, -1))
1322        cmap = self.colormap.replace(span=(low, high))
1323        qimg = qimage_from_array(cmap.apply(data))
1324        self.__pixitem.setPixmap(QPixmap.fromImage(qimg))
1325        if self.colormap.center is not None \
1326                and low < self.colormap.center < high:
1327            tick_values = [low, self.colormap.center, high]
1328        else:
1329            tick_values = [low, high]
1330        tickformat = "{:.6g}".format
1331        ticks = [(val, tickformat(val)) for val in tick_values]
1332        self.__axis.setRange(low, high)
1333        self.__axis.setTicks([ticks])
1334
1335        self.updateGeometry()
1336
1337    def changeEvent(self, event: QEvent) -> None:
1338        if event.type() == QEvent.PaletteChange:
1339            pen = QPen(self.palette().color(QPalette.Text))
1340            self.__axis.setPen(pen)
1341        super().changeEvent(event)
1342
1343
1344class CategoricalColorLegend(QGraphicsWidget):
1345    def __init__(
1346            self, colormap: CategoricalColorMap, title="",
1347            orientation=Qt.Vertical, parent=None, **kwargs,
1348    ) -> None:
1349        self.__colormap = colormap
1350        self.__title = title
1351        self.__names = colormap.names
1352        self.__layout = QGraphicsLinearLayout(Qt.Vertical)
1353        self.__flow = GraphicsFlowLayout()
1354        self.__layout.addItem(self.__flow)
1355        self.__flow.setHorizontalSpacing(4)
1356        self.__flow.setVerticalSpacing(4)
1357        self.__orientation = orientation
1358        kwargs.setdefault(
1359            "sizePolicy", QSizePolicy(QSizePolicy.Maximum, QSizePolicy.Maximum)
1360        )
1361        super().__init__(None, **kwargs)
1362        self.setLayout(self.__layout)
1363        self._setup()
1364
1365        if parent is not None:
1366            self.setParent(parent)
1367
1368    def setOrientation(self, orientation):
1369        if self.__orientation != orientation:
1370            self._clear()
1371            self._setup()
1372
1373    def orientation(self):
1374        return self.__orientation
1375
1376    def _clear(self):
1377        items = list(layout_items(self.__flow))
1378        layout_clear(self.__flow)
1379        for item in items:
1380            if isinstance(item, SimpleLayoutItem):
1381                remove_item(item.item)
1382        # remove 'title' item if present
1383        items = [item for item in layout_items(self.__layout)
1384                 if item is not self.__flow]
1385        for item in items:
1386            self.__layout.removeItem(item)
1387            if isinstance(item, SimpleLayoutItem):
1388                remove_item(item.item)
1389
1390    def _setup(self):
1391        # setup the layout
1392        colors = self.__colormap.colortable
1393        names = self.__colormap.names
1394        title = self.__title
1395        layout = self.__layout
1396        flow = self.__flow
1397        assert flow.count() == 0
1398        font = self.font()
1399        fm = QFontMetrics(font)
1400        size = fm.horizontalAdvance("X")
1401        headeritem = None
1402        if title:
1403            headeritem = QGraphicsSimpleTextItem(title)
1404            headeritem.setFont(font)
1405
1406        def centered(item):
1407            return SimpleLayoutItem(item, anchor=(0.5, 0.5), anchorItem=(0.5, 0.5))
1408
1409        def legend_item_pair(color: QColor, size: float, text: str):
1410            coloritem = QGraphicsRectItem(0, 0, size, size)
1411            coloritem.setBrush(color)
1412            textitem = QGraphicsSimpleTextItem()
1413            textitem.setFont(font)
1414            textitem.setText(text)
1415            layout = QGraphicsLinearLayout(Qt.Horizontal)
1416            layout.setSpacing(2)
1417            layout.addItem(centered(coloritem))
1418            layout.addItem(SimpleLayoutItem(textitem))
1419            return coloritem, textitem, layout
1420
1421        items = [legend_item_pair(QColor(*color), size, name)
1422                 for color, name in zip(colors, names)]
1423
1424        for sym, label, pair_layout in items:
1425            flow.addItem(pair_layout)
1426
1427        if headeritem:
1428            layout.insertItem(0, centered(headeritem))
1429
1430    def changeEvent(self, event: QEvent) -> None:
1431        if event.type() == QEvent.FontChange:
1432            self._updateFont(self.font())
1433        super().changeEvent(event)
1434
1435    def _updateFont(self, font):
1436        w = QFontMetrics(font).horizontalAdvance("X")
1437        for item in filter(
1438                lambda item: isinstance(item, SimpleLayoutItem),
1439                layout_items_recursive(self.__layout)
1440        ):
1441            if isinstance(item.item, QGraphicsSimpleTextItem):
1442                item.item.setFont(font)
1443            elif isinstance(item.item, QGraphicsRectItem):
1444                item.item.setRect(QRectF(0, 0, w, w))
1445            item.updateGeometry()
1446
1447
1448def layout_items(layout: QGraphicsLayout) -> Iterable[QGraphicsLayoutItem]:
1449    for item in map(layout.itemAt, range(layout.count())):
1450        if item is not None:
1451            yield item
1452
1453
1454def layout_items_recursive(layout: QGraphicsLayout):
1455    for item in map(layout.itemAt, range(layout.count())):
1456        if item is not None:
1457            if item.isLayout():
1458                assert isinstance(item, QGraphicsLayout)
1459                yield from layout_items_recursive(item)
1460            else:
1461                yield item
1462
1463
1464def layout_clear(layout: QGraphicsLayout) -> None:
1465    for i in reversed(range(layout.count())):
1466        item = layout.itemAt(i)
1467        layout.removeAt(i)
1468        if item is not None and item.graphicsItem() is not None:
1469            remove_item(item.graphicsItem())
1470
1471
1472class SelectionManager(QObject):
1473    """
1474    Selection manager for heatmap rows
1475    """
1476    selection_changed = Signal()
1477    selection_finished = Signal()
1478
1479    def __init__(self, parent=None, **kwargs):
1480        super().__init__(parent, **kwargs)
1481        self.selections = []
1482        self.selection_ranges = []
1483        self.selection_ranges_temp = []
1484        self.selection_rects = []
1485        self.heatmaps = []
1486        self._heatmap_ranges: Dict[GraphicsHeatmapWidget, Tuple[int, int]] = {}
1487        self._start_row = 0
1488
1489    def clear(self):
1490        self.remove_rows(self.selection)
1491
1492    def set_heatmap_widgets(self, widgets):
1493        # type: (Sequence[Sequence[GraphicsHeatmapWidget]] )-> None
1494        self.remove_rows(self.selections)
1495        self.heatmaps = list(zip(*widgets))
1496
1497        # Compute row ranges for all heatmaps
1498        self._heatmap_ranges = {}
1499        for group in zip(*widgets):
1500            start = end = 0
1501            for heatmap in group:
1502                end += heatmap.heatmapData().shape[0]
1503                self._heatmap_ranges[heatmap] = (start, end)
1504                start = end
1505
1506    def select_rows(self, rows, heatmap=None, clear=True):
1507        """Add `rows` to selection. If `heatmap` is provided the rows
1508        are mapped from the local indices to global heatmap indices. If `clear`
1509        then remove previous rows.
1510        """
1511        if heatmap is not None:
1512            start, _ = self._heatmap_ranges[heatmap]
1513            rows = [start + r for r in rows]
1514
1515        old_selection = list(self.selections)
1516        if clear:
1517            self.selections = rows
1518        else:
1519            self.selections = sorted(set(self.selections + rows))
1520
1521        if self.selections != old_selection:
1522            self.update_selection_rects()
1523            self.selection_changed.emit()
1524
1525    def remove_rows(self, rows):
1526        """Remove `rows` from the selection.
1527        """
1528        old_selection = list(self.selections)
1529        self.selections = sorted(set(self.selections) - set(rows))
1530        if old_selection != self.selections:
1531            self.update_selection_rects()
1532            self.selection_changed.emit()
1533
1534    def combined_ranges(self, ranges):
1535        combined_ranges = set()
1536        for start, end in ranges:
1537            if start <= end:
1538                rng = range(start, end + 1)
1539            else:
1540                rng = range(start, end - 1, -1)
1541            combined_ranges.update(rng)
1542        return sorted(combined_ranges)
1543
1544    def selection_start(self, heatmap_widget, event):
1545        """ Selection  started by `heatmap_widget` due to `event`.
1546        """
1547        pos = heatmap_widget.mapFromScene(event.scenePos())
1548        row, _ = heatmap_widget.heatmapCellAt(pos)
1549
1550        start, _ = self._heatmap_ranges[heatmap_widget]
1551        row = start + row
1552        self._start_row = row
1553        range = (row, row)
1554        self.selection_ranges_temp = []
1555        if event.modifiers() & Qt.ControlModifier:
1556            self.selection_ranges_temp = self.selection_ranges
1557            self.selection_ranges = self.remove_range(
1558                self.selection_ranges, row, row, append=True)
1559        elif event.modifiers() & Qt.ShiftModifier:
1560            self.selection_ranges.append(range)
1561        elif event.modifiers() & Qt.AltModifier:
1562            self.selection_ranges = self.remove_range(
1563                self.selection_ranges, row, row, append=False)
1564        else:
1565            self.selection_ranges = [range]
1566        self.select_rows(self.combined_ranges(self.selection_ranges))
1567
1568    def selection_update(self, heatmap_widget, event):
1569        """ Selection updated by `heatmap_widget due to `event` (mouse drag).
1570        """
1571        pos = heatmap_widget.mapFromScene(event.scenePos())
1572        row, _ = heatmap_widget.heatmapCellAt(pos)
1573        if row < 0:
1574            return
1575
1576        start, _ = self._heatmap_ranges[heatmap_widget]
1577        row = start + row
1578        if event.modifiers() & Qt.ControlModifier:
1579            self.selection_ranges = self.remove_range(
1580                self.selection_ranges_temp, self._start_row, row, append=True)
1581        elif event.modifiers() & Qt.AltModifier:
1582            self.selection_ranges = self.remove_range(
1583                self.selection_ranges, self._start_row, row, append=False)
1584        else:
1585            if self.selection_ranges:
1586                self.selection_ranges[-1] = (self._start_row, row)
1587            else:
1588                self.selection_ranges = [(row, row)]
1589
1590        self.select_rows(self.combined_ranges(self.selection_ranges))
1591
1592    def selection_finish(self, heatmap_widget, event):
1593        """ Selection finished by `heatmap_widget due to `event`.
1594        """
1595        if heatmap_widget is not None:
1596            pos = heatmap_widget.mapFromScene(event.scenePos())
1597            row, _ = heatmap_widget.heatmapCellAt(pos)
1598            start, _ = self._heatmap_ranges[heatmap_widget]
1599            row = start + row
1600            if event.modifiers() & Qt.ControlModifier:
1601                pass
1602            elif event.modifiers() & Qt.AltModifier:
1603                self.selection_ranges = self.remove_range(
1604                    self.selection_ranges, self._start_row, row, append=False)
1605            else:
1606                if len(self.selection_ranges) > 0:
1607                    self.selection_ranges[-1] = (self._start_row, row)
1608        self.select_rows(self.combined_ranges(self.selection_ranges))
1609        self.selection_finished.emit()
1610
1611    def selection_add(self, start, end, heatmap=None, clear=True,
1612                      remove=False, append=False):
1613        """ Add/remove a selection range from `start` to `end`.
1614        """
1615        if heatmap is not None:
1616            _start, _ = self._heatmap_ranges[heatmap]
1617            start = _start + start
1618            end = _start + end
1619
1620        if clear:
1621            self.selection_ranges = []
1622        if remove:
1623            self.selection_ranges = self.remove_range(
1624                self.selection_ranges, start, end, append=append)
1625        else:
1626            self.selection_ranges.append((start, end))
1627        self.select_rows(self.combined_ranges(self.selection_ranges))
1628        self.selection_finished.emit()
1629
1630    def remove_range(self, ranges, start, end, append=False):
1631        if start > end:
1632            start, end = end, start
1633        comb_ranges = [i for i in self.combined_ranges(ranges)
1634                       if i > end or i < start]
1635        if append:
1636            comb_ranges += [i for i in range(start, end + 1)
1637                            if i not in self.combined_ranges(ranges)]
1638            comb_ranges = sorted(comb_ranges)
1639        return self.combined_to_ranges(comb_ranges)
1640
1641    def combined_to_ranges(self, comb_ranges):
1642        ranges = []
1643        if len(comb_ranges) > 0:
1644            i, start, end = 0, comb_ranges[0], comb_ranges[0]
1645            for val in comb_ranges[1:]:
1646                i += 1
1647                if start + i < val:
1648                    ranges.append((start, end))
1649                    i, start = 0, val
1650                end = val
1651            ranges.append((start, end))
1652        return ranges
1653
1654    def update_selection_rects(self):
1655        """ Update the selection rects.
1656        """
1657        def group_selections(selections):
1658            """Group selections along with heatmaps.
1659            """
1660            rows2hm = self.rows_to_heatmaps()
1661            selections = iter(selections)
1662            try:
1663                start = end = next(selections)
1664            except StopIteration:
1665                return
1666            end_heatmaps = rows2hm[end]
1667            try:
1668                while True:
1669                    new_end = next(selections)
1670                    new_end_heatmaps = rows2hm[new_end]
1671                    if new_end > end + 1 or new_end_heatmaps != end_heatmaps:
1672                        yield start, end, end_heatmaps
1673                        start = end = new_end
1674                        end_heatmaps = new_end_heatmaps
1675                    else:
1676                        end = new_end
1677
1678            except StopIteration:
1679                yield start, end, end_heatmaps
1680
1681        def selection_rect(start, end, heatmaps):
1682            rect = QRectF()
1683            for heatmap in heatmaps:
1684                h_start, _ = self._heatmap_ranges[heatmap]
1685                rect |= heatmap.mapToScene(heatmap.rowRect(start - h_start)).boundingRect()
1686                rect |= heatmap.mapToScene(heatmap.rowRect(end - h_start)).boundingRect()
1687            return rect
1688
1689        self.selection_rects = []
1690        for start, end, heatmaps in group_selections(self.selections):
1691            rect = selection_rect(start, end, heatmaps)
1692            self.selection_rects.append(rect)
1693
1694    def rows_to_heatmaps(self):
1695        heatmap_groups = zip(*self.heatmaps)
1696        rows2hm = {}
1697        for heatmaps in heatmap_groups:
1698            hm = heatmaps[0]
1699            start, end = self._heatmap_ranges[hm]
1700            rows2hm.update(dict.fromkeys(range(start, end), heatmaps))
1701        return rows2hm
1702
1703
1704Parts = HeatmapGridWidget.Parts
1705RowItem = HeatmapGridWidget.RowItem
1706ColumnItem = HeatmapGridWidget.ColumnItem
1707