1import math 2import enum 3from itertools import chain, zip_longest 4 5from typing import ( 6 Optional, List, NamedTuple, Sequence, Tuple, Dict, Union, Iterable 7) 8 9import numpy as np 10 11from AnyQt.QtCore import ( 12 Signal, Property, Qt, QRectF, QSizeF, QEvent, QPointF, QObject 13) 14from AnyQt.QtGui import QPixmap, QPalette, QPen, QColor, QFontMetrics 15from AnyQt.QtWidgets import ( 16 QGraphicsWidget, QSizePolicy, QGraphicsGridLayout, QGraphicsRectItem, 17 QApplication, QGraphicsSceneMouseEvent, QGraphicsLinearLayout, 18 QGraphicsItem, QGraphicsSimpleTextItem, QGraphicsLayout, 19 QGraphicsLayoutItem 20) 21 22import pyqtgraph as pg 23 24from Orange.clustering import hierarchical 25from Orange.clustering.hierarchical import Tree 26from Orange.widgets.utils import apply_all 27from Orange.widgets.utils.colorpalettes import DefaultContinuousPalette 28from Orange.widgets.utils.graphicslayoutitem import SimpleLayoutItem, scaled 29from Orange.widgets.utils.graphicsflowlayout import GraphicsFlowLayout 30from Orange.widgets.utils.graphicspixmapwidget import GraphicsPixmapWidget 31from Orange.widgets.utils.image import qimage_from_array 32 33from Orange.widgets.utils.graphicstextlist import TextListWidget 34from Orange.widgets.utils.dendrogram import DendrogramWidget 35 36 37def leaf_indices(tree: Tree) -> Sequence[int]: 38 return [leaf.value.index for leaf in hierarchical.leaves(tree)] 39 40 41class ColorMap: 42 """Base color map class.""" 43 44 def apply(self, data: np.ndarray) -> np.ndarray: 45 raise NotImplementedError 46 47 def replace(self, **kwargs) -> 'ColorMap': 48 raise NotImplementedError 49 50 51class CategoricalColorMap(ColorMap): 52 """A categorical color map.""" 53 #: A color table. A (N, 3) uint8 ndarray 54 colortable: np.ndarray 55 #: A N sequence of categorical names 56 names: Sequence[str] 57 58 def __init__(self, colortable, names): 59 self.colortable = np.asarray(colortable) 60 self.names = names 61 assert len(colortable) == len(names) 62 63 def apply(self, data) -> np.ndarray: 64 data = np.asarray(data, dtype=int) 65 table = self.colortable[data] 66 return table 67 68 def replace(self, **kwargs) -> 'CategoricalColorMap': 69 kwargs.setdefault("colortable", self.colortable) 70 kwargs.setdefault("names", self.names) 71 return CategoricalColorMap(**kwargs) 72 73 74class GradientColorMap(ColorMap): 75 """Color map for the heatmap.""" 76 #: A color table. A (N, 3) uint8 ndarray 77 colortable: np.ndarray 78 #: The data range (min, max) 79 span: Optional[Tuple[float, float]] = None 80 #: Lower and upper thresholding operator parameters. Expressed as relative 81 #: to the data span (range) so (0, 1) applies no thresholding, while 82 #: (0.05, 0.95) squeezes the effective range by 5% from both ends 83 thresholds: Tuple[float, float] = (0., 1.) 84 #: Should the color map be center and if so around which value. 85 center: Optional[float] = None 86 87 def __init__(self, colortable, thresholds=thresholds, center=None, span=None): 88 self.colortable = np.asarray(colortable) 89 self.thresholds = thresholds 90 assert thresholds[0] <= thresholds[1] 91 self.center = center 92 self.span = span 93 94 def adjust_levels(self, low: float, high: float) -> Tuple[float, float]: 95 """ 96 Adjust the data low, high levels by applying the thresholding and 97 centering. 98 """ 99 if np.any(np.isnan([low, high])): 100 return np.nan, np.nan 101 elif low > high: 102 raise ValueError(f"low > high ({low} > {high})") 103 threshold_low, threshold_high = self.thresholds 104 lt = low + (high - low) * threshold_low 105 ht = low + (high - low) * threshold_high 106 if self.center is not None: 107 center = self.center 108 maxoff = max(abs(center - lt), abs(center - ht)) 109 lt = center - maxoff 110 ht = center + maxoff 111 return lt, ht 112 113 def apply(self, data) -> np.ndarray: 114 if self.span is None: 115 low, high = np.nanmin(data), np.nanmax(data) 116 else: 117 low, high = self.span 118 low, high = self.adjust_levels(low, high) 119 mask = np.isnan(data) 120 normalized = data - low 121 finfo = np.finfo(normalized.dtype) 122 if high - low <= 1 / finfo.max: 123 n_fact = finfo.max 124 else: 125 n_fact = 1. / (high - low) 126 # over/underflow to inf are expected and cliped with the rest in the 127 # next step 128 with np.errstate(over="ignore", under="ignore"): 129 normalized *= n_fact 130 normalized = np.clip(normalized, 0, 1, out=normalized) 131 table = np.empty_like(normalized, dtype=np.uint8) 132 ncolors = len(self.colortable) 133 assert ncolors - 1 <= np.iinfo(table.dtype).max 134 table = np.multiply( 135 normalized, ncolors - 1, out=table, where=~mask, casting="unsafe", 136 ) 137 colors = self.colortable[table] 138 colors[mask] = 0 139 return colors 140 141 def replace(self, **kwargs) -> 'GradientColorMap': 142 kwargs.setdefault("colortable", self.colortable) 143 kwargs.setdefault("thresholds", self.thresholds) 144 kwargs.setdefault("center", self.center) 145 kwargs.setdefault("span", self.span) 146 return GradientColorMap(**kwargs) 147 148 149def normalized_indices(item: Union['RowItem', 'ColumnItem']) -> np.ndarray: 150 if item.cluster is None: 151 return np.asarray(item.indices, dtype=int) 152 else: 153 reorder = np.array(leaf_indices(item.cluster), dtype=int) 154 indices = np.asarray(item.indices, dtype=int) 155 return indices[reorder] 156 157 158class GridLayout(QGraphicsGridLayout): 159 def setGeometry(self, rect: QRectF) -> None: 160 super().setGeometry(rect) 161 parent = self.parentLayoutItem() 162 if isinstance(parent, HeatmapGridWidget): 163 parent.layoutDidActivate.emit() 164 165 166def grid_layout_row_geometry(layout: QGraphicsGridLayout, row: int) -> QRectF: 167 """ 168 Return the geometry of the `row` in the grid layout. 169 170 If the row is empty return an empty geometry 171 """ 172 if not 0 <= row < layout.rowCount(): 173 return QRectF() 174 175 columns = layout.columnCount() 176 geometries: List[QRectF] = [] 177 for item in (layout.itemAt(row, column) for column in range(columns)): 178 if item is not None: 179 itemgeom = item.geometry() 180 if itemgeom.isValid(): 181 geometries.append(itemgeom) 182 if geometries: 183 rect = layout.geometry() 184 rect.setTop(min(g.top() for g in geometries)) 185 rect.setBottom(max(g.bottom() for g in geometries)) 186 return rect 187 else: 188 return QRectF() 189 190 191# Positions 192class Position(enum.IntFlag): 193 NoPosition = 0 194 Left, Top, Right, Bottom = 1, 2, 4, 8 195 196 197Left, Right = Position.Left, Position.Right 198Top, Bottom = Position.Top, Position.Bottom 199 200 201FLT_MAX = np.finfo(np.float32).max 202 203 204class HeatmapGridWidget(QGraphicsWidget): 205 """ 206 A graphics widget with a annotated 2D grid of heatmaps. 207 """ 208 class RowItem(NamedTuple): 209 """ 210 A row group item 211 212 Attributes 213 ---------- 214 title: str 215 Group title 216 indices : (N, ) Sequence[int] 217 Indices in the input data to retrieve the row subset for the group. 218 cluster : Optional[Tree] 219 220 """ 221 title: str 222 indices: Sequence[int] 223 cluster: Optional[Tree] = None 224 225 @property 226 def size(self): 227 return len(self.indices) 228 229 @property 230 def normalized_indices(self): 231 return normalized_indices(self) 232 233 class ColumnItem(NamedTuple): 234 """ 235 A column group 236 237 Attributes 238 ---------- 239 title: str 240 Column group title 241 indices: (N, ) Sequence[int] 242 Indexes the input data to retrieve the column subset for the group. 243 cluster: Optional[Tree] 244 """ 245 title: str 246 indices: Sequence[int] 247 cluster: Optional[Tree] = None 248 249 @property 250 def size(self): 251 return len(self.indices) 252 253 @property 254 def normalized_indices(self): 255 return normalized_indices(self) 256 257 class Parts(NamedTuple): 258 #: define the splits of data over rows, and define dendrogram and/or row 259 #: reordering 260 rows: Sequence['RowItem'] 261 #: define the splits of data over columns, and define dendrogram and/or 262 #: column reordering 263 columns: Sequence['ColumnItem'] 264 #: span (min, max) of the values in `data` 265 span: Tuple[float, float] 266 #: the complete data array (shape (N, M)) 267 data: np.ndarray 268 #: Row names (len N) 269 row_names: Optional[Sequence[str]] = None 270 #: Column names (len M) 271 col_names: Optional[Sequence[str]] = None 272 273 # Positions 274 class Position(enum.IntFlag): 275 NoPosition = 0 276 Left, Top, Right, Bottom = 1, 2, 4, 8 277 278 Left, Right = Position.Left, Position.Right 279 Top, Bottom = Position.Top, Position.Bottom 280 281 #: The widget's layout has activated (i.e. did a relayout 282 #: of the widget's contents) 283 layoutDidActivate = Signal() 284 285 #: Signal emitted when the user finished a selection operation 286 selectionFinished = Signal() 287 #: Signal emitted on any change in selection 288 selectionChanged = Signal() 289 290 NoPosition, PositionTop, PositionBottom = 0, Top, Bottom 291 292 # Start row/column where the heatmap items are inserted 293 # (after the titles/legends/dendrograms) 294 Row0 = 5 295 Col0 = 3 296 # The (color) legend row and column 297 LegendRow, LegendCol = 0, 4 298 # The column for the vertical dendrogram 299 DendrogramColumn = 1 300 # Horizontal split title column 301 GroupTitleRow = 1 302 # The row for the horizontal dendrograms 303 DendrogramRow = 2 304 # The row for top column annotation labels 305 TopLabelsRow = 3 306 # Top color annotation row 307 TopAnnotationRow = 4 308 # Vertical split title column 309 GroupTitleColumn = 0 310 311 def __init__(self, parent=None, **kwargs): 312 super().__init__(parent, **kwargs) 313 self.__spacing = 3 314 self.__colormap = GradientColorMap( 315 DefaultContinuousPalette.lookup_table() 316 ) 317 self.parts = None # type: Optional[Parts] 318 self.__averagesVisible = False 319 self.__legendVisible = True 320 self.__aspectRatioMode = Qt.IgnoreAspectRatio 321 self.__columnLabelPosition = Top 322 self.heatmap_widget_grid = [] # type: List[List[GraphicsHeatmapWidget]] 323 self.row_annotation_widgets = [] # type: List[TextListWidget] 324 self.col_annotation_widgets = [] # type: List[TextListWidget] 325 self.col_annotation_widgets_top = [] # type: List[TextListWidget] 326 self.col_annotation_widgets_bottom = [] # type: List[TextListWidget] 327 self.col_dendrograms = [] # type: List[Optional[DendrogramWidget]] 328 self.row_dendrograms = [] # type: List[Optional[DendrogramWidget]] 329 self.right_side_colors = [] # type: List[Optional[GraphicsPixmapWidget]] 330 self.top_side_colors = [] # type: List[Optional[GraphicsPixmapWidget]] 331 self.heatmap_colormap_legend = None 332 self.bottom_legend_container = None 333 self.__layout = GridLayout() 334 self.__layout.setSpacing(self.__spacing) 335 self.setLayout(self.__layout) 336 self.__selection_manager = SelectionManager(self) 337 self.__selection_manager.selection_changed.connect( 338 self.__update_selection_geometry 339 ) 340 self.__selection_manager.selection_finished.connect( 341 self.selectionFinished 342 ) 343 self.__selection_manager.selection_changed.connect( 344 self.selectionChanged 345 ) 346 self.selection_rects = [] 347 348 def clear(self): 349 """Clear the widget.""" 350 for i in reversed(range(self.__layout.count())): 351 item = self.__layout.itemAt(i) 352 self.__layout.removeAt(i) 353 if item is not None and item.graphicsItem() is not None: 354 remove_item(item.graphicsItem()) 355 356 self.heatmap_widget_grid = [] 357 self.row_annotation_widgets = [] 358 self.col_annotation_widgets = [] 359 self.col_dendrograms = [] 360 self.row_dendrograms = [] 361 self.right_side_colors = [] 362 self.top_side_colors = [] 363 self.heatmap_colormap_legend = None 364 self.bottom_legend_container = None 365 self.parts = None 366 self.updateGeometry() 367 368 def setHeatmaps(self, parts: 'Parts') -> None: 369 """Set the heatmap parts for display""" 370 self.clear() 371 grid = self.__layout 372 N, M = len(parts.rows), len(parts.columns) 373 374 # Start row/column where the heatmap items are inserted 375 # (after the titles/legends/dendrograms) 376 Row0 = self.Row0 377 Col0 = self.Col0 378 # The column for the vertical dendrograms 379 DendrogramColumn = self.DendrogramColumn 380 # The row for the horizontal dendrograms 381 DendrogramRow = self.DendrogramRow 382 RightLabelColumn = Col0 + 2 * M + 1 383 TopAnnotationRow = self.TopAnnotationRow 384 TopLabelsRow = self.TopLabelsRow 385 BottomLabelsRow = Row0 + N 386 colormap = self.__colormap 387 column_dendrograms: List[Optional[DendrogramWidget]] = [None] * M 388 row_dendrograms: List[Optional[DendrogramWidget]] = [None] * N 389 right_side_colors: List[Optional[GraphicsPixmapWidget]] = [None] * N 390 top_side_colors: List[Optional[GraphicsPixmapWidget]] = [None] * M 391 392 data = parts.data 393 if parts.col_names is None: 394 col_names = np.full(data.shape[1], "", dtype=object) 395 else: 396 col_names = np.asarray(parts.col_names, dtype=object) 397 if parts.row_names is None: 398 row_names = np.full(data.shape[0], "", dtype=object) 399 else: 400 row_names = np.asarray(parts.row_names, dtype=object) 401 402 assert len(col_names) == data.shape[1] 403 assert len(row_names) == data.shape[0] 404 405 for i, rowitem in enumerate(parts.rows): 406 if rowitem.title: 407 item = QGraphicsSimpleTextItem(rowitem.title, parent=self) 408 item.setTransform(item.transform().rotate(-90)) 409 item = SimpleLayoutItem(item, parent=grid, anchor=(0, 1), 410 anchorItem=(0, 0)) 411 item.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Maximum) 412 grid.addItem(item, Row0 + i, self.GroupTitleColumn, 413 alignment=Qt.AlignCenter) 414 if rowitem.cluster: 415 dendrogram = DendrogramWidget( 416 parent=self, 417 selectionMode=DendrogramWidget.NoSelection, 418 hoverHighlightEnabled=True, 419 ) 420 dendrogram.set_root(rowitem.cluster) 421 dendrogram.setMaximumWidth(100) 422 dendrogram.setMinimumWidth(100) 423 # Ignore dendrogram vertical size hint (heatmap's size 424 # should define the row's vertical size). 425 dendrogram.setSizePolicy( 426 QSizePolicy.Expanding, QSizePolicy.Ignored) 427 dendrogram.itemClicked.connect( 428 lambda item, partindex=i: 429 self.__select_by_cluster(item, partindex) 430 ) 431 grid.addItem(dendrogram, Row0 + i, DendrogramColumn) 432 row_dendrograms[i] = dendrogram 433 434 for j, colitem in enumerate(parts.columns): 435 if colitem.title: 436 item = SimpleLayoutItem( 437 QGraphicsSimpleTextItem(colitem.title, parent=self), 438 parent=grid, anchor=(0.5, 0.5), anchorItem=(0.5, 0.5) 439 ) 440 item.setSizePolicy(QSizePolicy.Preferred, QSizePolicy.Fixed) 441 grid.addItem(item, self.GroupTitleRow, Col0 + 2 * j + 1) 442 443 if colitem.cluster: 444 dendrogram = DendrogramWidget( 445 parent=self, 446 orientation=DendrogramWidget.Top, 447 selectionMode=DendrogramWidget.NoSelection, 448 hoverHighlightEnabled=False 449 ) 450 dendrogram.set_root(colitem.cluster) 451 dendrogram.setMaximumHeight(100) 452 dendrogram.setMinimumHeight(100) 453 # Ignore dendrogram horizontal size hint (heatmap's width 454 # should define the column width). 455 dendrogram.setSizePolicy( 456 QSizePolicy.Ignored, QSizePolicy.Expanding) 457 grid.addItem(dendrogram, DendrogramRow, Col0 + 2 * j + 1) 458 column_dendrograms[j] = dendrogram 459 460 heatmap_widgets = [] 461 for i in range(N): 462 heatmap_row = [] 463 for j in range(M): 464 row_ix = parts.rows[i].normalized_indices 465 col_ix = parts.columns[j].normalized_indices 466 X_part = data[np.ix_(row_ix, col_ix)] 467 hw = GraphicsHeatmapWidget( 468 aspectRatioMode=self.__aspectRatioMode, 469 data=X_part, span=parts.span, colormap=colormap, 470 ) 471 sp = QSizePolicy(QSizePolicy.Preferred, QSizePolicy.Preferred) 472 sp.setHeightForWidth(True) 473 hw.setSizePolicy(sp) 474 475 avgimg = GraphicsHeatmapWidget( 476 data=np.nanmean(X_part, axis=1, keepdims=True), 477 span=parts.span, colormap=colormap, 478 visible=self.__averagesVisible, 479 minimumSize=QSizeF(5, -1) 480 ) 481 avgimg.setSizePolicy(QSizePolicy.Preferred, QSizePolicy.Ignored) 482 grid.addItem(avgimg, Row0 + i, Col0 + 2 * j) 483 grid.addItem(hw, Row0 + i, Col0 + 2 * j + 1) 484 485 heatmap_row.append(hw) 486 heatmap_widgets.append(heatmap_row) 487 488 for j in range(M): 489 grid.setColumnStretchFactor(Col0 + 2 * j, 1) 490 grid.setColumnStretchFactor( 491 Col0 + 2 * j + 1, parts.columns[j].size) 492 grid.setColumnStretchFactor(RightLabelColumn - 1, 1) 493 494 for i in range(N): 495 grid.setRowStretchFactor(Row0 + i, parts.rows[i].size) 496 497 row_annotation_widgets = [] 498 col_annotation_widgets = [] 499 col_annotation_widgets_top = [] 500 col_annotation_widgets_bottom = [] 501 502 for i, rowitem in enumerate(parts.rows): 503 # Right row annotations 504 indices = np.asarray(rowitem.normalized_indices, dtype=np.intp) 505 labels = row_names[indices] 506 labelslist = TextListWidget( 507 items=labels, parent=self, orientation=Qt.Vertical, 508 alignment=Qt.AlignLeft | Qt.AlignVCenter, 509 sizePolicy=QSizePolicy(QSizePolicy.Fixed, QSizePolicy.Ignored), 510 autoScale=True, 511 objectName="row-labels-right" 512 ) 513 labelslist.setMaximumWidth(300) 514 rowauxsidecolor = GraphicsPixmapWidget( 515 parent=self, visible=False, 516 scaleContents=True, aspectMode=Qt.IgnoreAspectRatio, 517 sizePolicy=QSizePolicy(QSizePolicy.Maximum, QSizePolicy.Ignored), 518 minimumSize=QSizeF(10, -1) 519 ) 520 grid.addItem(rowauxsidecolor, Row0 + i, RightLabelColumn - 1) 521 grid.addItem(labelslist, Row0 + i, RightLabelColumn, Qt.AlignLeft) 522 row_annotation_widgets.append(labelslist) 523 right_side_colors[i] = rowauxsidecolor 524 525 for j, colitem in enumerate(parts.columns): 526 # Top attr annotations 527 indices = np.asarray(colitem.normalized_indices, dtype=np.intp) 528 labels = col_names[indices] 529 sp = QSizePolicy(QSizePolicy.Ignored, QSizePolicy.Fixed) 530 sp.setHeightForWidth(True) 531 labelslist = TextListWidget( 532 items=labels, parent=self, 533 alignment=Qt.AlignLeft | Qt.AlignVCenter, 534 orientation=Qt.Horizontal, 535 autoScale=True, 536 sizePolicy=sp, 537 visible=self.__columnLabelPosition & Position.Top, 538 objectName="column-labels-top", 539 ) 540 colauxsidecolor = GraphicsPixmapWidget( 541 parent=self, visible=False, 542 scaleContents=True, aspectMode=Qt.IgnoreAspectRatio, 543 sizePolicy=QSizePolicy(QSizePolicy.Ignored, 544 QSizePolicy.Maximum), 545 minimumSize=QSizeF(-1, 10) 546 ) 547 548 grid.addItem(labelslist, TopLabelsRow, Col0 + 2 * j + 1, 549 Qt.AlignBottom | Qt.AlignLeft) 550 grid.addItem(colauxsidecolor, TopAnnotationRow, Col0 + 2 * j + 1) 551 col_annotation_widgets.append(labelslist) 552 col_annotation_widgets_top.append(labelslist) 553 top_side_colors[j] = colauxsidecolor 554 555 # Bottom attr annotations 556 labelslist = TextListWidget( 557 items=labels, parent=self, 558 alignment=Qt.AlignRight | Qt.AlignVCenter, 559 orientation=Qt.Horizontal, 560 autoScale=True, 561 sizePolicy=sp, 562 visible=self.__columnLabelPosition & Position.Bottom, 563 objectName="column-labels-bottom", 564 ) 565 grid.addItem(labelslist, BottomLabelsRow, Col0 + 2 * j + 1) 566 col_annotation_widgets.append(labelslist) 567 col_annotation_widgets_bottom.append(labelslist) 568 569 row_color_annotation_header = QGraphicsSimpleTextItem("", self) 570 row_color_annotation_header.setTransform( 571 row_color_annotation_header.transform().rotate(-90)) 572 573 grid.addItem(SimpleLayoutItem( 574 row_color_annotation_header, anchor=(0, 1), 575 aspectMode=Qt.KeepAspectRatio, 576 sizePolicy=QSizePolicy(QSizePolicy.Fixed, QSizePolicy.Preferred), 577 ), 578 0, RightLabelColumn - 1, self.TopLabelsRow + 1, 1, 579 alignment=Qt.AlignBottom 580 ) 581 582 col_color_annotation_header = QGraphicsSimpleTextItem("", self) 583 grid.addItem(SimpleLayoutItem( 584 col_color_annotation_header, anchor=(1, 1), anchorItem=(1, 1), 585 aspectMode=Qt.KeepAspectRatio, 586 sizePolicy=QSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed), 587 ), 588 TopAnnotationRow, 0, 1, Col0, alignment=Qt.AlignRight 589 ) 590 591 legend = GradientLegendWidget( 592 parts.span[0], parts.span[1], 593 colormap, 594 parent=self, 595 minimumSize=QSizeF(100, 20), 596 visible=self.__legendVisible, 597 sizePolicy=QSizePolicy(QSizePolicy.Ignored, QSizePolicy.Fixed) 598 ) 599 legend.setMaximumWidth(300) 600 grid.addItem(legend, self.LegendRow, self.LegendCol, 1, M * 2 - 1) 601 602 def container(parent=None, orientation=Qt.Horizontal, margin=0, spacing=0, **kwargs): 603 widget = QGraphicsWidget(**kwargs) 604 layout = QGraphicsLinearLayout(orientation) 605 layout.setContentsMargins(margin, margin, margin, margin) 606 layout.setSpacing(spacing) 607 widget.setLayout(layout) 608 if parent is not None: 609 widget.setParentItem(parent) 610 611 return widget 612 # Container for color annotation legends 613 legend_container = container( 614 spacing=3, 615 sizePolicy=QSizePolicy(QSizePolicy.Maximum, QSizePolicy.Fixed), 616 visible=False, objectName="annotation-legend-container" 617 ) 618 legend_container_rows = container( 619 parent=legend_container, 620 sizePolicy=QSizePolicy(QSizePolicy.Maximum, QSizePolicy.Fixed), 621 visible=False, objectName="row-annotation-legend-container" 622 ) 623 legend_container_cols = container( 624 parent=legend_container, 625 sizePolicy=QSizePolicy(QSizePolicy.Maximum, QSizePolicy.Fixed), 626 visible=False, objectName="col-annotation-legend-container" 627 ) 628 # ? keep refs to child containers; segfault in scene.clear() ? 629 legend_container._refs = (legend_container_rows, legend_container_cols) 630 legend_container.layout().addItem(legend_container_rows) 631 legend_container.layout().addItem(legend_container_cols) 632 633 grid.addItem(legend_container, BottomLabelsRow + 1, Col0 + 1, 1, M * 2 - 1, 634 alignment=Qt.AlignRight) 635 636 self.heatmap_widget_grid = heatmap_widgets 637 self.row_annotation_widgets = row_annotation_widgets 638 self.col_annotation_widgets = col_annotation_widgets 639 self.col_annotation_widgets_top = col_annotation_widgets_top 640 self.col_annotation_widgets_bottom = col_annotation_widgets_bottom 641 self.col_dendrograms = column_dendrograms 642 self.row_dendrograms = row_dendrograms 643 self.right_side_colors = right_side_colors 644 self.top_side_colors = top_side_colors 645 self.heatmap_colormap_legend = legend 646 self.bottom_legend_container = legend_container 647 self.parts = parts 648 self.__selection_manager.set_heatmap_widgets(heatmap_widgets) 649 650 def legendVisible(self) -> bool: 651 """Is the colormap legend visible.""" 652 return self.__legendVisible 653 654 def setLegendVisible(self, visible: bool) -> None: 655 """Set colormap legend visible state.""" 656 self.__legendVisible = visible 657 legends = [ 658 self.heatmap_colormap_legend, 659 self.bottom_legend_container 660 ] 661 apply_all(filter(None, legends), lambda item: item.setVisible(visible)) 662 663 legendVisible_ = Property(bool, legendVisible, setLegendVisible) 664 665 def setAspectRatioMode(self, mode: Qt.AspectRatioMode) -> None: 666 """ 667 Set the scale aspect mode. 668 669 The widget will try to keep (hint) the scale ratio via the sizeHint 670 reimplementation. 671 """ 672 if self.__aspectRatioMode != mode: 673 self.__aspectRatioMode = mode 674 for hm in chain.from_iterable(self.heatmap_widget_grid): 675 hm.setAspectMode(mode) 676 sp = self.sizePolicy() 677 sp.setHeightForWidth(mode != Qt.IgnoreAspectRatio) 678 self.setSizePolicy(sp) 679 680 def aspectRatioMode(self) -> Qt.AspectRatioMode: 681 return self.__aspectRatioMode 682 683 aspectRatioMode_ = Property( 684 Qt.AspectRatioMode, aspectRatioMode, setAspectRatioMode 685 ) 686 687 def setColumnLabelsPosition(self, position: Position) -> None: 688 self.__columnLabelPosition = position 689 top = bool(position & HeatmapGridWidget.PositionTop) 690 bottom = bool(position & HeatmapGridWidget.PositionBottom) 691 for w in self.col_annotation_widgets_top: 692 w.setVisible(top) 693 w.setMaximumHeight(FLT_MAX if top else 0) 694 for w in self.col_annotation_widgets_bottom: 695 w.setVisible(bottom) 696 w.setMaximumHeight(FLT_MAX if bottom else 0) 697 698 def columnLabelPosition(self) -> Position: 699 return self.__columnLabelPosition 700 701 def setColumnLabels(self, data: Optional[Sequence[str]]) -> None: 702 """Set the column labels to display. If None clear the row names.""" 703 if data is not None: 704 data = np.asarray(data, dtype=object) 705 for top, bottom, part in zip(self.col_annotation_widgets_top, 706 self.col_annotation_widgets_bottom, 707 self.parts.columns): 708 if data is not None: 709 top.setItems(data[part.normalized_indices]) 710 bottom.setItems(data[part.normalized_indices]) 711 else: 712 top.clear() 713 bottom.clear() 714 715 def setRowLabels(self, data: Optional[Sequence[str]]): 716 """ 717 Set the row labels to display. If None clear the row names. 718 """ 719 if data is not None: 720 data = np.asarray(data, dtype=object) 721 for widget, part in zip(self.row_annotation_widgets, self.parts.rows): 722 if data is not None: 723 widget.setItems(data[part.normalized_indices]) 724 else: 725 widget.clear() 726 727 def setRowLabelsVisible(self, visible: bool): 728 """Set row labels visibility""" 729 for widget in self.row_annotation_widgets: 730 widget.setVisible(visible) 731 732 def setRowSideColorAnnotations( 733 self, data: np.ndarray, colormap: ColorMap = None, name="" 734 ): 735 """ 736 Set an optional row side color annotations. 737 738 Parameters 739 ---------- 740 data: Optional[np.ndarray] 741 A sequence such that it is accepted by `colormap.apply`. If None 742 then the color annotations are cleared. 743 colormap: ColorMap 744 name: str 745 Name/title for the annotation column. 746 """ 747 col = self.Col0 + 2 * len(self.parts.columns) 748 legend_layout = self.bottom_legend_container.layout() 749 legend_container = legend_layout.itemAt(1) 750 self.__setColorAnnotationsHelper( 751 data, colormap, name, self.right_side_colors, col, Qt.Vertical, 752 legend_container 753 ) 754 legend_container.setVisible(True) 755 756 def setColumnSideColorAnnotations( 757 self, data: np.ndarray, colormap: ColorMap = None, name="" 758 ): 759 """ 760 Set an optional column color annotations. 761 762 Parameters 763 ---------- 764 data: Optional[np.ndarray] 765 A sequence such that it is accepted by `colormap.apply`. If None 766 then the color annotations are cleared. 767 colormap: ColorMap 768 name: str 769 Name/title for the annotation column. 770 """ 771 row = self.TopAnnotationRow 772 legend_layout = self.bottom_legend_container.layout() 773 legend_container = legend_layout.itemAt(0) 774 self.__setColorAnnotationsHelper( 775 data, colormap, name, self.top_side_colors, row, Qt.Horizontal, 776 legend_container) 777 legend_container.setVisible(True) 778 779 def __setColorAnnotationsHelper( 780 self, data: np.ndarray, colormap: ColorMap, name: str, 781 items: List[GraphicsPixmapWidget], position: int, 782 orientation: Qt.Orientation, legend_container: QGraphicsWidget): 783 layout = self.__layout 784 if orientation == Qt.Horizontal: 785 nameitem = layout.itemAt(position, 0) 786 else: 787 nameitem = layout.itemAt(self.TopLabelsRow, position) 788 size = QFontMetrics(self.font()).lineSpacing() 789 layout_clear(legend_container.layout()) 790 791 def grid_set_maximum_size(position: int, size: float): 792 if orientation == Qt.Horizontal: 793 layout.setRowMaximumHeight(position, size) 794 else: 795 layout.setColumnMaximumWidth(position, size) 796 797 def set_minimum_size(item: QGraphicsLayoutItem, size: float): 798 if orientation == Qt.Horizontal: 799 item.setMinimumHeight(size) 800 else: 801 item.setMinimumWidth(size) 802 item.updateGeometry() 803 804 def reset_minimum_size(item: QGraphicsLayoutItem): 805 set_minimum_size(item, -1) 806 807 def set_hidden(item: GraphicsPixmapWidget): 808 item.setVisible(False) 809 reset_minimum_size(item,) 810 811 def set_visible(item: GraphicsPixmapWidget): 812 item.setVisible(True) 813 set_minimum_size(item, 10) 814 815 def set_preferred_size(item, size): 816 if orientation == Qt.Horizontal: 817 item.setPreferredHeight(size) 818 else: 819 item.setPreferredWidth(size) 820 item.updateGeometry() 821 822 if data is None: 823 apply_all(filter(None, items), set_hidden) 824 grid_set_maximum_size(position, 0) 825 826 nameitem.item.setVisible(False) 827 nameitem.updateGeometry() 828 legend_container.setVisible(False) 829 return 830 else: 831 apply_all(filter(None, items), set_visible) 832 grid_set_maximum_size(position, FLT_MAX) 833 legend_container.setVisible(True) 834 835 if orientation == Qt.Horizontal: 836 parts = self.parts.columns 837 else: 838 parts = self.parts.rows 839 for p, item in zip(parts, items): 840 if item is not None: 841 subset = data[p.normalized_indices] 842 subset = colormap.apply(subset) 843 rgbdata = subset.reshape((-1, 1, subset.shape[-1])) 844 if orientation == Qt.Horizontal: 845 rgbdata = rgbdata.reshape((1, -1, rgbdata.shape[-1])) 846 img = qimage_from_array(rgbdata) 847 item.setPixmap(img) 848 item.setVisible(True) 849 set_preferred_size(item, size) 850 851 nameitem.item.setText(name) 852 nameitem.item.setVisible(True) 853 set_preferred_size(nameitem, size) 854 855 container = legend_container.layout() 856 if isinstance(colormap, CategoricalColorMap): 857 legend = CategoricalColorLegend( 858 colormap, title=name, 859 orientation=Qt.Horizontal, 860 sizePolicy=QSizePolicy(QSizePolicy.Maximum, QSizePolicy.Maximum), 861 visible=self.__legendVisible, 862 ) 863 container.addItem(legend) 864 elif isinstance(colormap, GradientColorMap): 865 legend = GradientLegendWidget( 866 *colormap.span, colormap, title=name, 867 sizePolicy=QSizePolicy(QSizePolicy.MinimumExpanding, QSizePolicy.Maximum), 868 ) 869 legend.setMinimumWidth(100) 870 container.addItem(legend) 871 872 def headerGeometry(self) -> QRectF: 873 """Return the 'header' geometry. 874 875 This is the top part of the widget spanning the top dendrogram, 876 column labels... (can be empty). 877 """ 878 layout = self.__layout 879 geom1 = grid_layout_row_geometry(layout, self.DendrogramRow) 880 geom2 = grid_layout_row_geometry(layout, self.TopLabelsRow) 881 first = grid_layout_row_geometry(layout, self.TopLabelsRow + 1) 882 geom = geom1.united(geom2) 883 if geom.isValid(): 884 if first.isValid(): 885 geom.setBottom(geom.bottom() / 2.0 + first.top() / 2.0) 886 return QRectF(self.geometry().topLeft(), geom.bottomRight()) 887 else: 888 return QRectF() 889 890 def footerGeometry(self) -> QRectF: 891 """Return the 'footer' geometry. 892 893 This is the bottom part of the widget spanning the bottom column labels 894 when applicable (can be empty). 895 """ 896 layout = self.__layout 897 row = self.Row0 + len(self.heatmap_widget_grid) 898 geom = grid_layout_row_geometry(layout, row) 899 nextolast = grid_layout_row_geometry(layout, row - 1) 900 if geom.isValid(): 901 if nextolast.isValid(): 902 geom.setTop(geom.top() / 2 + nextolast.bottom() / 2) 903 return QRectF(geom.topLeft(), self.geometry().bottomRight()) 904 else: 905 return QRectF() 906 907 def setColorMap(self, colormap: GradientColorMap) -> None: 908 self.__colormap = colormap 909 for hm in chain.from_iterable(self.heatmap_widget_grid): 910 hm.setColorMap(colormap) 911 for item in self.__avgitems(): 912 item.setColorMap(colormap) 913 for ch in self.childItems(): 914 if isinstance(ch, GradientLegendWidget): 915 ch.setColorMap(colormap) 916 917 def colorMap(self) -> ColorMap: 918 return self.__colormap 919 920 def __avgitems(self): 921 if self.parts is None: 922 return 923 N = len(self.parts.rows) 924 M = len(self.parts.columns) 925 layout = self.__layout 926 for i in range(N): 927 for j in range(M): 928 item = layout.itemAt(self.Row0 + i, self.Col0 + 2 * j) 929 if isinstance(item, GraphicsHeatmapWidget): 930 yield item 931 932 def setShowAverages(self, visible): 933 self.__averagesVisible = visible 934 for item in self.__avgitems(): 935 item.setVisible(visible) 936 item.setPreferredWidth(0 if not visible else 10) 937 938 def event(self, event): 939 # type: (QEvent) -> bool 940 rval = super().event(event) 941 if event.type() == QEvent.LayoutRequest and self.layout() is not None: 942 self.__update_selection_geometry() 943 return rval 944 945 def setGeometry(self, rect: QRectF) -> None: 946 super().setGeometry(rect) 947 self.__update_selection_geometry() 948 949 def __update_selection_geometry(self): 950 scene = self.scene() 951 self.__selection_manager.update_selection_rects() 952 rects = self.__selection_manager.selection_rects 953 palette = self.palette() 954 pen = QPen(palette.color(QPalette.Foreground), 2) 955 pen.setCosmetic(True) 956 brushcolor = QColor(palette.color(QPalette.Highlight)) 957 brushcolor.setAlpha(50) 958 selection_rects = [] 959 for rect, item in zip_longest(rects, self.selection_rects): 960 assert rect is not None or item is not None 961 if item is None: 962 item = QGraphicsRectItem(rect, None) 963 item.setPen(pen) 964 item.setBrush(brushcolor) 965 scene.addItem(item) 966 selection_rects.append(item) 967 elif rect is not None: 968 item.setRect(rect) 969 item.setPen(pen) 970 item.setBrush(brushcolor) 971 selection_rects.append(item) 972 else: 973 scene.removeItem(item) 974 self.selection_rects = selection_rects 975 976 def __select_by_cluster(self, item, dendrogramindex): 977 # User clicked on a dendrogram node. 978 # Select all rows corresponding to the cluster item. 979 node = item.node 980 try: 981 hm = self.heatmap_widget_grid[dendrogramindex][0] 982 except IndexError: 983 pass 984 else: 985 key = QApplication.keyboardModifiers() 986 clear = not (key & ((Qt.ControlModifier | Qt.ShiftModifier | 987 Qt.AltModifier))) 988 remove = (key & (Qt.ControlModifier | Qt.AltModifier)) 989 append = (key & Qt.ControlModifier) 990 self.__selection_manager.selection_add( 991 node.value.first, node.value.last - 1, hm, 992 clear=clear, remove=remove, append=append) 993 994 def heatmapAtPos(self, pos: QPointF) -> Optional['GraphicsHeatmapWidget']: 995 for hw in chain.from_iterable(self.heatmap_widget_grid): 996 if hw.contains(hw.mapFromItem(self, pos)): 997 return hw 998 return None 999 1000 __selecting = False 1001 1002 def mousePressEvent(self, event: QGraphicsSceneMouseEvent) -> None: 1003 pos = event.pos() 1004 heatmap = self.heatmapAtPos(pos) 1005 if heatmap and event.button() & Qt.LeftButton: 1006 row, _ = heatmap.heatmapCellAt(heatmap.mapFromScene(event.scenePos())) 1007 if row != -1: 1008 self.__selection_manager.selection_start(heatmap, event) 1009 self.__selecting = True 1010 event.setAccepted(True) 1011 return 1012 super().mousePressEvent(event) 1013 1014 def mouseMoveEvent(self, event: QGraphicsSceneMouseEvent) -> None: 1015 pos = event.pos() 1016 heatmap = self.heatmapAtPos(pos) 1017 if heatmap and event.buttons() & Qt.LeftButton and self.__selecting: 1018 row, _ = heatmap.heatmapCellAt(heatmap.mapFromScene(pos)) 1019 if row != -1: 1020 self.__selection_manager.selection_update(heatmap, event) 1021 event.setAccepted(True) 1022 return 1023 super().mouseMoveEvent(event) 1024 1025 def mouseReleaseEvent(self, event: QGraphicsSceneMouseEvent) -> None: 1026 pos = event.pos() 1027 if event.button() == Qt.LeftButton and self.__selecting: 1028 self.__selection_manager.selection_finish( 1029 self.heatmapAtPos(pos), event) 1030 self.__selecting = False 1031 super().mouseReleaseEvent(event) 1032 1033 def selectedRows(self) -> Sequence[int]: 1034 """Return the current selected rows.""" 1035 if self.parts is None: 1036 return [] 1037 visual_indices = self.__selection_manager.selections 1038 indices = np.hstack([r.normalized_indices for r in self.parts.rows]) 1039 return indices[visual_indices].tolist() 1040 1041 def selectRows(self, selection: Sequence[int]): 1042 """Select the specified rows. Previous selection is cleared.""" 1043 if self.parts is not None: 1044 indices = np.hstack([r.normalized_indices for r in self.parts.rows]) 1045 else: 1046 indices = [] 1047 condition = np.in1d(indices, selection) 1048 visual_indices = np.flatnonzero(condition) 1049 self.__selection_manager.select_rows(visual_indices.tolist()) 1050 1051 1052class GraphicsHeatmapWidget(QGraphicsWidget): 1053 __aspectMode = Qt.KeepAspectRatio 1054 1055 def __init__( 1056 self, parent=None, 1057 data: Optional[np.ndarray] = None, 1058 span: Tuple[float, float] = (0., 1.), 1059 colormap: Optional[ColorMap] = None, 1060 aspectRatioMode=Qt.KeepAspectRatio, 1061 **kwargs 1062 ) -> None: 1063 super().__init__(None, **kwargs) 1064 self.setAcceptHoverEvents(True) 1065 self.__levels = span 1066 if colormap is None: 1067 colormap = GradientColorMap(DefaultContinuousPalette.lookup_table()) 1068 self.__colormap = colormap 1069 self.__data: Optional[np.ndarray] = None 1070 self.__pixmap = QPixmap() 1071 self.__aspectMode = aspectRatioMode 1072 1073 layout = QGraphicsLinearLayout(Qt.Horizontal) 1074 layout.setContentsMargins(0, 0, 0, 0) 1075 self.__pixmapItem = GraphicsPixmapWidget( 1076 self, scaleContents=True, aspectMode=Qt.IgnoreAspectRatio 1077 ) 1078 sp = QSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) 1079 sp.setHeightForWidth(True) 1080 self.__pixmapItem.setSizePolicy(sp) 1081 layout.addItem(self.__pixmapItem) 1082 self.setLayout(layout) 1083 sp = QSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) 1084 sp.setHeightForWidth(True) 1085 self.setSizePolicy(sp) 1086 self.setHeatmapData(data) 1087 1088 if parent is not None: 1089 self.setParentItem(parent) 1090 1091 def setAspectMode(self, mode: Qt.AspectRatioMode) -> None: 1092 if self.__aspectMode != mode: 1093 self.__aspectMode = mode 1094 sp = self.sizePolicy() 1095 sp.setHeightForWidth(mode != Qt.IgnoreAspectRatio) 1096 self.setSizePolicy(sp) 1097 self.updateGeometry() 1098 1099 def aspectMode(self) -> Qt.AspectRatioMode: 1100 return self.__aspectMode 1101 1102 def sizeHint(self, which: Qt.SizeHint, constraint=QSizeF(-1, -1)) -> QSizeF: 1103 if which == Qt.PreferredSize and constraint.width() >= 0: 1104 sh = super().sizeHint(which) 1105 return scaled(sh, QSizeF(constraint.width(), -1), self.__aspectMode) 1106 return super().sizeHint(which, constraint) 1107 1108 def clear(self): 1109 """Clear/reset the widget.""" 1110 self.__data = None 1111 self.__pixmap = QPixmap() 1112 self.__pixmapItem.setPixmap(self.__pixmap) 1113 self.updateGeometry() 1114 1115 def setHeatmapData(self, data): 1116 """Set the heatmap data for display.""" 1117 if self.__data is not data: 1118 self.clear() 1119 self.__data = data 1120 self.__updatePixmap() 1121 self.update() 1122 1123 def heatmapData(self) -> Optional[np.ndarray]: 1124 if self.__data is not None: 1125 v = self.__data.view() 1126 v.flags.writeable = False 1127 return v 1128 else: 1129 return None 1130 1131 def pixmap(self) -> QPixmap: 1132 return self.__pixmapItem.pixmap() 1133 1134 def setLevels(self, levels: Tuple[float, float]) -> None: 1135 if levels != self.__levels: 1136 self.__levels = levels 1137 self.__updatePixmap() 1138 self.update() 1139 1140 def setColorMap(self, colormap: ColorMap): 1141 self.__colormap = colormap 1142 self.__updatePixmap() 1143 1144 def colorMap(self,) -> ColorMap: 1145 return self.__colormap 1146 1147 def __updatePixmap(self): 1148 if self.__data is not None: 1149 ll, lh = self.__levels 1150 cmap = self.__colormap.replace(span=(ll, lh)) 1151 rgb = cmap.apply(self.__data) 1152 rgb[np.isnan(self.__data)] = (100, 100, 100) 1153 qimage = qimage_from_array(rgb) 1154 self.__pixmap = QPixmap.fromImage(qimage) 1155 else: 1156 self.__pixmap = QPixmap() 1157 1158 self.__pixmapItem.setPixmap(self.__pixmap) 1159 self.__updateSizeHints() 1160 1161 def changeEvent(self, event: QEvent) -> None: 1162 super().changeEvent(event) 1163 if event.type() == QEvent.FontChange: 1164 self.__updateSizeHints() 1165 1166 def __updateSizeHints(self): 1167 hmsize = QSizeF(self.__pixmap.size()) 1168 size = QFontMetrics(self.font()).lineSpacing() 1169 self.__pixmapItem.setMinimumSize(hmsize) 1170 self.__pixmapItem.setPreferredSize(hmsize * size) 1171 1172 def heatmapCellAt(self, pos: QPointF) -> Tuple[int, int]: 1173 """Return the cell row, column from `pos` in local coordinates. 1174 """ 1175 if self.__pixmap.isNull() or not \ 1176 self.__pixmapItem.geometry().contains(pos): 1177 return -1, -1 1178 assert self.__data is not None 1179 item_clicked = self.__pixmapItem 1180 pos = self.mapToItem(item_clicked, pos) 1181 size = self.__pixmapItem.size() 1182 1183 x, y = pos.x(), pos.y() 1184 1185 N, M = self.__data.shape 1186 fx = x / size.width() 1187 fy = y / size.height() 1188 i = min(int(math.floor(fy * N)), N - 1) 1189 j = min(int(math.floor(fx * M)), M - 1) 1190 return i, j 1191 1192 def heatmapCellRect(self, row: int, column: int) -> QRectF: 1193 """Return a rectangle in local coordinates containing the cell 1194 at `row` and `column`. 1195 """ 1196 size = self.__pixmap.size() 1197 if not (0 <= column < size.width() or 0 <= row < size.height()): 1198 return QRectF() 1199 1200 topleft = QPointF(column, row) 1201 bottomright = QPointF(column + 1, row + 1) 1202 t = self.__pixmapItem.pixmapTransform() 1203 rect = t.mapRect(QRectF(topleft, bottomright)) 1204 rect.translated(self.__pixmapItem.pos()) 1205 return rect 1206 1207 def rowRect(self, row): 1208 """ 1209 Return a QRectF in local coordinates containing the entire row. 1210 """ 1211 rect = self.heatmapCellRect(row, 0) 1212 rect.setLeft(0) 1213 rect.setRight(self.size().width()) 1214 return rect 1215 1216 def heatmapCellToolTip(self, row, column): 1217 return "{}, {}: {:g}".format(row, column, self.__data[row, column]) 1218 1219 def hoverMoveEvent(self, event): 1220 pos = event.pos() 1221 row, column = self.heatmapCellAt(pos) 1222 if row != -1: 1223 tooltip = self.heatmapCellToolTip(row, column) 1224 self.setToolTip(tooltip) 1225 return super().hoverMoveEvent(event) 1226 1227 1228def remove_item(item: QGraphicsItem) -> None: 1229 scene = item.scene() 1230 if scene is not None: 1231 scene.removeItem(item) 1232 else: 1233 item.setParentItem(None) 1234 1235 1236class _GradientLegendAxisItem(pg.AxisItem): 1237 def boundingRect(self): 1238 br = super().boundingRect() 1239 if self.orientation in ["top", "bottom"]: 1240 # adjust brect (extend in horizontal direction). pg.AxisItem has 1241 # only fixed constant adjustment for tick text over-flow. 1242 font = self.style.get("tickFont") 1243 if font is None: 1244 font = self.font() 1245 fm = QFontMetrics(font) 1246 w = fm.horizontalAdvance('0.0000000') / 2 # bad, should use _tickValues 1247 geomw = self.geometry().size().width() 1248 maxw = max(geomw + 2 * w, br.width()) 1249 if br.width() < maxw: 1250 adjust = (maxw - br.width()) / 2 1251 br = br.adjusted(-adjust, 0, adjust, 0) 1252 return br 1253 1254 def showEvent(self, event): 1255 super().showEvent(event) 1256 # AxisItem resizes to 0 width/height when hidden, does not update when 1257 # shown implicitly (i.e. a parent becomes visible). 1258 # Use showLabel(False) which should update the size without actually 1259 # changing anything else (no public interface to explicitly recalc 1260 # fixed sizes). 1261 self.showLabel(False) 1262 1263 1264class GradientLegendWidget(QGraphicsWidget): 1265 def __init__( 1266 self, low, high, colormap: GradientColorMap, parent=None, title="", 1267 **kwargs 1268 ): 1269 kwargs.setdefault( 1270 "sizePolicy", QSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed) 1271 ) 1272 super().__init__(None, **kwargs) 1273 self.low = low 1274 self.high = high 1275 self.colormap = colormap 1276 self.title = title 1277 1278 layout = QGraphicsLinearLayout(Qt.Vertical) 1279 layout.setContentsMargins(0, 0, 0, 0) 1280 layout.setSpacing(0) 1281 self.setLayout(layout) 1282 if title: 1283 titleitem = SimpleLayoutItem( 1284 QGraphicsSimpleTextItem(title, self), parent=layout, 1285 anchor=(0.5, 1.), anchorItem=(0.5, 1.0) 1286 ) 1287 layout.addItem(titleitem) 1288 self.__axis = axis = _GradientLegendAxisItem( 1289 orientation="top", maxTickLength=3) 1290 axis.setRange(low, high) 1291 layout.addItem(axis) 1292 pen = QPen(self.palette().color(QPalette.Text)) 1293 axis.setPen(pen) 1294 self.__pixitem = GraphicsPixmapWidget( 1295 parent=self, scaleContents=True, aspectMode=Qt.IgnoreAspectRatio 1296 ) 1297 self.__pixitem.setSizePolicy(QSizePolicy.Ignored, QSizePolicy.Preferred) 1298 self.__pixitem.setMinimumHeight(12) 1299 layout.addItem(self.__pixitem) 1300 self.__update() 1301 1302 if parent is not None: 1303 self.setParentItem(parent) 1304 1305 def setRange(self, low, high): 1306 if self.low != low or self.high != high: 1307 self.low = low 1308 self.high = high 1309 self.__update() 1310 1311 def setColorMap(self, colormap: ColorMap) -> None: 1312 """Set the color map""" 1313 self.colormap = colormap 1314 self.__update() 1315 1316 def colorMap(self) -> ColorMap: 1317 return self.colormap 1318 1319 def __update(self): 1320 low, high = self.low, self.high 1321 data = np.linspace(low, high, num=1000).reshape((1, -1)) 1322 cmap = self.colormap.replace(span=(low, high)) 1323 qimg = qimage_from_array(cmap.apply(data)) 1324 self.__pixitem.setPixmap(QPixmap.fromImage(qimg)) 1325 if self.colormap.center is not None \ 1326 and low < self.colormap.center < high: 1327 tick_values = [low, self.colormap.center, high] 1328 else: 1329 tick_values = [low, high] 1330 tickformat = "{:.6g}".format 1331 ticks = [(val, tickformat(val)) for val in tick_values] 1332 self.__axis.setRange(low, high) 1333 self.__axis.setTicks([ticks]) 1334 1335 self.updateGeometry() 1336 1337 def changeEvent(self, event: QEvent) -> None: 1338 if event.type() == QEvent.PaletteChange: 1339 pen = QPen(self.palette().color(QPalette.Text)) 1340 self.__axis.setPen(pen) 1341 super().changeEvent(event) 1342 1343 1344class CategoricalColorLegend(QGraphicsWidget): 1345 def __init__( 1346 self, colormap: CategoricalColorMap, title="", 1347 orientation=Qt.Vertical, parent=None, **kwargs, 1348 ) -> None: 1349 self.__colormap = colormap 1350 self.__title = title 1351 self.__names = colormap.names 1352 self.__layout = QGraphicsLinearLayout(Qt.Vertical) 1353 self.__flow = GraphicsFlowLayout() 1354 self.__layout.addItem(self.__flow) 1355 self.__flow.setHorizontalSpacing(4) 1356 self.__flow.setVerticalSpacing(4) 1357 self.__orientation = orientation 1358 kwargs.setdefault( 1359 "sizePolicy", QSizePolicy(QSizePolicy.Maximum, QSizePolicy.Maximum) 1360 ) 1361 super().__init__(None, **kwargs) 1362 self.setLayout(self.__layout) 1363 self._setup() 1364 1365 if parent is not None: 1366 self.setParent(parent) 1367 1368 def setOrientation(self, orientation): 1369 if self.__orientation != orientation: 1370 self._clear() 1371 self._setup() 1372 1373 def orientation(self): 1374 return self.__orientation 1375 1376 def _clear(self): 1377 items = list(layout_items(self.__flow)) 1378 layout_clear(self.__flow) 1379 for item in items: 1380 if isinstance(item, SimpleLayoutItem): 1381 remove_item(item.item) 1382 # remove 'title' item if present 1383 items = [item for item in layout_items(self.__layout) 1384 if item is not self.__flow] 1385 for item in items: 1386 self.__layout.removeItem(item) 1387 if isinstance(item, SimpleLayoutItem): 1388 remove_item(item.item) 1389 1390 def _setup(self): 1391 # setup the layout 1392 colors = self.__colormap.colortable 1393 names = self.__colormap.names 1394 title = self.__title 1395 layout = self.__layout 1396 flow = self.__flow 1397 assert flow.count() == 0 1398 font = self.font() 1399 fm = QFontMetrics(font) 1400 size = fm.horizontalAdvance("X") 1401 headeritem = None 1402 if title: 1403 headeritem = QGraphicsSimpleTextItem(title) 1404 headeritem.setFont(font) 1405 1406 def centered(item): 1407 return SimpleLayoutItem(item, anchor=(0.5, 0.5), anchorItem=(0.5, 0.5)) 1408 1409 def legend_item_pair(color: QColor, size: float, text: str): 1410 coloritem = QGraphicsRectItem(0, 0, size, size) 1411 coloritem.setBrush(color) 1412 textitem = QGraphicsSimpleTextItem() 1413 textitem.setFont(font) 1414 textitem.setText(text) 1415 layout = QGraphicsLinearLayout(Qt.Horizontal) 1416 layout.setSpacing(2) 1417 layout.addItem(centered(coloritem)) 1418 layout.addItem(SimpleLayoutItem(textitem)) 1419 return coloritem, textitem, layout 1420 1421 items = [legend_item_pair(QColor(*color), size, name) 1422 for color, name in zip(colors, names)] 1423 1424 for sym, label, pair_layout in items: 1425 flow.addItem(pair_layout) 1426 1427 if headeritem: 1428 layout.insertItem(0, centered(headeritem)) 1429 1430 def changeEvent(self, event: QEvent) -> None: 1431 if event.type() == QEvent.FontChange: 1432 self._updateFont(self.font()) 1433 super().changeEvent(event) 1434 1435 def _updateFont(self, font): 1436 w = QFontMetrics(font).horizontalAdvance("X") 1437 for item in filter( 1438 lambda item: isinstance(item, SimpleLayoutItem), 1439 layout_items_recursive(self.__layout) 1440 ): 1441 if isinstance(item.item, QGraphicsSimpleTextItem): 1442 item.item.setFont(font) 1443 elif isinstance(item.item, QGraphicsRectItem): 1444 item.item.setRect(QRectF(0, 0, w, w)) 1445 item.updateGeometry() 1446 1447 1448def layout_items(layout: QGraphicsLayout) -> Iterable[QGraphicsLayoutItem]: 1449 for item in map(layout.itemAt, range(layout.count())): 1450 if item is not None: 1451 yield item 1452 1453 1454def layout_items_recursive(layout: QGraphicsLayout): 1455 for item in map(layout.itemAt, range(layout.count())): 1456 if item is not None: 1457 if item.isLayout(): 1458 assert isinstance(item, QGraphicsLayout) 1459 yield from layout_items_recursive(item) 1460 else: 1461 yield item 1462 1463 1464def layout_clear(layout: QGraphicsLayout) -> None: 1465 for i in reversed(range(layout.count())): 1466 item = layout.itemAt(i) 1467 layout.removeAt(i) 1468 if item is not None and item.graphicsItem() is not None: 1469 remove_item(item.graphicsItem()) 1470 1471 1472class SelectionManager(QObject): 1473 """ 1474 Selection manager for heatmap rows 1475 """ 1476 selection_changed = Signal() 1477 selection_finished = Signal() 1478 1479 def __init__(self, parent=None, **kwargs): 1480 super().__init__(parent, **kwargs) 1481 self.selections = [] 1482 self.selection_ranges = [] 1483 self.selection_ranges_temp = [] 1484 self.selection_rects = [] 1485 self.heatmaps = [] 1486 self._heatmap_ranges: Dict[GraphicsHeatmapWidget, Tuple[int, int]] = {} 1487 self._start_row = 0 1488 1489 def clear(self): 1490 self.remove_rows(self.selection) 1491 1492 def set_heatmap_widgets(self, widgets): 1493 # type: (Sequence[Sequence[GraphicsHeatmapWidget]] )-> None 1494 self.remove_rows(self.selections) 1495 self.heatmaps = list(zip(*widgets)) 1496 1497 # Compute row ranges for all heatmaps 1498 self._heatmap_ranges = {} 1499 for group in zip(*widgets): 1500 start = end = 0 1501 for heatmap in group: 1502 end += heatmap.heatmapData().shape[0] 1503 self._heatmap_ranges[heatmap] = (start, end) 1504 start = end 1505 1506 def select_rows(self, rows, heatmap=None, clear=True): 1507 """Add `rows` to selection. If `heatmap` is provided the rows 1508 are mapped from the local indices to global heatmap indices. If `clear` 1509 then remove previous rows. 1510 """ 1511 if heatmap is not None: 1512 start, _ = self._heatmap_ranges[heatmap] 1513 rows = [start + r for r in rows] 1514 1515 old_selection = list(self.selections) 1516 if clear: 1517 self.selections = rows 1518 else: 1519 self.selections = sorted(set(self.selections + rows)) 1520 1521 if self.selections != old_selection: 1522 self.update_selection_rects() 1523 self.selection_changed.emit() 1524 1525 def remove_rows(self, rows): 1526 """Remove `rows` from the selection. 1527 """ 1528 old_selection = list(self.selections) 1529 self.selections = sorted(set(self.selections) - set(rows)) 1530 if old_selection != self.selections: 1531 self.update_selection_rects() 1532 self.selection_changed.emit() 1533 1534 def combined_ranges(self, ranges): 1535 combined_ranges = set() 1536 for start, end in ranges: 1537 if start <= end: 1538 rng = range(start, end + 1) 1539 else: 1540 rng = range(start, end - 1, -1) 1541 combined_ranges.update(rng) 1542 return sorted(combined_ranges) 1543 1544 def selection_start(self, heatmap_widget, event): 1545 """ Selection started by `heatmap_widget` due to `event`. 1546 """ 1547 pos = heatmap_widget.mapFromScene(event.scenePos()) 1548 row, _ = heatmap_widget.heatmapCellAt(pos) 1549 1550 start, _ = self._heatmap_ranges[heatmap_widget] 1551 row = start + row 1552 self._start_row = row 1553 range = (row, row) 1554 self.selection_ranges_temp = [] 1555 if event.modifiers() & Qt.ControlModifier: 1556 self.selection_ranges_temp = self.selection_ranges 1557 self.selection_ranges = self.remove_range( 1558 self.selection_ranges, row, row, append=True) 1559 elif event.modifiers() & Qt.ShiftModifier: 1560 self.selection_ranges.append(range) 1561 elif event.modifiers() & Qt.AltModifier: 1562 self.selection_ranges = self.remove_range( 1563 self.selection_ranges, row, row, append=False) 1564 else: 1565 self.selection_ranges = [range] 1566 self.select_rows(self.combined_ranges(self.selection_ranges)) 1567 1568 def selection_update(self, heatmap_widget, event): 1569 """ Selection updated by `heatmap_widget due to `event` (mouse drag). 1570 """ 1571 pos = heatmap_widget.mapFromScene(event.scenePos()) 1572 row, _ = heatmap_widget.heatmapCellAt(pos) 1573 if row < 0: 1574 return 1575 1576 start, _ = self._heatmap_ranges[heatmap_widget] 1577 row = start + row 1578 if event.modifiers() & Qt.ControlModifier: 1579 self.selection_ranges = self.remove_range( 1580 self.selection_ranges_temp, self._start_row, row, append=True) 1581 elif event.modifiers() & Qt.AltModifier: 1582 self.selection_ranges = self.remove_range( 1583 self.selection_ranges, self._start_row, row, append=False) 1584 else: 1585 if self.selection_ranges: 1586 self.selection_ranges[-1] = (self._start_row, row) 1587 else: 1588 self.selection_ranges = [(row, row)] 1589 1590 self.select_rows(self.combined_ranges(self.selection_ranges)) 1591 1592 def selection_finish(self, heatmap_widget, event): 1593 """ Selection finished by `heatmap_widget due to `event`. 1594 """ 1595 if heatmap_widget is not None: 1596 pos = heatmap_widget.mapFromScene(event.scenePos()) 1597 row, _ = heatmap_widget.heatmapCellAt(pos) 1598 start, _ = self._heatmap_ranges[heatmap_widget] 1599 row = start + row 1600 if event.modifiers() & Qt.ControlModifier: 1601 pass 1602 elif event.modifiers() & Qt.AltModifier: 1603 self.selection_ranges = self.remove_range( 1604 self.selection_ranges, self._start_row, row, append=False) 1605 else: 1606 if len(self.selection_ranges) > 0: 1607 self.selection_ranges[-1] = (self._start_row, row) 1608 self.select_rows(self.combined_ranges(self.selection_ranges)) 1609 self.selection_finished.emit() 1610 1611 def selection_add(self, start, end, heatmap=None, clear=True, 1612 remove=False, append=False): 1613 """ Add/remove a selection range from `start` to `end`. 1614 """ 1615 if heatmap is not None: 1616 _start, _ = self._heatmap_ranges[heatmap] 1617 start = _start + start 1618 end = _start + end 1619 1620 if clear: 1621 self.selection_ranges = [] 1622 if remove: 1623 self.selection_ranges = self.remove_range( 1624 self.selection_ranges, start, end, append=append) 1625 else: 1626 self.selection_ranges.append((start, end)) 1627 self.select_rows(self.combined_ranges(self.selection_ranges)) 1628 self.selection_finished.emit() 1629 1630 def remove_range(self, ranges, start, end, append=False): 1631 if start > end: 1632 start, end = end, start 1633 comb_ranges = [i for i in self.combined_ranges(ranges) 1634 if i > end or i < start] 1635 if append: 1636 comb_ranges += [i for i in range(start, end + 1) 1637 if i not in self.combined_ranges(ranges)] 1638 comb_ranges = sorted(comb_ranges) 1639 return self.combined_to_ranges(comb_ranges) 1640 1641 def combined_to_ranges(self, comb_ranges): 1642 ranges = [] 1643 if len(comb_ranges) > 0: 1644 i, start, end = 0, comb_ranges[0], comb_ranges[0] 1645 for val in comb_ranges[1:]: 1646 i += 1 1647 if start + i < val: 1648 ranges.append((start, end)) 1649 i, start = 0, val 1650 end = val 1651 ranges.append((start, end)) 1652 return ranges 1653 1654 def update_selection_rects(self): 1655 """ Update the selection rects. 1656 """ 1657 def group_selections(selections): 1658 """Group selections along with heatmaps. 1659 """ 1660 rows2hm = self.rows_to_heatmaps() 1661 selections = iter(selections) 1662 try: 1663 start = end = next(selections) 1664 except StopIteration: 1665 return 1666 end_heatmaps = rows2hm[end] 1667 try: 1668 while True: 1669 new_end = next(selections) 1670 new_end_heatmaps = rows2hm[new_end] 1671 if new_end > end + 1 or new_end_heatmaps != end_heatmaps: 1672 yield start, end, end_heatmaps 1673 start = end = new_end 1674 end_heatmaps = new_end_heatmaps 1675 else: 1676 end = new_end 1677 1678 except StopIteration: 1679 yield start, end, end_heatmaps 1680 1681 def selection_rect(start, end, heatmaps): 1682 rect = QRectF() 1683 for heatmap in heatmaps: 1684 h_start, _ = self._heatmap_ranges[heatmap] 1685 rect |= heatmap.mapToScene(heatmap.rowRect(start - h_start)).boundingRect() 1686 rect |= heatmap.mapToScene(heatmap.rowRect(end - h_start)).boundingRect() 1687 return rect 1688 1689 self.selection_rects = [] 1690 for start, end, heatmaps in group_selections(self.selections): 1691 rect = selection_rect(start, end, heatmaps) 1692 self.selection_rects.append(rect) 1693 1694 def rows_to_heatmaps(self): 1695 heatmap_groups = zip(*self.heatmaps) 1696 rows2hm = {} 1697 for heatmaps in heatmap_groups: 1698 hm = heatmaps[0] 1699 start, end = self._heatmap_ranges[hm] 1700 rows2hm.update(dict.fromkeys(range(start, end), heatmaps)) 1701 return rows2hm 1702 1703 1704Parts = HeatmapGridWidget.Parts 1705RowItem = HeatmapGridWidget.RowItem 1706ColumnItem = HeatmapGridWidget.ColumnItem 1707