1import enum 2from collections import defaultdict 3from itertools import islice 4from typing import ( 5 Iterable, Mapping, Any, TypeVar, NamedTuple, Sequence, Optional, 6 Union, Tuple, List, Callable 7) 8 9import numpy as np 10import scipy.sparse as sp 11 12from AnyQt.QtWidgets import ( 13 QGraphicsScene, QGraphicsView, QFormLayout, QComboBox, QGroupBox, 14 QMenu, QAction, QSizePolicy 15) 16from AnyQt.QtGui import QStandardItemModel, QStandardItem, QFont, QKeySequence 17from AnyQt.QtCore import Qt, QSize, QRectF, QObject 18 19from orangewidget.utils.combobox import ComboBox, ComboBoxSearch 20from Orange.data import Domain, Table, Variable, DiscreteVariable, \ 21 ContinuousVariable 22from Orange.data.sql.table import SqlTable 23import Orange.distance 24 25from Orange.clustering import hierarchical, kmeans 26from Orange.widgets.utils import colorpalettes, apply_all, enum_get, itemmodels 27from Orange.widgets.utils.itemmodels import DomainModel 28from Orange.widgets.utils.stickygraphicsview import StickyGraphicsView 29from Orange.widgets.utils.graphicsview import GraphicsWidgetView 30from Orange.widgets.utils.colorpalettes import Palette 31 32from Orange.widgets.utils.annotated_data import (create_annotated_table, 33 ANNOTATED_DATA_SIGNAL_NAME) 34from Orange.widgets import widget, gui, settings 35from Orange.widgets.widget import Msg, Input, Output 36 37from Orange.widgets.data.oweditdomain import table_column_data 38from Orange.widgets.visualize.utils.heatmap import HeatmapGridWidget, \ 39 ColorMap, CategoricalColorMap, GradientColorMap 40from Orange.widgets.utils.colorgradientselection import ColorGradientSelection 41from Orange.widgets.utils.widgetpreview import WidgetPreview 42 43 44__all__ = [] 45 46 47def kmeans_compress(X, k=50): 48 km = kmeans.KMeans(n_clusters=k, n_init=5, random_state=42) 49 return km.get_model(X) 50 51 52def split_domain(domain: Domain, split_label: str): 53 """Split the domain based on values of `split_label` value. 54 """ 55 groups = defaultdict(list) 56 for var in domain.attributes: 57 val = var.attributes.get(split_label) 58 groups[val].append(var) 59 if None in groups: 60 na = groups.pop(None) 61 return [*groups.items(), ("N/A", na)] 62 else: 63 return list(groups.items()) 64 65 66def cbselect(cb: QComboBox, value, role: Qt.ItemDataRole = Qt.EditRole) -> None: 67 """ 68 Find and select the `value` in the `cb` QComboBox. 69 70 Parameters 71 ---------- 72 cb: QComboBox 73 value: Any 74 role: Qt.ItemDataRole 75 The data role in the combo box model to match value against 76 """ 77 cb.setCurrentIndex(cb.findData(value, role)) 78 79 80class Clustering(enum.IntEnum): 81 #: No clustering 82 None_ = 0 83 #: Hierarchical clustering 84 Clustering = 1 85 #: Hierarchical clustering with optimal leaf ordering 86 OrderedClustering = 2 87 88 89ClusteringRole = Qt.UserRole + 13 90#: Item data for clustering method selection models 91ClusteringModelData = [ 92 { 93 Qt.DisplayRole: "None", 94 Qt.ToolTipRole: "No clustering", 95 ClusteringRole: Clustering.None_, 96 }, { 97 Qt.DisplayRole: "Clustering", 98 Qt.ToolTipRole: "Apply hierarchical clustering", 99 ClusteringRole: Clustering.Clustering, 100 }, { 101 Qt.DisplayRole: "Clustering (opt. ordering)", 102 Qt.ToolTipRole: "Apply hierarchical clustering with optimal leaf " 103 "ordering.", 104 ClusteringRole: Clustering.OrderedClustering, 105 } 106] 107 108ColumnLabelsPosData = [ 109 {Qt.DisplayRole: name, Qt.UserRole: value} 110 for name, value in [ 111 ("None", HeatmapGridWidget.NoPosition), 112 ("Top", HeatmapGridWidget.PositionTop), 113 ("Bottom", HeatmapGridWidget.PositionBottom), 114 ("Top and Bottom", (HeatmapGridWidget.PositionTop | 115 HeatmapGridWidget.PositionBottom)), 116 ] 117] 118 119 120def create_list_model( 121 items: Iterable[Mapping[Qt.ItemDataRole, Any]], 122 parent: Optional[QObject] = None, 123) -> QStandardItemModel: 124 """Create list model from an item date iterable.""" 125 model = QStandardItemModel(parent) 126 for item in items: 127 sitem = QStandardItem() 128 for role, value in item.items(): 129 sitem.setData(value, role) 130 model.appendRow([sitem]) 131 return model 132 133 134class OWHeatMap(widget.OWWidget): 135 name = "Heat Map" 136 description = "Plot a data matrix heatmap." 137 icon = "icons/Heatmap.svg" 138 priority = 260 139 keywords = [] 140 141 class Inputs: 142 data = Input("Data", Table) 143 144 class Outputs: 145 selected_data = Output("Selected Data", Table, default=True) 146 annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table) 147 148 settings_version = 3 149 150 settingsHandler = settings.DomainContextHandler() 151 152 # Disable clustering for inputs bigger than this 153 MaxClustering = 25000 154 # Disable cluster leaf ordering for inputs bigger than this 155 MaxOrderedClustering = 1000 156 157 threshold_low = settings.Setting(0.0) 158 threshold_high = settings.Setting(1.0) 159 color_center = settings.Setting(0) 160 161 merge_kmeans = settings.Setting(False) 162 merge_kmeans_k = settings.Setting(50) 163 164 # Display column with averages 165 averages: bool = settings.Setting(True) 166 # Display legend 167 legend: bool = settings.Setting(True) 168 # Annotations 169 #: text row annotation (row names) 170 annotation_var = settings.ContextSetting(None) 171 #: color row annotation 172 annotation_color_var = settings.ContextSetting(None) 173 column_annotation_color_key: Optional[Tuple[str, str]] = settings.ContextSetting(None) 174 175 # Discrete variable used to split that data/heatmaps (vertically) 176 split_by_var = settings.ContextSetting(None) 177 # Split heatmap columns by 'key' (horizontal) 178 split_columns_key: Optional[Tuple[str, str]] = settings.ContextSetting(None) 179 # Selected row/column clustering method (name) 180 col_clustering_method: str = settings.Setting(Clustering.None_.name) 181 row_clustering_method: str = settings.Setting(Clustering.None_.name) 182 183 palette_name = settings.Setting(colorpalettes.DefaultContinuousPaletteName) 184 column_label_pos: int = settings.Setting(1) 185 selected_rows: List[int] = settings.Setting(None, schema_only=True) 186 187 auto_commit = settings.Setting(True) 188 189 graph_name = "scene" 190 191 class Information(widget.OWWidget.Information): 192 sampled = Msg("Data has been sampled") 193 discrete_ignored = Msg("{} categorical feature{} ignored") 194 row_clust = Msg("{}") 195 col_clust = Msg("{}") 196 sparse_densified = Msg("Showing this data may require a lot of memory") 197 198 class Error(widget.OWWidget.Error): 199 no_continuous = Msg("No numeric features") 200 not_enough_features = Msg("Not enough features for column clustering") 201 not_enough_instances = Msg("Not enough instances for clustering") 202 not_enough_instances_k_means = Msg( 203 "Not enough instances for k-means merging") 204 not_enough_memory = Msg("Not enough memory to show this data") 205 206 class Warning(widget.OWWidget.Warning): 207 empty_clusters = Msg("Empty clusters were removed") 208 209 UserAdviceMessages = [ 210 widget.Message( 211 "For data with a meaningful mid-point, " 212 "choose one of diverging palettes.", 213 "diverging_palette")] 214 215 def __init__(self): 216 super().__init__() 217 self.__pending_selection = self.selected_rows 218 219 # A kingdom for a save_state/restore_state 220 self.col_clustering = enum_get( 221 Clustering, self.col_clustering_method, Clustering.None_) 222 self.row_clustering = enum_get( 223 Clustering, self.row_clustering_method, Clustering.None_) 224 225 self.settingsAboutToBePacked.connect(self._save_state_for_serialization) 226 self.keep_aspect = False 227 228 #: The original data with all features (retained to 229 #: preserve the domain on the output) 230 self.input_data = None 231 #: The effective data striped of discrete features, and often 232 #: merged using k-means 233 self.data = None 234 self.effective_data = None 235 #: Source of column annotations (derived from self.data) 236 self.col_annot_data: Optional[Table] = None 237 #: kmeans model used to merge rows of input_data 238 self.kmeans_model = None 239 #: merge indices derived from kmeans 240 #: a list (len==k) of int ndarray where the i-th item contains 241 #: the indices which merge the input_data into the heatmap row i 242 self.merge_indices = None 243 self.parts: Optional[Parts] = None 244 self.__rows_cache = {} 245 self.__columns_cache = {} 246 247 # GUI definition 248 colorbox = gui.vBox(self.controlArea, "Color") 249 250 self.color_map_widget = cmw = ColorGradientSelection( 251 thresholds=(self.threshold_low, self.threshold_high), 252 center=self.color_center 253 ) 254 model = itemmodels.ContinuousPalettesModel(parent=self) 255 cmw.setModel(model) 256 idx = cmw.findData(self.palette_name, model.KeyRole) 257 if idx != -1: 258 cmw.setCurrentIndex(idx) 259 260 cmw.activated.connect(self.update_color_schema) 261 262 def _set_thresholds(low, high): 263 self.threshold_low, self.threshold_high = low, high 264 self.update_color_schema() 265 cmw.thresholdsChanged.connect(_set_thresholds) 266 267 def _set_centering(center): 268 self.color_center = center 269 self.update_color_schema() 270 cmw.centerChanged.connect(_set_centering) 271 272 colorbox.layout().addWidget(self.color_map_widget) 273 274 mergebox = gui.vBox(self.controlArea, "Merge",) 275 gui.checkBox(mergebox, self, "merge_kmeans", "Merge by k-means", 276 callback=self.__update_row_clustering) 277 ibox = gui.indentedBox(mergebox) 278 gui.spin(ibox, self, "merge_kmeans_k", minv=5, maxv=500, 279 label="Clusters:", keyboardTracking=False, 280 callbackOnReturn=True, callback=self.update_merge) 281 282 cluster_box = gui.vBox(self.controlArea, "Clustering") 283 # Row clustering 284 self.row_cluster_cb = cb = ComboBox() 285 cb.setModel(create_list_model(ClusteringModelData, self)) 286 cbselect(cb, self.row_clustering, ClusteringRole) 287 self.connect_control( 288 "row_clustering", 289 lambda value, cb=cb: cbselect(cb, value, ClusteringRole) 290 ) 291 @cb.activated.connect 292 def _(idx, cb=cb): 293 self.set_row_clustering(cb.itemData(idx, ClusteringRole)) 294 295 # Column clustering 296 self.col_cluster_cb = cb = ComboBox() 297 cb.setModel(create_list_model(ClusteringModelData, self)) 298 cbselect(cb, self.col_clustering, ClusteringRole) 299 self.connect_control( 300 "col_clustering", 301 lambda value, cb=cb: cbselect(cb, value, ClusteringRole) 302 ) 303 @cb.activated.connect 304 def _(idx, cb=cb): 305 self.set_col_clustering(cb.itemData(idx, ClusteringRole)) 306 307 form = QFormLayout( 308 labelAlignment=Qt.AlignLeft, formAlignment=Qt.AlignLeft, 309 fieldGrowthPolicy=QFormLayout.AllNonFixedFieldsGrow, 310 ) 311 form.addRow("Rows:", self.row_cluster_cb) 312 form.addRow("Columns:", self.col_cluster_cb) 313 cluster_box.layout().addLayout(form) 314 box = gui.vBox(self.controlArea, "Split By") 315 form = QFormLayout( 316 formAlignment=Qt.AlignLeft, labelAlignment=Qt.AlignLeft, 317 fieldGrowthPolicy=QFormLayout.AllNonFixedFieldsGrow, 318 ) 319 box.layout().addLayout(form) 320 321 self.row_split_model = DomainModel( 322 placeholder="(None)", 323 valid_types=(Orange.data.DiscreteVariable,), 324 parent=self, 325 ) 326 self.row_split_cb = cb = ComboBoxSearch( 327 enabled=not self.merge_kmeans, 328 sizeAdjustPolicy=ComboBox.AdjustToMinimumContentsLengthWithIcon, 329 minimumContentsLength=14, 330 toolTip="Split the heatmap vertically by a categorical column" 331 ) 332 self.row_split_cb.setModel(self.row_split_model) 333 self.connect_control( 334 "split_by_var", lambda value, cb=cb: cbselect(cb, value) 335 ) 336 self.connect_control( 337 "merge_kmeans", self.row_split_cb.setDisabled 338 ) 339 self.split_by_var = None 340 341 self.row_split_cb.activated.connect( 342 self.__on_split_rows_activated 343 ) 344 self.col_split_model = DomainModel( 345 placeholder="(None)", 346 order=DomainModel.MIXED, 347 valid_types=(Orange.data.DiscreteVariable,), 348 parent=self, 349 ) 350 self.col_split_cb = cb = ComboBoxSearch( 351 sizeAdjustPolicy=ComboBox.AdjustToMinimumContentsLengthWithIcon, 352 minimumContentsLength=14, 353 toolTip="Split the heatmap horizontally by column annotation" 354 ) 355 self.col_split_cb.setModel(self.col_split_model) 356 self.connect_control( 357 "split_columns_var", lambda value, cb=cb: cbselect(cb, value) 358 ) 359 self.split_columns_var = None 360 self.col_split_cb.activated.connect(self.__on_split_cols_activated) 361 form.addRow("Rows:", self.row_split_cb) 362 form.addRow("Columns:", self.col_split_cb) 363 364 box = gui.vBox(self.controlArea, 'Annotation && Legends') 365 366 gui.checkBox(box, self, 'legend', 'Show legend', 367 callback=self.update_legend) 368 369 gui.checkBox(box, self, 'averages', 'Stripes with averages', 370 callback=self.update_averages_stripe) 371 gui.separator(box) 372 annotbox = QGroupBox("Row Annotations") 373 form = QFormLayout( 374 annotbox, 375 formAlignment=Qt.AlignLeft, 376 labelAlignment=Qt.AlignLeft, 377 fieldGrowthPolicy=QFormLayout.AllNonFixedFieldsGrow 378 ) 379 self.annotation_model = DomainModel(placeholder="(None)") 380 self.annotation_text_cb = ComboBoxSearch( 381 minimumContentsLength=12, 382 sizeAdjustPolicy=QComboBox.AdjustToMinimumContentsLengthWithIcon 383 ) 384 self.annotation_text_cb.setModel(self.annotation_model) 385 self.annotation_text_cb.activated.connect(self.set_annotation_var) 386 self.connect_control("annotation_var", self.annotation_var_changed) 387 388 self.row_side_color_model = DomainModel( 389 order=(DomainModel.CLASSES, DomainModel.Separator, 390 DomainModel.METAS), 391 placeholder="(None)", valid_types=DomainModel.PRIMITIVE, 392 flags=Qt.ItemIsSelectable | Qt.ItemIsEnabled, 393 parent=self, 394 ) 395 self.row_side_color_cb = ComboBoxSearch( 396 sizeAdjustPolicy=QComboBox.AdjustToMinimumContentsLengthWithIcon, 397 minimumContentsLength=12 398 ) 399 self.row_side_color_cb.setModel(self.row_side_color_model) 400 self.row_side_color_cb.activated.connect(self.set_annotation_color_var) 401 self.connect_control("annotation_color_var", self.annotation_color_var_changed) 402 form.addRow("Text", self.annotation_text_cb) 403 form.addRow("Color", self.row_side_color_cb) 404 box.layout().addWidget(annotbox) 405 annotbox = QGroupBox("Column annotations") 406 form = QFormLayout( 407 annotbox, 408 formAlignment=Qt.AlignLeft, 409 labelAlignment=Qt.AlignLeft, 410 fieldGrowthPolicy=QFormLayout.AllNonFixedFieldsGrow 411 ) 412 self.col_side_color_model = DomainModel( 413 placeholder="(None)", 414 valid_types=(DiscreteVariable, ContinuousVariable), 415 parent=self 416 ) 417 self.col_side_color_cb = cb = ComboBoxSearch( 418 sizeAdjustPolicy=QComboBox.AdjustToMinimumContentsLengthWithIcon, 419 minimumContentsLength=12 420 ) 421 self.col_side_color_cb.setModel(self.col_side_color_model) 422 self.connect_control( 423 "column_annotation_color_var", self.column_annotation_color_var_changed, 424 ) 425 self.column_annotation_color_var = None 426 self.col_side_color_cb.activated.connect( 427 self.__set_column_annotation_color_var_index) 428 429 cb = gui.comboBox( 430 None, self, "column_label_pos", 431 callback=self.update_column_annotations) 432 cb.setModel(create_list_model(ColumnLabelsPosData, parent=self)) 433 cb.setCurrentIndex(self.column_label_pos) 434 form.addRow("Position", cb) 435 form.addRow("Color", self.col_side_color_cb) 436 box.layout().addWidget(annotbox) 437 438 gui.checkBox(self.controlArea, self, "keep_aspect", 439 "Keep aspect ratio", box="Resize", 440 callback=self.__aspect_mode_changed) 441 442 gui.rubber(self.controlArea) 443 444 gui.auto_send(self.buttonsArea, self, "auto_commit") 445 446 # Scene with heatmap 447 class HeatmapScene(QGraphicsScene): 448 widget: Optional[HeatmapGridWidget] = None 449 450 self.scene = self.scene = HeatmapScene(parent=self) 451 self.view = GraphicsView( 452 self.scene, 453 verticalScrollBarPolicy=Qt.ScrollBarAlwaysOn, 454 horizontalScrollBarPolicy=Qt.ScrollBarAlwaysOn, 455 viewportUpdateMode=QGraphicsView.FullViewportUpdate, 456 widgetResizable=True, 457 ) 458 self.view.setContextMenuPolicy(Qt.CustomContextMenu) 459 self.view.customContextMenuRequested.connect( 460 self._on_view_context_menu 461 ) 462 self.mainArea.layout().addWidget(self.view) 463 self.selected_rows = [] 464 self.__font_inc = QAction( 465 "Increase Font", self, shortcut=QKeySequence("ctrl+>")) 466 self.__font_dec = QAction( 467 "Decrease Font", self, shortcut=QKeySequence("ctrl+<")) 468 self.__font_inc.triggered.connect(lambda: self.__adjust_font_size(1)) 469 self.__font_dec.triggered.connect(lambda: self.__adjust_font_size(-1)) 470 if hasattr(QAction, "setShortcutVisibleInContextMenu"): 471 apply_all( 472 [self.__font_inc, self.__font_dec], 473 lambda a: a.setShortcutVisibleInContextMenu(True) 474 ) 475 self.addActions([self.__font_inc, self.__font_dec]) 476 477 def _save_state_for_serialization(self): 478 def desc(var: Optional[Variable]) -> Optional[Tuple[str, str]]: 479 if var is not None: 480 return type(var).__name__, var.name 481 else: 482 return None 483 484 self.col_clustering_method = self.col_clustering.name 485 self.row_clustering_method = self.row_clustering.name 486 487 self.column_annotation_color_key = desc(self.column_annotation_color_var) 488 self.split_columns_key = desc(self.split_columns_var) 489 490 @property 491 def center_palette(self): 492 palette = self.color_map_widget.currentData() 493 return bool(palette.flags & palette.Diverging) 494 495 @property 496 def _column_label_pos(self) -> HeatmapGridWidget.Position: 497 return ColumnLabelsPosData[self.column_label_pos][Qt.UserRole] 498 499 def annotation_color_var_changed(self, value): 500 cbselect(self.row_side_color_cb, value, Qt.EditRole) 501 502 def annotation_var_changed(self, value): 503 cbselect(self.annotation_text_cb, value, Qt.EditRole) 504 505 def set_row_clustering(self, method: Clustering) -> None: 506 assert isinstance(method, Clustering) 507 if self.row_clustering != method: 508 self.row_clustering = method 509 cbselect(self.row_cluster_cb, method, ClusteringRole) 510 self.__update_row_clustering() 511 512 def set_col_clustering(self, method: Clustering) -> None: 513 assert isinstance(method, Clustering) 514 if self.col_clustering != method: 515 self.col_clustering = method 516 cbselect(self.col_cluster_cb, method, ClusteringRole) 517 self.__update_column_clustering() 518 519 def sizeHint(self) -> QSize: 520 return super().sizeHint().expandedTo(QSize(900, 700)) 521 522 def color_palette(self): 523 return self.color_map_widget.currentData().lookup_table() 524 525 def color_map(self) -> GradientColorMap: 526 return GradientColorMap( 527 self.color_palette(), (self.threshold_low, self.threshold_high), 528 self.color_map_widget.center() if self.center_palette else None 529 ) 530 531 def clear(self): 532 self.data = None 533 self.input_data = None 534 self.effective_data = None 535 self.kmeans_model = None 536 self.merge_indices = None 537 self.annotation_model.set_domain(None) 538 self.annotation_var = None 539 self.row_side_color_model.set_domain(None) 540 self.col_side_color_model.set_domain(None) 541 self.annotation_color_var = None 542 self.column_annotation_color_var = None 543 self.row_split_model.set_domain(None) 544 self.col_split_model.set_domain(None) 545 self.split_by_var = None 546 self.split_columns_var = None 547 self.parts = None 548 self.clear_scene() 549 self.selected_rows = [] 550 self.__columns_cache.clear() 551 self.__rows_cache.clear() 552 self.__update_clustering_enable_state(None) 553 554 def clear_scene(self): 555 if self.scene.widget is not None: 556 self.scene.widget.layoutDidActivate.disconnect( 557 self.__on_layout_activate 558 ) 559 self.scene.widget.selectionFinished.disconnect( 560 self.on_selection_finished 561 ) 562 self.scene.widget = None 563 self.scene.clear() 564 565 self.view.setSceneRect(QRectF()) 566 self.view.setHeaderSceneRect(QRectF()) 567 self.view.setFooterSceneRect(QRectF()) 568 569 @Inputs.data 570 def set_dataset(self, data=None): 571 """Set the input dataset to display.""" 572 self.closeContext() 573 self.clear() 574 self.clear_messages() 575 576 if isinstance(data, SqlTable): 577 if data.approx_len() < 4000: 578 data = Table(data) 579 else: 580 self.Information.sampled() 581 data_sample = data.sample_time(1, no_cache=True) 582 data_sample.download_data(2000, partial=True) 583 data = Table(data_sample) 584 585 if data is not None and not len(data): 586 data = None 587 588 if data is not None and sp.issparse(data.X): 589 try: 590 data = data.to_dense() 591 except MemoryError: 592 data = None 593 self.Error.not_enough_memory() 594 else: 595 self.Information.sparse_densified() 596 597 input_data = data 598 599 # Data contains no attributes or meta attributes only 600 if data is not None and len(data.domain.attributes) == 0: 601 self.Error.no_continuous() 602 input_data = data = None 603 604 # Data contains some discrete attributes which must be filtered 605 if data is not None and \ 606 any(var.is_discrete for var in data.domain.attributes): 607 ndisc = sum(var.is_discrete for var in data.domain.attributes) 608 data = data.transform( 609 Domain([var for var in data.domain.attributes 610 if var.is_continuous], 611 data.domain.class_vars, 612 data.domain.metas)) 613 if not data.domain.attributes: 614 self.Error.no_continuous() 615 input_data = data = None 616 else: 617 self.Information.discrete_ignored( 618 ndisc, "s" if ndisc > 1 else "") 619 620 self.data = data 621 self.input_data = input_data 622 623 if data is not None: 624 self.annotation_model.set_domain(self.input_data.domain) 625 self.row_side_color_model.set_domain(self.input_data.domain) 626 self.annotation_var = None 627 self.annotation_color_var = None 628 self.row_split_model.set_domain(data.domain) 629 self.col_annot_data = data.transpose(data[:0].transform(Domain(data.domain.attributes))) 630 self.col_split_model.set_domain(self.col_annot_data.domain) 631 self.col_side_color_model.set_domain(self.col_annot_data.domain) 632 if data.domain.has_discrete_class: 633 self.split_by_var = data.domain.class_var 634 else: 635 self.split_by_var = None 636 self.split_columns_var = None 637 self.column_annotation_color_var = None 638 self.openContext(self.input_data) 639 if self.split_by_var not in self.row_split_model: 640 self.split_by_var = None 641 642 def match(desc: Tuple[str, str], source: Iterable[Variable]): 643 for v in source: 644 if desc == (type(v).__name__, v.name): 645 return v 646 return None 647 648 def is_variable(obj): 649 return isinstance(obj, Variable) 650 651 if self.split_columns_key is not None: 652 self.split_columns_var = match( 653 self.split_columns_key, 654 filter(is_variable, self.col_split_model) 655 ) 656 657 if self.column_annotation_color_key is not None: 658 self.column_annotation_color_var = match( 659 self.column_annotation_color_key, 660 filter(is_variable, self.col_side_color_model) 661 ) 662 663 self.update_heatmaps() 664 if data is not None and self.__pending_selection is not None: 665 assert self.scene.widget is not None 666 self.scene.widget.selectRows(self.__pending_selection) 667 self.selected_rows = self.__pending_selection 668 self.__pending_selection = None 669 670 self.unconditional_commit() 671 672 def __on_split_rows_activated(self): 673 self.set_split_variable(self.row_split_cb.currentData(Qt.EditRole)) 674 675 def set_split_variable(self, var): 676 if var is not self.split_by_var: 677 self.split_by_var = var 678 self.update_heatmaps() 679 680 def __on_split_cols_activated(self): 681 self.set_column_split_var(self.col_split_cb.currentData(Qt.EditRole)) 682 683 def set_column_split_var(self, var: Optional[Variable]): 684 if var is not self.split_columns_var: 685 self.split_columns_var = var 686 self.update_heatmaps() 687 688 def update_heatmaps(self): 689 if self.data is not None: 690 self.clear_scene() 691 self.clear_messages() 692 if self.col_clustering != Clustering.None_ and \ 693 len(self.data.domain.attributes) < 2: 694 self.Error.not_enough_features() 695 elif (self.col_clustering != Clustering.None_ or 696 self.row_clustering != Clustering.None_) and \ 697 len(self.data) < 2: 698 self.Error.not_enough_instances() 699 elif self.merge_kmeans and len(self.data) < 3: 700 self.Error.not_enough_instances_k_means() 701 else: 702 parts = self.construct_heatmaps(self.data, self.split_by_var, self.split_columns_var) 703 self.construct_heatmaps_scene(parts, self.effective_data) 704 self.selected_rows = [] 705 else: 706 self.clear() 707 708 def update_merge(self): 709 self.kmeans_model = None 710 self.merge_indices = None 711 if self.data is not None and self.merge_kmeans: 712 self.update_heatmaps() 713 self.commit() 714 715 def _make_parts(self, data, group_var=None, column_split_key=None): 716 """ 717 Make initial `Parts` for data, split by group_var, group_key 718 """ 719 if group_var is not None: 720 assert group_var.is_discrete 721 _col_data = table_column_data(data, group_var) 722 row_indices = [np.flatnonzero(_col_data == i) 723 for i in range(len(group_var.values))] 724 725 row_groups = [RowPart(title=name, indices=ind, 726 cluster=None, cluster_ordered=None) 727 for name, ind in zip(group_var.values, row_indices)] 728 if np.any(_col_data.mask): 729 row_groups.append(RowPart( 730 title="N/A", indices=np.flatnonzero(_col_data.mask), 731 cluster=None, cluster_ordered=None 732 )) 733 else: 734 row_groups = [RowPart(title=None, indices=range(0, len(data)), 735 cluster=None, cluster_ordered=None)] 736 737 if column_split_key is not None: 738 col_groups = split_domain(data.domain, column_split_key) 739 assert len(col_groups) > 0 740 col_indices = [np.array([data.domain.index(var) for var in group]) 741 for _, group in col_groups] 742 col_groups = [ColumnPart(title=str(name), domain=d, indices=ind, 743 cluster=None, cluster_ordered=None) 744 for (name, d), ind in zip(col_groups, col_indices)] 745 else: 746 col_groups = [ 747 ColumnPart( 748 title=None, indices=range(0, len(data.domain.attributes)), 749 domain=data.domain.attributes, cluster=None, cluster_ordered=None) 750 ] 751 752 minv, maxv = np.nanmin(data.X), np.nanmax(data.X) 753 return Parts(row_groups, col_groups, span=(minv, maxv)) 754 755 def cluster_rows(self, data: Table, parts: 'Parts', ordered=False) -> 'Parts': 756 row_groups = [] 757 for row in parts.rows: 758 if row.cluster is not None: 759 cluster = row.cluster 760 else: 761 cluster = None 762 if row.cluster_ordered is not None: 763 cluster_ord = row.cluster_ordered 764 else: 765 cluster_ord = None 766 767 if row.can_cluster: 768 matrix = None 769 need_dist = cluster is None or (ordered and cluster_ord is None) 770 if need_dist: 771 subset = data[row.indices] 772 matrix = Orange.distance.Euclidean(subset) 773 774 if cluster is None: 775 cluster = hierarchical.dist_matrix_clustering( 776 matrix, linkage=hierarchical.WARD 777 ) 778 if ordered and cluster_ord is None: 779 cluster_ord = hierarchical.optimal_leaf_ordering( 780 cluster, matrix, 781 ) 782 row_groups.append(row._replace(cluster=cluster, cluster_ordered=cluster_ord)) 783 784 return parts._replace(rows=row_groups) 785 786 def cluster_columns(self, data, parts: 'Parts', ordered=False): 787 assert all(var.is_continuous for var in data.domain.attributes) 788 col_groups = [] 789 for col in parts.columns: 790 if col.cluster is not None: 791 cluster = col.cluster 792 else: 793 cluster = None 794 if col.cluster_ordered is not None: 795 cluster_ord = col.cluster_ordered 796 else: 797 cluster_ord = None 798 if col.can_cluster: 799 need_dist = cluster is None or (ordered and cluster_ord is None) 800 matrix = None 801 if need_dist: 802 subset = data.transform(Domain(col.domain)) 803 subset = Orange.distance._preprocess(subset) 804 matrix = np.asarray(Orange.distance.PearsonR(subset, axis=0)) 805 # nan values break clustering below 806 matrix = np.nan_to_num(matrix) 807 808 if cluster is None: 809 assert matrix is not None 810 cluster = hierarchical.dist_matrix_clustering( 811 matrix, linkage=hierarchical.WARD 812 ) 813 if ordered and cluster_ord is None: 814 cluster_ord = hierarchical.optimal_leaf_ordering(cluster, matrix) 815 816 col_groups.append(col._replace(cluster=cluster, cluster_ordered=cluster_ord)) 817 return parts._replace(columns=col_groups) 818 819 def construct_heatmaps(self, data, group_var=None, column_split_key=None) -> 'Parts': 820 if self.merge_kmeans: 821 if self.kmeans_model is None: 822 effective_data = self.input_data.transform( 823 Orange.data.Domain( 824 [var for var in self.input_data.domain.attributes 825 if var.is_continuous], 826 self.input_data.domain.class_vars, 827 self.input_data.domain.metas)) 828 nclust = min(self.merge_kmeans_k, len(effective_data) - 1) 829 self.kmeans_model = kmeans_compress(effective_data, k=nclust) 830 effective_data.domain = self.kmeans_model.domain 831 merge_indices = [np.flatnonzero(self.kmeans_model.labels == ind) 832 for ind in range(nclust)] 833 not_empty_indices = [i for i, x in enumerate(merge_indices) 834 if len(x) > 0] 835 self.merge_indices = \ 836 [merge_indices[i] for i in not_empty_indices] 837 if len(merge_indices) != len(self.merge_indices): 838 self.Warning.empty_clusters() 839 effective_data = Orange.data.Table( 840 Orange.data.Domain(effective_data.domain.attributes), 841 self.kmeans_model.centroids[not_empty_indices] 842 ) 843 else: 844 effective_data = self.effective_data 845 846 group_var = None 847 else: 848 self.kmeans_model = None 849 self.merge_indices = None 850 effective_data = data 851 852 self.effective_data = effective_data 853 854 parts = self._make_parts( 855 effective_data, group_var, 856 column_split_key.name if column_split_key is not None else None) 857 858 self.__update_clustering_enable_state(parts) 859 # Restore/update the row/columns items descriptions from cache if 860 # available 861 rows_cache_key = (group_var, 862 self.merge_kmeans_k if self.merge_kmeans else None) 863 if rows_cache_key in self.__rows_cache: 864 parts = parts._replace(rows=self.__rows_cache[rows_cache_key].rows) 865 866 if column_split_key in self.__columns_cache: 867 parts = parts._replace( 868 columns=self.__columns_cache[column_split_key].columns) 869 870 if self.row_clustering != Clustering.None_: 871 parts = self.cluster_rows( 872 effective_data, parts, 873 ordered=self.row_clustering == Clustering.OrderedClustering 874 ) 875 if self.col_clustering != Clustering.None_: 876 parts = self.cluster_columns( 877 effective_data, parts, 878 ordered=self.col_clustering == Clustering.OrderedClustering 879 ) 880 881 # Cache the updated parts 882 self.__rows_cache[rows_cache_key] = parts 883 return parts 884 885 def construct_heatmaps_scene(self, parts: 'Parts', data: Table) -> None: 886 _T = TypeVar("_T", bound=Union[RowPart, ColumnPart]) 887 888 def select_cluster(clustering: Clustering, item: _T) -> _T: 889 if clustering == Clustering.None_: 890 return item._replace(cluster=None, cluster_ordered=None) 891 elif clustering == Clustering.Clustering: 892 return item._replace(cluster=item.cluster, cluster_ordered=None) 893 elif clustering == Clustering.OrderedClustering: 894 return item._replace(cluster=item.cluster_ordered, cluster_ordered=None) 895 else: # pragma: no cover 896 raise TypeError() 897 898 rows = [select_cluster(self.row_clustering, rowitem) 899 for rowitem in parts.rows] 900 cols = [select_cluster(self.col_clustering, colitem) 901 for colitem in parts.columns] 902 parts = Parts(columns=cols, rows=rows, span=parts.span) 903 904 self.setup_scene(parts, data) 905 906 def setup_scene(self, parts, data): 907 # type: (Parts, Table) -> None 908 widget = HeatmapGridWidget() 909 widget.setColorMap(self.color_map()) 910 self.scene.addItem(widget) 911 self.scene.widget = widget 912 columns = [v.name for v in data.domain.attributes] 913 parts = HeatmapGridWidget.Parts( 914 rows=[ 915 HeatmapGridWidget.RowItem(r.title, r.indices, r.cluster) 916 for r in parts.rows 917 ], 918 columns=[ 919 HeatmapGridWidget.ColumnItem(c.title, c.indices, c.cluster) 920 for c in parts.columns 921 ], 922 data=data.X, 923 span=parts.span, 924 row_names=None, 925 col_names=columns, 926 ) 927 widget.setHeatmaps(parts) 928 929 side = self.row_side_colors() 930 if side is not None: 931 widget.setRowSideColorAnnotations(side[0], side[1], name=side[2].name) 932 933 side = self.column_side_colors() 934 if side is not None: 935 widget.setColumnSideColorAnnotations(side[0], side[1], name=side[2].name) 936 937 widget.setColumnLabelsPosition(self._column_label_pos) 938 widget.setAspectRatioMode( 939 Qt.KeepAspectRatio if self.keep_aspect else Qt.IgnoreAspectRatio 940 ) 941 widget.setShowAverages(self.averages) 942 widget.setLegendVisible(self.legend) 943 944 widget.layoutDidActivate.connect(self.__on_layout_activate) 945 widget.selectionFinished.connect(self.on_selection_finished) 946 947 self.update_annotations() 948 self.view.setCentralWidget(widget) 949 self.parts = parts 950 951 def __update_scene_rects(self): 952 widget = self.scene.widget 953 if widget is None: 954 return 955 rect = widget.geometry() 956 self.scene.setSceneRect(rect) 957 self.view.setSceneRect(rect) 958 self.view.setHeaderSceneRect(widget.headerGeometry()) 959 self.view.setFooterSceneRect(widget.footerGeometry()) 960 961 def __on_layout_activate(self): 962 self.__update_scene_rects() 963 964 def __aspect_mode_changed(self): 965 widget = self.scene.widget 966 if widget is None: 967 return 968 widget.setAspectRatioMode( 969 Qt.KeepAspectRatio if self.keep_aspect else Qt.IgnoreAspectRatio 970 ) 971 # when aspect fixed the vertical sh is fixex, when not, it can 972 # shrink vertically 973 sp = widget.sizePolicy() 974 if self.keep_aspect: 975 sp.setVerticalPolicy(QSizePolicy.Fixed) 976 else: 977 sp.setVerticalPolicy(QSizePolicy.Preferred) 978 widget.setSizePolicy(sp) 979 980 def __update_clustering_enable_state(self, parts: Optional['Parts']): 981 def c_cost(sizes: Iterable[int]) -> int: 982 """Estimated cost for clustering of `sizes`""" 983 return sum(n ** 2 for n in sizes) 984 985 def co_cost(sizes: Iterable[int]) -> int: 986 """Estimated cost for cluster ordering of `sizes`""" 987 # ~O(N ** 3) but O(N ** 4) worst case. 988 return sum(n ** 4 for n in sizes) 989 990 if parts is not None: 991 Ns = [len(p.indices) for p in parts.rows] 992 Ms = [len(p.indices) for p in parts.columns] 993 else: 994 Ns = Ms = [0] 995 996 rc_enabled = c_cost(Ns) <= c_cost([self.MaxClustering]) 997 rco_enabled = co_cost(Ns) <= co_cost([self.MaxOrderedClustering]) 998 cc_enabled = c_cost(Ms) <= c_cost([self.MaxClustering]) 999 cco_enabled = co_cost(Ms) <= co_cost([self.MaxOrderedClustering]) 1000 row_clust, col_clust = self.row_clustering, self.col_clustering 1001 1002 row_clust_msg = "" 1003 col_clust_msg = "" 1004 1005 if not rco_enabled and row_clust == Clustering.OrderedClustering: 1006 row_clust = Clustering.Clustering 1007 row_clust_msg = "Row cluster ordering was disabled due to the " \ 1008 "estimated runtime cost" 1009 if not rc_enabled and row_clust == Clustering.Clustering: 1010 row_clust = Clustering.None_ 1011 row_clust_msg = "Row clustering was was disabled due to the " \ 1012 "estimated runtime cost" 1013 1014 if not cco_enabled and col_clust == Clustering.OrderedClustering: 1015 col_clust = Clustering.Clustering 1016 col_clust_msg = "Column cluster ordering was disabled due to " \ 1017 "estimated runtime cost" 1018 if not cc_enabled and col_clust == Clustering.Clustering: 1019 col_clust = Clustering.None_ 1020 col_clust_msg = "Column clustering was disabled due to the " \ 1021 "estimated runtime cost" 1022 1023 self.col_clustering = col_clust 1024 self.row_clustering = row_clust 1025 1026 self.Information.row_clust(row_clust_msg, shown=bool(row_clust_msg)) 1027 self.Information.col_clust(col_clust_msg, shown=bool(col_clust_msg)) 1028 1029 # Disable/enable the combobox items for the clustering methods 1030 def setenabled(cb: QComboBox, clu: bool, clu_op: bool): 1031 model = cb.model() 1032 assert isinstance(model, QStandardItemModel) 1033 idx = cb.findData(Clustering.OrderedClustering, ClusteringRole) 1034 assert idx != -1 1035 model.item(idx).setEnabled(clu_op) 1036 idx = cb.findData(Clustering.Clustering, ClusteringRole) 1037 assert idx != -1 1038 model.item(idx).setEnabled(clu) 1039 1040 setenabled(self.row_cluster_cb, rc_enabled, rco_enabled) 1041 setenabled(self.col_cluster_cb, cc_enabled, cco_enabled) 1042 1043 def update_averages_stripe(self): 1044 """Update the visibility of the averages stripe. 1045 """ 1046 widget = self.scene.widget 1047 if widget is not None: 1048 widget.setShowAverages(self.averages) 1049 1050 def update_color_schema(self): 1051 self.palette_name = self.color_map_widget.currentData().name 1052 w = self.scene.widget 1053 if w is not None: 1054 w.setColorMap(self.color_map()) 1055 1056 def __update_column_clustering(self): 1057 self.update_heatmaps() 1058 self.commit() 1059 1060 def __update_row_clustering(self): 1061 self.update_heatmaps() 1062 self.commit() 1063 1064 def update_legend(self): 1065 widget = self.scene.widget 1066 if widget is not None: 1067 widget.setLegendVisible(self.legend) 1068 1069 def row_annotation_var(self): 1070 return self.annotation_var 1071 1072 def row_annotation_data(self): 1073 var = self.row_annotation_var() 1074 if var is None: 1075 return None 1076 return column_str_from_table(self.input_data, var) 1077 1078 def _merge_row_indices(self): 1079 if self.merge_kmeans and self.kmeans_model is not None: 1080 return self.merge_indices 1081 else: 1082 return None 1083 1084 def set_annotation_var(self, var: Union[None, Variable, int]): 1085 if isinstance(var, int): 1086 var = self.annotation_model[var] 1087 if self.annotation_var is not var: 1088 self.annotation_var = var 1089 self.update_annotations() 1090 1091 def update_annotations(self): 1092 widget = self.scene.widget 1093 if widget is not None: 1094 annot_col = self.row_annotation_data() 1095 merge_indices = self._merge_row_indices() 1096 if merge_indices is not None and annot_col is not None: 1097 join = lambda _1: join_elided(", ", 42, _1, " ({} more)") 1098 annot_col = aggregate_apply(join, annot_col, merge_indices) 1099 if annot_col is not None: 1100 widget.setRowLabels(annot_col) 1101 widget.setRowLabelsVisible(True) 1102 else: 1103 widget.setRowLabelsVisible(False) 1104 widget.setRowLabels(None) 1105 1106 def row_side_colors(self): 1107 var = self.annotation_color_var 1108 if var is None: 1109 return None 1110 column_data = column_data_from_table(self.input_data, var) 1111 merges = self._merge_row_indices() 1112 if merges is not None: 1113 column_data = aggregate(var, column_data, merges) 1114 data, colormap = colorize(var, column_data) 1115 if var.is_continuous: 1116 span = (np.nanmin(column_data), np.nanmax(column_data)) 1117 if np.any(np.isnan(span)): 1118 span = 0., 1. 1119 colormap.span = span 1120 return data, colormap, var 1121 1122 def set_annotation_color_var(self, var: Union[None, Variable, int]): 1123 """Set the current side color annotation variable.""" 1124 if isinstance(var, int): 1125 var = self.row_side_color_model[var] 1126 if self.annotation_color_var is not var: 1127 self.annotation_color_var = var 1128 self.update_row_side_colors() 1129 1130 def update_row_side_colors(self): 1131 widget = self.scene.widget 1132 if widget is None: 1133 return 1134 colors = self.row_side_colors() 1135 if colors is None: 1136 widget.setRowSideColorAnnotations(None) 1137 else: 1138 widget.setRowSideColorAnnotations(colors[0], colors[1], colors[2].name) 1139 1140 def __set_column_annotation_color_var_index(self, index: int): 1141 key = self.col_side_color_cb.itemData(index, Qt.EditRole) 1142 self.set_column_annotation_color_var(key) 1143 1144 def column_annotation_color_var_changed(self, value): 1145 cbselect(self.col_side_color_cb, value, Qt.EditRole) 1146 1147 def set_column_annotation_color_var(self, var): 1148 if self.column_annotation_color_var is not var: 1149 self.column_annotation_color_var = var 1150 colors = self.column_side_colors() 1151 if colors is not None: 1152 self.scene.widget.setColumnSideColorAnnotations( 1153 colors[0], colors[1], colors[2].name, 1154 ) 1155 else: 1156 self.scene.widget.setColumnSideColorAnnotations(None) 1157 1158 def column_side_colors(self): 1159 var = self.column_annotation_color_var 1160 if var is None: 1161 return None 1162 table = self.col_annot_data 1163 return color_annotation_data(table, var) 1164 1165 def update_column_annotations(self): 1166 widget = self.scene.widget 1167 if self.data is not None and widget is not None: 1168 widget.setColumnLabelsPosition(self._column_label_pos) 1169 1170 def __adjust_font_size(self, diff): 1171 widget = self.scene.widget 1172 if widget is None: 1173 return 1174 curr = widget.font().pointSizeF() 1175 new = curr + diff 1176 1177 self.__font_dec.setEnabled(new > 1.0) 1178 self.__font_inc.setEnabled(new <= 32) 1179 if new > 1.0: 1180 font = QFont() 1181 font.setPointSizeF(new) 1182 widget.setFont(font) 1183 1184 def _on_view_context_menu(self, pos): 1185 widget = self.scene.widget 1186 if widget is None: 1187 return 1188 assert isinstance(widget, HeatmapGridWidget) 1189 menu = QMenu(self.view.viewport()) 1190 menu.setAttribute(Qt.WA_DeleteOnClose) 1191 menu.addActions(self.view.actions()) 1192 menu.addSeparator() 1193 menu.addActions([self.__font_inc, self.__font_dec]) 1194 menu.addSeparator() 1195 a = QAction("Keep aspect ratio", menu, checkable=True) 1196 a.setChecked(self.keep_aspect) 1197 1198 def ontoggled(state): 1199 self.keep_aspect = state 1200 self.__aspect_mode_changed() 1201 a.toggled.connect(ontoggled) 1202 menu.addAction(a) 1203 menu.popup(self.view.viewport().mapToGlobal(pos)) 1204 1205 def on_selection_finished(self): 1206 if self.scene.widget is not None: 1207 self.selected_rows = list(self.scene.widget.selectedRows()) 1208 else: 1209 self.selected_rows = [] 1210 self.commit() 1211 1212 def commit(self): 1213 data = None 1214 indices = None 1215 if self.merge_kmeans: 1216 merge_indices = self.merge_indices 1217 else: 1218 merge_indices = None 1219 1220 if self.input_data is not None and self.selected_rows: 1221 indices = self.selected_rows 1222 if merge_indices is not None: 1223 # expand merged indices 1224 indices = np.hstack([merge_indices[i] for i in indices]) 1225 1226 data = self.input_data[indices] 1227 1228 self.Outputs.selected_data.send(data) 1229 self.Outputs.annotated_data.send(create_annotated_table(self.input_data, indices)) 1230 1231 def onDeleteWidget(self): 1232 self.clear() 1233 super().onDeleteWidget() 1234 1235 def send_report(self): 1236 self.report_items(( 1237 ("Columns:", "Clustering" if self.col_clustering else "No sorting"), 1238 ("Rows:", "Clustering" if self.row_clustering else "No sorting"), 1239 ("Split:", 1240 self.split_by_var is not None and self.split_by_var.name), 1241 ("Row annotation", 1242 self.annotation_var is not None and self.annotation_var.name), 1243 )) 1244 self.report_plot() 1245 1246 @classmethod 1247 def migrate_settings(cls, settings, version): 1248 if version is not None and version < 3: 1249 def st2cl(state: bool) -> Clustering: 1250 return Clustering.OrderedClustering if state else \ 1251 Clustering.None_ 1252 1253 rc = settings.pop("row_clustering", False) 1254 cc = settings.pop("col_clustering", False) 1255 settings["row_clustering_method"] = st2cl(rc).name 1256 settings["col_clustering_method"] = st2cl(cc).name 1257 1258 1259# If StickyGraphicsView ever defines qt signals/slots/properties this will 1260# break 1261class GraphicsView(GraphicsWidgetView, StickyGraphicsView): 1262 pass 1263 1264 1265class RowPart(NamedTuple): 1266 """ 1267 A row group 1268 1269 Attributes 1270 ---------- 1271 title: str 1272 Group title 1273 indices : (N, ) Sequence[int] 1274 Indices in the input data to retrieve the row subset for the group. 1275 cluster : hierarchical.Tree optional 1276 cluster_ordered : hierarchical.Tree optional 1277 """ 1278 title: str 1279 indices: Sequence[int] 1280 cluster: Optional[hierarchical.Tree] = None 1281 cluster_ordered: Optional[hierarchical.Tree] = None 1282 1283 @property 1284 def can_cluster(self) -> bool: 1285 if isinstance(self.indices, slice): 1286 return (self.indices.stop - self.indices.start) > 1 1287 else: 1288 return len(self.indices) > 1 1289 1290 1291class ColumnPart(NamedTuple): 1292 """ 1293 A column group 1294 1295 Attributes 1296 ---------- 1297 title : str 1298 Column group title 1299 indices : (N, ) int ndarray 1300 Indexes the input data to retrieve the column subset for the group. 1301 domain : List[Variable] 1302 List of variables in the group. 1303 cluster : hierarchical.Tree optional 1304 cluster_ordered : hierarchical.Tree optional 1305 """ 1306 title: str 1307 indices: Sequence[int] 1308 domain: Sequence[int] 1309 cluster: Optional[hierarchical.Tree] = None 1310 cluster_ordered: Optional[hierarchical.Tree] = None 1311 1312 @property 1313 def can_cluster(self) -> bool: 1314 if isinstance(self.indices, slice): 1315 return (self.indices.stop - self.indices.start) > 1 1316 else: 1317 return len(self.indices) > 1 1318 1319 1320class Parts(NamedTuple): 1321 rows: Sequence[RowPart] 1322 columns: Sequence[ColumnPart] 1323 span: Tuple[float, float] 1324 1325 1326def join_elided(sep, maxlen, values, elidetemplate="..."): 1327 def generate(sep, ellidetemplate, values): 1328 count = len(values) 1329 length = 0 1330 parts = [] 1331 for i, val in enumerate(values): 1332 elide = ellidetemplate.format(count - i) if count - i > 1 else "" 1333 parts.append(val) 1334 length += len(val) + (len(sep) if parts else 0) 1335 yield i, islice(parts, i + 1), length, elide 1336 1337 best = None 1338 for _, parts, length, elide in generate(sep, elidetemplate, values): 1339 if length > maxlen: 1340 if best is None: 1341 best = sep.join(parts) + elide 1342 return best 1343 fulllen = length + len(elide) 1344 if fulllen < maxlen or best is None: 1345 best = sep.join(parts) + elide 1346 return best 1347 1348 1349def column_str_from_table( 1350 table: Orange.data.Table, 1351 column: Union[int, Orange.data.Variable], 1352) -> np.ndarray: 1353 var = table.domain[column] 1354 data, _ = table.get_column_view(column) 1355 return np.asarray([var.str_val(v) for v in data], dtype=object) 1356 1357 1358def column_data_from_table( 1359 table: Orange.data.Table, 1360 column: Union[int, Orange.data.Variable], 1361) -> np.ndarray: 1362 var = table.domain[column] 1363 data, _ = table.get_column_view(column) 1364 if var.is_primitive() and data.dtype.kind != "f": 1365 data = data.astype(float) 1366 return data 1367 1368 1369def color_annotation_data( 1370 table: Table, var: Union[int, str, Variable] 1371) -> Tuple[np.ndarray, ColorMap, Variable]: 1372 var = table.domain[var] 1373 column_data = column_data_from_table(table, var) 1374 data, colormap = colorize(var, column_data) 1375 return data, colormap, var 1376 1377 1378def colorize(var: Variable, data: np.ndarray) -> Tuple[np.ndarray, ColorMap]: 1379 palette = var.palette # type: Palette 1380 colors = np.array( 1381 [[c.red(), c.green(), c.blue()] for c in palette.qcolors_w_nan], 1382 dtype=np.uint8, 1383 ) 1384 if var.is_discrete: 1385 mask = np.isnan(data) 1386 data = data.astype(int) 1387 data[mask] = -1 1388 if mask.any(): 1389 values = (*var.values, "N/A") 1390 else: 1391 values = var.values 1392 colors = colors[: -1] 1393 return data, CategoricalColorMap(colors, values) 1394 elif var.is_continuous: 1395 span = np.nanmin(data), np.nanmax(data) 1396 if np.any(np.isnan(span)): 1397 span = 0, 1. 1398 return data, GradientColorMap(colors[:-1], span=span) 1399 else: 1400 raise TypeError 1401 1402 1403def aggregate( 1404 var: Variable, data: np.ndarray, groupindices: Sequence[Sequence[int]], 1405) -> np.ndarray: 1406 if var.is_string: 1407 join = lambda values: (join_elided(", ", 42, values, " ({} more)")) 1408 # collect all original labels for every merged row 1409 values = [data[indices] for indices in groupindices] 1410 data = [join(list(map(var.str_val, vals))) for vals in values] 1411 return np.array(data, dtype=object) 1412 elif var.is_continuous: 1413 data = [np.nanmean(data[indices]) if len(indices) else np.nan 1414 for indices in groupindices] 1415 return np.array(data, dtype=float) 1416 elif var.is_discrete: 1417 from Orange.statistics.util import nanmode 1418 data = [nanmode(data[indices])[0] if len(indices) else np.nan 1419 for indices in groupindices] 1420 return np.asarray(data, dtype=float) 1421 else: 1422 raise TypeError(type(var)) 1423 1424 1425def agg_join_str(var, data, groupindices, maxlen=50, elidetemplate=" ({} more)"): 1426 join_s = lambda values: ( 1427 join_elided(", ", maxlen, values, elidetemplate=elidetemplate) 1428 ) 1429 join = lambda values: join_s(map(var.str_val, values)) 1430 return aggregate_apply(join, data, groupindices) 1431 1432 1433_T = TypeVar("_T") 1434 1435 1436def aggregate_apply( 1437 f: Callable[[Sequence], _T], 1438 data: np.ndarray, 1439 groupindices: Sequence[Sequence[int]] 1440) -> Sequence[_T]: 1441 return [f(data[indices]) for indices in groupindices] 1442 1443 1444if __name__ == "__main__": # pragma: no cover 1445 WidgetPreview(OWHeatMap).run(Table("brown-selected.tab")) 1446