1import sys
2import itertools
3import warnings
4from xml.sax.saxutils import escape
5from math import log10, floor, ceil
6from datetime import datetime, timezone
7
8import numpy as np
9from AnyQt.QtCore import Qt, QRectF, QSize, QTimer, pyqtSignal as Signal, \
10    QObject
11from AnyQt.QtGui import QColor, QPen, QBrush, QPainterPath, QTransform, \
12    QPainter
13from AnyQt.QtWidgets import QApplication, QToolTip, QGraphicsTextItem, \
14    QGraphicsRectItem, QGraphicsItemGroup
15
16import pyqtgraph as pg
17from pyqtgraph.graphicsItems.ScatterPlotItem import Symbols
18from pyqtgraph.graphicsItems.LegendItem import LegendItem as PgLegendItem
19from pyqtgraph.graphicsItems.TextItem import TextItem
20
21from Orange.preprocess.discretize import _time_binnings
22from Orange.util import utc_from_timestamp
23from Orange.widgets import gui
24from Orange.widgets.settings import Setting
25from Orange.widgets.utils import classdensity, colorpalettes
26from Orange.widgets.utils.plot import OWPalette
27from Orange.widgets.visualize.utils.customizableplot import Updater, \
28    CommonParameterSetter
29from Orange.widgets.visualize.utils.plotutils import (
30    HelpEventDelegate as EventDelegate, InteractiveViewBox as ViewBox,
31    PaletteItemSample, SymbolItemSample, AxisItem
32)
33
34SELECTION_WIDTH = 5
35MAX_N_VALID_SIZE_ANIMATE = 1000
36
37# maximum number of colors (including Other)
38MAX_COLORS = 11
39
40
41class LegendItem(PgLegendItem):
42    def __init__(self, size=None, offset=None, pen=None, brush=None):
43        super().__init__(size, offset)
44
45        self.layout.setContentsMargins(5, 5, 5, 5)
46        self.layout.setHorizontalSpacing(15)
47        self.layout.setColumnAlignment(1, Qt.AlignLeft | Qt.AlignVCenter)
48
49        if pen is None:
50            pen = QPen(QColor(196, 197, 193, 200), 1)
51            pen.setCosmetic(True)
52        self.__pen = pen
53
54        if brush is None:
55            brush = QBrush(QColor(232, 232, 232, 100))
56        self.__brush = brush
57
58    def restoreAnchor(self, anchors):
59        """
60        Restore (parent) relative position from stored anchors.
61
62        The restored position is within the parent bounds.
63        """
64        anchor, parentanchor = anchors
65        self.anchor(*bound_anchor_pos(anchor, parentanchor))
66
67    # pylint: disable=arguments-differ
68    def paint(self, painter, _option, _widget=None):
69        painter.setPen(self.__pen)
70        painter.setBrush(self.__brush)
71        rect = self.contentsRect()
72        painter.drawRoundedRect(rect, 2, 2)
73
74    def addItem(self, item, name):
75        super().addItem(item, name)
76        # Fix-up the label alignment
77        _, label = self.items[-1]
78        label.setText(name, justify="left")
79
80    def clear(self):
81        """
82        Clear all legend items.
83        """
84        items = list(self.items)
85        self.items = []
86        for sample, label in items:
87            self.layout.removeItem(sample)
88            self.layout.removeItem(label)
89            sample.hide()
90            label.hide()
91
92        self.updateSize()
93
94
95def bound_anchor_pos(corner, parentpos):
96    corner = np.clip(corner, 0, 1)
97    parentpos = np.clip(parentpos, 0, 1)
98
99    irx, iry = corner
100    prx, pry = parentpos
101
102    if irx > 0.9 and prx < 0.1:
103        irx = prx = 0.0
104    if iry > 0.9 and pry < 0.1:
105        iry = pry = 0.0
106    if irx < 0.1 and prx > 0.9:
107        irx = prx = 1.0
108    if iry < 0.1 and pry > 0.9:
109        iry = pry = 1.0
110    return (irx, iry), (prx, pry)
111
112
113class DiscretizedScale:
114    """
115    Compute suitable bins for continuous value from its minimal and
116    maximal value.
117
118    The width of the bin is a power of 10 (including negative powers).
119    The minimal value is rounded up and the maximal is rounded down. If this
120    gives less than 3 bins, the width is divided by four; if it gives
121    less than 6, it is halved.
122
123    .. attribute:: offset
124        The start of the first bin.
125
126    .. attribute:: width
127        The width of the bins
128
129    .. attribute:: bins
130        The number of bins
131
132    .. attribute:: decimals
133        The number of decimals used for printing out the boundaries
134    """
135    def __init__(self, min_v, max_v):
136        """
137        :param min_v: Minimal value
138        :type min_v: float
139        :param max_v: Maximal value
140        :type max_v: float
141        """
142        super().__init__()
143        dif = max_v - min_v if max_v != min_v else 1
144        if np.isnan(dif):
145            min_v = 0
146            dif = decimals = 1
147        else:
148            decimals = -floor(log10(dif))
149        resolution = 10 ** -decimals
150        bins = ceil(dif / resolution)
151        if bins < 6:
152            decimals += 1
153            if bins < 3:
154                resolution /= 4
155            else:
156                resolution /= 2
157            bins = ceil(dif / resolution)
158        self.offset = resolution * floor(min_v // resolution)
159        self.bins = bins
160        self.decimals = max(decimals, 0)
161        self.width = resolution
162
163    def get_bins(self):
164        return self.offset + self.width * np.arange(self.bins + 1)
165
166
167class ScatterPlotItem(pg.ScatterPlotItem):
168    """
169    Modifies the behaviour of ScatterPlotItem as follows:
170
171    - Add z-index. ScatterPlotItem paints points in order of appearance in
172      self.data. Plotting by z-index is achieved by sorting before calling
173      super().paint() and re-sorting afterwards. Re-sorting (instead of
174      storing the original data) is needed because the inherited paint
175      may modify the data.
176
177    - Prevent multiple calls to updateSpots. ScatterPlotItem calls updateSpots
178      at any change of sizes/colors/symbols, which then rebuilds the stored
179      pixmaps for each symbol. Orange calls set* functions in succession,
180      so we postpone updateSpots() to paint()."""
181
182    def __init__(self, *args, **kwargs):
183        super().__init__(*args, **kwargs)
184        self._update_spots_in_paint = False
185        self._z_mapping = None
186        self._inv_mapping = None
187
188    def setZ(self, z):
189        """
190        Set z values for all points.
191
192        Points with higher values are plotted on top of those with lower.
193
194        Args:
195            z (np.ndarray or None): a vector of z values
196        """
197        if z is None:
198            self._z_mapping = self._inv_mapping = None
199        else:
200            assert len(z) == len(self.data)
201            self._z_mapping = np.argsort(z)
202            self._inv_mapping = np.argsort(self._z_mapping)
203
204    def setCoordinates(self, x, y):
205        """
206        Change the coordinates of points while keeping other properties.
207
208        Asserts that the number of points stays the same.
209
210        Note. Pyqtgraph does not offer a method for this: setting coordinates
211        invalidates other data. We therefore retrieve the data to set it
212        together with the coordinates. Pyqtgraph also does not offer a
213        (documented) method for retrieving the data, yet using
214        data[prop]` looks reasonably safe.
215
216        The alternative, updating the whole scatterplot from the Orange Table,
217        is too slow.
218        """
219        assert len(self.data) == len(x) == len(y)
220        data = dict(x=x, y=y)
221        for prop in ('pen', 'brush', 'size', 'symbol', 'data'):
222            data[prop] = self.data[prop]
223        self.setData(**data)
224
225    def updateSpots(self, dataSet=None):  # pylint: disable=unused-argument
226        self._update_spots_in_paint = True
227        self.update()
228
229    # pylint: disable=arguments-differ
230    def paint(self, painter, option, widget=None):
231        try:
232            if self._z_mapping is not None:
233                assert len(self._z_mapping) == len(self.data)
234                self.data = self.data[self._z_mapping]
235            if self._update_spots_in_paint:
236                self._update_spots_in_paint = False
237                super().updateSpots()
238            painter.setRenderHint(QPainter.SmoothPixmapTransform, True)
239            super().paint(painter, option, widget)
240        finally:
241            if self._inv_mapping is not None:
242                self.data = self.data[self._inv_mapping]
243
244
245def _define_symbols():
246    """
247    Add symbol ? to ScatterPlotItemSymbols,
248    reflect the triangle to point upwards
249    """
250    path = QPainterPath()
251    path.addEllipse(QRectF(-0.35, -0.35, 0.7, 0.7))
252    path.moveTo(-0.5, 0.5)
253    path.lineTo(0.5, -0.5)
254    path.moveTo(-0.5, -0.5)
255    path.lineTo(0.5, 0.5)
256    Symbols["?"] = path
257
258    path = QPainterPath()
259    plusCoords = [
260        (-0.5, -0.1), (-0.5, 0.1), (-0.1, 0.1), (-0.1, 0.5),
261        (0.1, 0.5), (0.1, 0.1), (0.5, 0.1), (0.5, -0.1),
262        (0.1, -0.1), (0.1, -0.5), (-0.1, -0.5), (-0.1, -0.1)
263    ]
264    path.moveTo(*plusCoords[0])
265    for x, y in plusCoords[1:]:
266        path.lineTo(x, y)
267    path.closeSubpath()
268    Symbols["+"] = path
269
270    tr = QTransform()
271    tr.rotate(180)
272    Symbols['t'] = tr.map(Symbols['t'])
273
274    tr = QTransform()
275    tr.rotate(45)
276    Symbols['x'] = tr.map(Symbols["+"])
277
278
279_define_symbols()
280
281
282def _make_pen(color, width):
283    p = QPen(color, width)
284    p.setCosmetic(True)
285    return p
286
287
288class AxisItem(AxisItem):
289    """
290    Axis that if needed displays ticks appropriate for time data.
291    """
292
293    _label_width = 80
294
295    def __init__(self, *args, **kwargs):
296        super().__init__(*args, **kwargs)
297        self._use_time = False
298
299    def use_time(self, enable):
300        """Enables axes to display ticks for time data."""
301        self._use_time = enable
302        self.enableAutoSIPrefix(not enable)
303
304    def tickValues(self, minVal, maxVal, size):
305        """Find appropriate tick locations."""
306        if not self._use_time:
307            return super().tickValues(minVal, maxVal, size)
308
309        # if timezone is not set, then local is used which cause exceptions
310        minVal = max(minVal,
311                     datetime.min.replace(tzinfo=timezone.utc).timestamp() + 1)
312        maxVal = min(maxVal,
313                     datetime.max.replace(tzinfo=timezone.utc).timestamp() - 1)
314        mn = utc_from_timestamp(minVal).timetuple()
315        mx = utc_from_timestamp(maxVal).timetuple()
316        try:
317            bins = _time_binnings(mn, mx, 6, 30)[-1]
318        except (IndexError, ValueError):
319            # cannot handle very large and very small time intervals
320            return super().tickValues(minVal, maxVal, size)
321
322        ticks = bins.thresholds
323
324        max_steps = max(int(size / self._label_width), 1)
325        if len(ticks) > max_steps:
326            # remove some of ticks so that they don't overlap
327            step = int(np.ceil(float(len(ticks)) / max_steps))
328            ticks = ticks[::step]
329
330        spacing = min(b - a for a, b in zip(ticks[:-1], ticks[1:]))
331        return [(spacing, ticks)]
332
333    def tickStrings(self, values, scale, spacing):
334        """Format tick values according to space between them."""
335        if not self._use_time:
336            return super().tickStrings(values, scale, spacing)
337
338        if spacing >= 3600 * 24 * 365:
339            fmt = "%Y"
340        elif spacing >= 3600 * 24 * 28:
341            fmt = "%Y %b"
342        elif spacing >= 3600 * 24:
343            fmt = "%Y %b %d"
344        elif spacing >= 3600:
345            min_day = max_day = 1
346            if len(values) > 0:
347                min_day = datetime.fromtimestamp(
348                    min(values), tz=timezone.utc).day
349                max_day = datetime.fromtimestamp(
350                    max(values), tz=timezone.utc).day
351            if min_day == max_day:
352                fmt = "%Hh"
353            else:
354                fmt = "%d %Hh"
355        elif spacing >= 60:
356            fmt = "%H:%M"
357        elif spacing >= 1:
358            fmt = "%H:%M:%S"
359        else:
360            fmt = '%S.%f'
361
362        return [utc_from_timestamp(x).strftime(fmt) for x in values]
363
364
365class ScatterBaseParameterSetter(CommonParameterSetter):
366    CAT_LEGEND_LABEL = "Categorical legend"
367    NUM_LEGEND_LABEL = "Numerical legend"
368    NUM_LEGEND_SETTING = {
369        Updater.SIZE_LABEL: (range(4, 50), 11),
370        Updater.IS_ITALIC_LABEL: (None, False),
371    }
372
373    def __init__(self, master):
374        super().__init__()
375        self.master = master
376        self.cat_legend_settings = {}
377        self.num_legend_settings = {}
378
379    def update_setters(self):
380        self.initial_settings = {
381            self.LABELS_BOX: {
382                self.FONT_FAMILY_LABEL: self.FONT_FAMILY_SETTING,
383                self.TITLE_LABEL: self.FONT_SETTING,
384                self.LABEL_LABEL: self.FONT_SETTING,
385                self.CAT_LEGEND_LABEL: self.FONT_SETTING,
386                self.NUM_LEGEND_LABEL: self.NUM_LEGEND_SETTING,
387            },
388            self.ANNOT_BOX: {
389                self.TITLE_LABEL: {self.TITLE_LABEL: ("", "")},
390            }
391        }
392
393        def update_cat_legend(**settings):
394            self.cat_legend_settings.update(**settings)
395            Updater.update_legend_font(self.cat_legend_items, **settings)
396
397        def update_num_legend(**settings):
398            self.num_legend_settings.update(**settings)
399            Updater.update_num_legend_font(self.num_legend, **settings)
400
401        labels = self.LABELS_BOX
402        self._setters[labels][self.CAT_LEGEND_LABEL] = update_cat_legend
403        self._setters[labels][self.NUM_LEGEND_LABEL] = update_num_legend
404
405    @property
406    def title_item(self):
407        return self.master.plot_widget.getPlotItem().titleLabel
408
409    @property
410    def cat_legend_items(self):
411        items = self.master.color_legend.items
412        if items and items[0] and isinstance(items[0][0], PaletteItemSample):
413            items = []
414        return itertools.chain(self.master.shape_legend.items, items)
415
416    @property
417    def num_legend(self):
418        items = self.master.color_legend.items
419        if items and items[0] and isinstance(items[0][0], PaletteItemSample):
420            return self.master.color_legend
421        return None
422
423    @property
424    def labels(self):
425        return self.master.labels
426
427
428class OWScatterPlotBase(gui.OWComponent, QObject):
429    """
430    Provide a graph component for widgets that show any kind of point plot
431
432    The component plots a set of points with given coordinates, shapes,
433    sizes and colors. Its function is similar to that of a *view*, whereas
434    the widget represents a *model* and a *controler*.
435
436    The model (widget) needs to provide methods:
437
438    - `get_coordinates_data`, `get_size_data`, `get_color_data`,
439      `get_shape_data`, `get_label_data`, which return a 1d array (or two
440      arrays, for `get_coordinates_data`) of `dtype` `float64`, except for
441      `get_label_data`, which returns formatted labels;
442    - `get_shape_labels` returns a list of strings for shape legend
443    - `get_color_labels` returns strings for color legend, or a function for
444       formatting numbers if the legend is continuous, or None for default
445       formatting
446    - `get_tooltip`, which gives a tooltip for a single data point
447    - (optional) `impute_sizes`, `impute_shapes` get final coordinates and
448      shapes, and replace nans;
449    - `get_subset_mask` returns a bool array indicating whether a
450      data point is in the subset or not (e.g. in the 'Data Subset' signal
451      in the Scatter plot and similar widgets);
452    - `get_palette` returns a palette appropriate for visualizing the
453      current color data;
454    - `is_continuous_color` decides the type of the color legend;
455
456    The widget (in a role of controller) must also provide methods
457    - `selection_changed`
458
459    If `get_coordinates_data` returns `(None, None)`, the plot is cleared. If
460    `get_size_data`, `get_color_data` or `get_shape_data` return `None`,
461    all points will have the same size, color or shape, respectively.
462    If `get_label_data` returns `None`, there are no labels.
463
464    The view (this compomnent) provides methods `update_coordinates`,
465    `update_sizes`, `update_colors`, `update_shapes` and `update_labels`
466    that the widget (in a role of a controler) should call when any of
467    these properties are changed. If the widget calls, for instance, the
468    plot's `update_colors`, the plot will react by calling the widget's
469    `get_color_data` as well as the widget's methods needed to construct the
470    legend.
471
472    The view also provides a method `reset_graph`, which should be called only
473    when
474    - the widget gets entirely new data
475    - the number of points may have changed, for instance when selecting
476    a different attribute for x or y in the scatter plot, where the points
477    with missing x or y coordinates are hidden.
478
479    Every `update_something` calls the plot's `get_something`, which
480    calls the model's `get_something_data`, then it transforms this data
481    into whatever is needed (colors, shapes, scaled sizes) and changes the
482    plot. For the simplest example, here is `update_shapes`:
483
484    ```
485        def update_shapes(self):
486            if self.scatterplot_item:
487                shape_data = self.get_shapes()
488                self.scatterplot_item.setSymbol(shape_data)
489            self.update_legends()
490
491        def get_shapes(self):
492            shape_data = self.master.get_shape_data()
493            shape_data = self.master.impute_shapes(
494                shape_data, len(self.CurveSymbols) - 1)
495            return self.CurveSymbols[shape_data]
496    ```
497
498    On the widget's side, `get_something_data` is essentially just:
499
500    ```
501        def get_size_data(self):
502            return self.get_column(self.attr_size)
503    ```
504
505    where `get_column` retrieves a column while also filtering out the
506    points with missing x and y and so forth. (Here we present the simplest
507    two cases, "shapes" for the view and "sizes" for the model. The colors
508    for the view are more complicated since they deal with discrete and
509    continuous palettes, and the shapes for the view merge infrequent shapes.)
510
511    The plot can also show just a random sample of the data. The sample size is
512    set by `set_sample_size`, and the rest is taken care by the plot: the
513    widget keeps providing the data for all points, selection indices refer
514    to the entire set etc. Internally, sampling happens as early as possible
515    (in methods `get_<something>`).
516    """
517    too_many_labels = Signal(bool)
518    begin_resizing = Signal()
519    step_resizing = Signal()
520    end_resizing = Signal()
521
522    label_only_selected = Setting(False)
523    point_width = Setting(10)
524    alpha_value = Setting(128)
525    show_grid = Setting(False)
526    show_legend = Setting(True)
527    class_density = Setting(False)
528    jitter_size = Setting(0)
529
530    resolution = 256
531
532    CurveSymbols = np.array("o x t + d star ?".split())
533    MinShapeSize = 6
534    DarkerValue = 120
535    UnknownColor = (168, 50, 168)
536
537    COLOR_NOT_SUBSET = (128, 128, 128, 0)
538    COLOR_SUBSET = (128, 128, 128, 255)
539    COLOR_DEFAULT = (128, 128, 128, 255)
540
541    MAX_VISIBLE_LABELS = 500
542
543    def __init__(self, scatter_widget, parent=None, view_box=ViewBox):
544        QObject.__init__(self)
545        gui.OWComponent.__init__(self, scatter_widget)
546
547        self.subset_is_shown = False
548        self.jittering_suspended = False
549
550        self.view_box = view_box(self)
551        _axis = {"left": AxisItem("left"), "bottom": AxisItem("bottom")}
552        self.plot_widget = pg.PlotWidget(viewBox=self.view_box, parent=parent,
553                                         background="w", axisItems=_axis)
554        self.plot_widget.hideAxis("left")
555        self.plot_widget.hideAxis("bottom")
556        self.plot_widget.getPlotItem().buttonsHidden = True
557        self.plot_widget.setAntialiasing(True)
558        self.plot_widget.sizeHint = lambda: QSize(500, 500)
559
560        self.density_img = None
561        self.scatterplot_item = None
562        self.scatterplot_item_sel = None
563        self.labels = []
564
565        self.master = scatter_widget
566        tooltip = self._create_drag_tooltip()
567        self.view_box.setDragTooltip(tooltip)
568
569        self.selection = None  # np.ndarray
570
571        self.n_valid = 0
572        self.n_shown = 0
573        self.sample_size = None
574        self.sample_indices = None
575
576        self.palette = None
577
578        self.shape_legend = self._create_legend(((1, 0), (1, 0)))
579        self.color_legend = self._create_legend(((1, 1), (1, 1)))
580        self.update_legend_visibility()
581
582        self.scale = None  # DiscretizedScale
583        self._too_many_labels = False
584
585        # self.setMouseTracking(True)
586        # self.grabGesture(QPinchGesture)
587        # self.grabGesture(QPanGesture)
588
589        self.update_grid_visibility()
590
591        self._tooltip_delegate = EventDelegate(self.help_event)
592        self.plot_widget.scene().installEventFilter(self._tooltip_delegate)
593        self.view_box.sigTransformChanged.connect(self.update_density)
594        self.view_box.sigRangeChangedManually.connect(self.update_labels)
595
596        self.timer = None
597
598        self.parameter_setter = ScatterBaseParameterSetter(self)
599
600    def _create_legend(self, anchor):
601        legend = LegendItem()
602        legend.setParentItem(self.plot_widget.getViewBox())
603        legend.restoreAnchor(anchor)
604        return legend
605
606    def _create_drag_tooltip(self):
607        tip_parts = [
608            (Qt.ControlModifier,
609             "{}: Append to group".
610             format("Cmd" if sys.platform == "darwin" else "Ctrl")),
611            (Qt.ShiftModifier, "Shift: Add group"),
612            (Qt.AltModifier, "Alt: Remove")
613        ]
614        all_parts = "<center>" + \
615                    ", ".join(part for _, part in tip_parts) + \
616                    "</center>"
617        self.tiptexts = {
618            modifier: all_parts.replace(part, "<b>{}</b>".format(part))
619            for modifier, part in tip_parts
620        }
621        self.tiptexts[Qt.NoModifier] = all_parts
622
623        self.tip_textitem = text = QGraphicsTextItem()
624        # Set to the longest text
625        text.setHtml(self.tiptexts[Qt.ControlModifier])
626        text.setPos(4, 2)
627        r = text.boundingRect()
628        text.setTextWidth(r.width())
629        rect = QGraphicsRectItem(0, 0, r.width() + 8, r.height() + 4)
630        rect.setBrush(QColor(224, 224, 224, 212))
631        rect.setPen(QPen(Qt.NoPen))
632        self.update_tooltip()
633
634        tooltip_group = QGraphicsItemGroup()
635        tooltip_group.addToGroup(rect)
636        tooltip_group.addToGroup(text)
637        return tooltip_group
638
639    def update_tooltip(self, modifiers=Qt.NoModifier):
640        text = self.tiptexts[Qt.NoModifier]
641        for mod in [Qt.ControlModifier,
642                    Qt.ShiftModifier,
643                    Qt.AltModifier]:
644            if modifiers & mod:
645                text = self.tiptexts.get(mod)
646                break
647        self.tip_textitem.setHtml(text)
648
649    def suspend_jittering(self):
650        if self.jittering_suspended:
651            return
652        self.jittering_suspended = True
653        if self.jitter_size != 0:
654            self.update_jittering()
655
656    def unsuspend_jittering(self):
657        if not self.jittering_suspended:
658            return
659        self.jittering_suspended = False
660        if self.jitter_size != 0:
661            self.update_jittering()
662
663    def update_jittering(self):
664        x, y = self.get_coordinates()
665        if x is None or len(x) == 0 or self.scatterplot_item is None:
666            return
667        self.scatterplot_item.setCoordinates(x, y)
668        self.scatterplot_item_sel.setCoordinates(x, y)
669        self.update_labels()
670
671    # TODO: Rename to remove_plot_items
672    def clear(self):
673        """
674        Remove all graphical elements from the plot
675
676        Calls the pyqtgraph's plot widget's clear, sets all handles to `None`,
677        removes labels and selections.
678
679        This method should generally not be called by the widget. If the data
680        is gone (*e.g.* upon receiving `None` as an input data signal), this
681        should be handler by calling `reset_graph`, which will in turn call
682        `clear`.
683
684        Derived classes should override this method if they add more graphical
685        elements. For instance, the regression line in the scatterplot adds
686        `self.reg_line_item = None` (the line in the plot is already removed
687        in this method).
688        """
689        self.plot_widget.clear()
690
691        self.density_img = None
692        if self.timer is not None and self.timer.isActive():
693            self.timer.stop()
694            self.timer = None
695        self.scatterplot_item = None
696        self.scatterplot_item_sel = None
697        self.labels = []
698        self._signal_too_many_labels(False)
699        self.view_box.init_history()
700        self.view_box.tag_history()
701
702    # TODO: I hate `keep_something` and `reset_something` arguments
703    # __keep_selection is used exclusively be set_sample size which would
704    # otherwise just repeat the code from reset_graph except for resetting
705    # the selection. I'm uncomfortable with this; we may prefer to have a
706    # method _reset_graph which does everything except resetting the selection,
707    # and reset_graph would call it.
708    def reset_graph(self, __keep_selection=False):
709        """
710        Reset the graph to new data (or no data)
711
712        The method must be called when the plot receives new data, in
713        particular when the number of points change. If only their properties
714        - like coordinates or shapes - change, an update method
715        (`update_coordinates`, `update_shapes`...) should be called instead.
716
717        The method must also be called when the data is gone.
718
719        The method calls `clear`, followed by calls of all update methods.
720
721        NB. Argument `__keep_selection` is for internal use only
722        """
723        self.clear()
724        if not __keep_selection:
725            self.selection = None
726        self.sample_indices = None
727        self.update_coordinates()
728        self.update_point_props()
729
730    def set_sample_size(self, sample_size):
731        """
732        Set the sample size
733
734        Args:
735            sample_size (int or None): sample size or `None` to show all points
736        """
737        if self.sample_size != sample_size:
738            self.sample_size = sample_size
739            self.reset_graph(True)
740
741    def update_point_props(self):
742        """
743        Update the sizes, colors, shapes and labels
744
745        The method calls the appropriate update methods for individual
746        properties.
747        """
748        self.update_sizes()
749        self.update_colors()
750        self.update_selection_colors()
751        self.update_shapes()
752        self.update_labels()
753
754    # Coordinates
755    # TODO: It could be nice if this method was run on entire data, not just
756    # a sample. For this, however, it would need to either be called from
757    # `get_coordinates` before sampling (very ugly) or call
758    # `self.master.get_coordinates_data` (beyond ugly) or the widget would
759    # have to store the ranges of unsampled data (ugly).
760    # Maybe we leave it as it is.
761    def _reset_view(self, x_data, y_data):
762        """
763        Set the range of the view box
764
765        Args:
766            x_data (np.ndarray): x coordinates
767            y_data (np.ndarray) y coordinates
768        """
769        min_x, max_x = np.min(x_data), np.max(x_data)
770        min_y, max_y = np.min(y_data), np.max(y_data)
771        self.view_box.setRange(
772            QRectF(min_x, min_y, max_x - min_x or 1, max_y - min_y or 1),
773            padding=0.025)
774
775    def _filter_visible(self, data):
776        """Return the sample from the data using the stored sample_indices"""
777        if data is None or self.sample_indices is None:
778            return data
779        else:
780            return np.asarray(data[self.sample_indices])
781
782    def get_coordinates(self):
783        """
784        Prepare coordinates of the points in the plot
785
786        The method is called by `update_coordinates`. It gets the coordinates
787        from the widget, jitters them and return them.
788
789        The methods also initializes the sample indices if neededd and stores
790        the original and sampled number of points.
791
792        Returns:
793            (tuple): a pair of numpy arrays containing (sampled) coordinates,
794                or `(None, None)`.
795        """
796        x, y = self.master.get_coordinates_data()
797        if x is None:
798            self.n_valid = self.n_shown = 0
799            return None, None
800        self.n_valid = len(x)
801        self._create_sample()
802        x = self._filter_visible(x)
803        y = self._filter_visible(y)
804        # Jittering after sampling is OK if widgets do not change the sample
805        # semi-permanently, e.g. take a sample for the duration of some
806        # animation. If the sample size changes dynamically (like by adding
807        # a "sample size" slider), points would move around when the sample
808        # size changes. To prevent this, jittering should be done before
809        # sampling (i.e. two lines earlier). This would slow it down somewhat.
810        x, y = self.jitter_coordinates(x, y)
811        return x, y
812
813    def _create_sample(self):
814        """
815        Create a random sample if the data is larger than the set sample size
816        """
817        self.n_shown = min(self.n_valid, self.sample_size or self.n_valid)
818        if self.sample_size is not None \
819                and self.sample_indices is None \
820                and self.n_valid != self.n_shown:
821            random = np.random.RandomState(seed=0)
822            self.sample_indices = random.choice(
823                self.n_valid, self.n_shown, replace=False)
824            # TODO: Is this really needed?
825            np.sort(self.sample_indices)
826
827    def jitter_coordinates(self, x, y):
828        """
829        Display coordinates to random positions within ellipses with
830        radiuses of `self.jittter_size` percents of spans
831        """
832        if self.jitter_size == 0 or self.jittering_suspended:
833            return x, y
834        return self._jitter_data(x, y)
835
836    def _jitter_data(self, x, y, span_x=None, span_y=None):
837        if span_x is None:
838            span_x = np.max(x) - np.min(x)
839        if span_y is None:
840            span_y = np.max(y) - np.min(y)
841        random = np.random.RandomState(seed=0)
842        rs = random.uniform(0, 1, len(x))
843        phis = random.uniform(0, 2 * np.pi, len(x))
844        magnitude = self.jitter_size / 100
845        return (x + magnitude * span_x * rs * np.cos(phis),
846                y + magnitude * span_y * rs * np.sin(phis))
847
848    def update_coordinates(self):
849        """
850        Trigger the update of coordinates while keeping other features intact.
851
852        The method gets the coordinates by calling `self.get_coordinates`,
853        which in turn calls the widget's `get_coordinate_data`. The number of
854        coordinate pairs returned by the latter must match the current number
855        of points. If this is not the case, the widget should trigger
856        the complete update by calling `reset_graph` instead of this method.
857        """
858        x, y = self.get_coordinates()
859        if x is None or len(x) == 0:
860            return
861
862        self._reset_view(x, y)
863        if self.scatterplot_item is None:
864            if self.sample_indices is None:
865                indices = np.arange(self.n_valid)
866            else:
867                indices = self.sample_indices
868            kwargs = dict(x=x, y=y, data=indices)
869            self.scatterplot_item = ScatterPlotItem(**kwargs)
870            self.scatterplot_item.sigClicked.connect(self.select_by_click)
871            self.scatterplot_item_sel = ScatterPlotItem(**kwargs)
872            self.plot_widget.addItem(self.scatterplot_item_sel)
873            self.plot_widget.addItem(self.scatterplot_item)
874        else:
875            self.scatterplot_item.setCoordinates(x, y)
876            self.scatterplot_item_sel.setCoordinates(x, y)
877            self.update_labels()
878
879        self.update_density()  # Todo: doesn't work: try MDS with density on
880
881    # Sizes
882    def get_sizes(self):
883        """
884        Prepare data for sizes of points in the plot
885
886        The method is called by `update_sizes`. It gets the sizes
887        from the widget and performs the necessary scaling and sizing.
888        The output is rounded to half a pixel for faster drawing.
889
890        Returns:
891            (np.ndarray): sizes
892        """
893        size_column = self.master.get_size_data()
894        if size_column is None:
895            return np.full((self.n_shown,),
896                           self.MinShapeSize + (5 + self.point_width) * 0.5)
897        size_column = self._filter_visible(size_column)
898        size_column = size_column.copy()
899        with warnings.catch_warnings():
900            warnings.simplefilter("ignore", category=RuntimeWarning)
901            size_column -= np.nanmin(size_column)
902            mx = np.nanmax(size_column)
903        if mx > 0:
904            size_column /= mx
905        else:
906            size_column[:] = 0.5
907
908        sizes = self.MinShapeSize + (5 + self.point_width) * size_column
909        # round sizes to half pixel for smaller pyqtgraph's symbol pixmap atlas
910        sizes = (sizes * 2).round() / 2
911        return sizes
912
913    def update_sizes(self):
914        """
915        Trigger an update of point sizes
916
917        The method calls `self.get_sizes`, which in turn calls the widget's
918        `get_size_data`. The result are properly scaled and then passed
919        back to widget for imputing (`master.impute_sizes`).
920        """
921        if self.scatterplot_item:
922            size_data = self.get_sizes()
923            size_imputer = getattr(
924                self.master, "impute_sizes", self.default_impute_sizes)
925            size_imputer(size_data)
926
927            if self.timer is not None and self.timer.isActive():
928                self.timer.stop()
929                self.timer = None
930
931            current_size_data = self.scatterplot_item.data["size"].copy()
932            diff = size_data - current_size_data
933            widget = self
934
935            class Timeout:
936                # 0.5 - np.cos(np.arange(0.17, 1, 0.17) * np.pi) / 2
937                factors = [0.07, 0.26, 0.52, 0.77, 0.95, 1]
938
939                def __init__(self):
940                    self._counter = 0
941
942                def __call__(self):
943                    factor = self.factors[self._counter]
944                    self._counter += 1
945                    size = current_size_data + diff * factor
946                    if len(self.factors) == self._counter:
947                        widget.timer.stop()
948                        widget.timer = None
949                        size = size_data
950                    widget.scatterplot_item.setSize(size)
951                    widget.scatterplot_item_sel.setSize(size + SELECTION_WIDTH)
952                    if widget.timer is None:
953                        widget.end_resizing.emit()
954                    else:
955                        widget.step_resizing.emit()
956
957            if self.n_valid <= MAX_N_VALID_SIZE_ANIMATE and \
958                    np.all(current_size_data > 0) and np.any(diff != 0):
959                # If encountered any strange behaviour when updating sizes,
960                # implement it with threads
961                self.begin_resizing.emit()
962                self.timer = QTimer(self.scatterplot_item, interval=50)
963                self.timer.timeout.connect(Timeout())
964                self.timer.start()
965            else:
966                self.begin_resizing.emit()
967                self.scatterplot_item.setSize(size_data)
968                self.scatterplot_item_sel.setSize(size_data + SELECTION_WIDTH)
969                self.end_resizing.emit()
970
971    update_point_size = update_sizes  # backward compatibility (needed?!)
972    update_size = update_sizes
973
974    @classmethod
975    def default_impute_sizes(cls, size_data):
976        """
977        Fallback imputation for sizes.
978
979        Set the size to two pixels smaller than the minimal size
980
981        Returns:
982            (bool): True if there was any missing data
983        """
984        nans = np.isnan(size_data)
985        if np.any(nans):
986            size_data[nans] = cls.MinShapeSize - 2
987            return True
988        else:
989            return False
990
991    # Colors
992    def get_colors(self):
993        """
994        Prepare data for colors of the points in the plot
995
996        The method is called by `update_colors`. It gets the colors and the
997        indices of the data subset from the widget (`get_color_data`,
998        `get_subset_mask`), and constructs lists of pens and brushes for
999        each data point.
1000
1001        The method uses different palettes for discrete and continuous data,
1002        as determined by calling the widget's method `is_continuous_color`.
1003
1004        If also marks the points that are in the subset as defined by, for
1005        instance the 'Data Subset' signal in the Scatter plot and similar
1006        widgets. (Do not confuse this with *selected points*, which are
1007        marked by circles around the points, which are colored by groups
1008        and thus independent of this method.)
1009
1010        Returns:
1011            (tuple): a list of pens and list of brushes
1012        """
1013        c_data = self.master.get_color_data()
1014        c_data = self._filter_visible(c_data)
1015        subset = self.master.get_subset_mask()
1016        subset = self._filter_visible(subset)
1017        self.subset_is_shown = subset is not None
1018        if c_data is None:  # same color
1019            self.palette = None
1020            return self._get_same_colors(subset)
1021        elif self.master.is_continuous_color():
1022            return self._get_continuous_colors(c_data, subset)
1023        else:
1024            return self._get_discrete_colors(c_data, subset)
1025
1026    def _get_same_colors(self, subset):
1027        """
1028        Return the same pen for all points while the brush color depends
1029        upon whether the point is in the subset or not
1030
1031        Args:
1032            subset (np.ndarray): a bool array indicating whether a data point
1033                is in the subset or not (e.g. in the 'Data Subset' signal
1034                in the Scatter plot and similar widgets);
1035
1036        Returns:
1037            (tuple): a list of pens and list of brushes
1038        """
1039        color = self.plot_widget.palette().color(OWPalette.Data)
1040        pen = [_make_pen(color, 1.5)] * self.n_shown  # use a single QPen instance
1041
1042        # Prepare all brushes; we use the first two or the last
1043        brushes = []
1044        for c in (self.COLOR_SUBSET, self.COLOR_NOT_SUBSET, self.COLOR_DEFAULT):
1045            color = QColor(*c)
1046            if color.alpha():
1047                color.setAlpha(self.alpha_value)
1048            brushes.append(QBrush(color))
1049
1050        if subset is not None:
1051            brush = np.where(subset, *brushes[:2])
1052        else:
1053            brush = brushes[-1:] * self.n_shown  # use a single QBrush instance
1054        return pen, brush
1055
1056    def _get_continuous_colors(self, c_data, subset):
1057        """
1058        Return the pens and colors whose color represent an index into
1059        a continuous palette. The same color is used for pen and brush,
1060        except the former is darker. If the data has a subset, the brush
1061        is transparent for points that are not in the subset.
1062        """
1063        palette = self.master.get_palette()
1064
1065        if np.isnan(c_data).all():
1066            self.palette = palette
1067            return self._get_continuous_nan_colors(len(c_data))
1068
1069        self.scale = DiscretizedScale(np.nanmin(c_data), np.nanmax(c_data))
1070        bins = self.scale.get_bins()
1071        self.palette = \
1072            colorpalettes.BinnedContinuousPalette.from_palette(palette, bins)
1073        colors = self.palette.values_to_colors(c_data)
1074        brush = np.hstack(
1075            (colors,
1076             np.full((len(c_data), 1), self.alpha_value, dtype=np.ubyte)))
1077        pen = (colors.astype(dtype=float) * 100 / self.DarkerValue
1078               ).astype(np.ubyte)
1079
1080        # Reuse pens and brushes with the same colors because PyQtGraph then
1081        # builds smaller pixmap atlas, which makes the drawing faster
1082
1083        def reuse(cache, fun, *args):
1084            if args not in cache:
1085                cache[args] = fun(args)
1086            return cache[args]
1087
1088        def create_pen(col):
1089            return _make_pen(QColor(*col), 1.5)
1090
1091        def create_brush(col):
1092            return QBrush(QColor(*col))
1093
1094        cached_pens = {}
1095        pen = [reuse(cached_pens, create_pen, *col) for col in pen.tolist()]
1096
1097        if subset is not None:
1098            brush[:, 3] = 0
1099            brush[subset, 3] = self.alpha_value
1100
1101        cached_brushes = {}
1102        brush = np.array([reuse(cached_brushes, create_brush, *col)
1103                          for col in brush.tolist()])
1104
1105        return pen, brush
1106
1107    def _get_continuous_nan_colors(self, n):
1108        nan_color = QColor(*self.palette.nan_color)
1109        nan_pen = _make_pen(nan_color.darker(1.2), 1.5)
1110        pen = np.full(n, nan_pen)
1111        nan_brush = QBrush(nan_color)
1112        brush = np.full(n, nan_brush)
1113        return pen, brush
1114
1115    def _get_discrete_colors(self, c_data, subset):
1116        """
1117        Return the pens and colors whose color represent an index into
1118        a discrete palette. The same color is used for pen and brush,
1119        except the former is darker. If the data has a subset, the brush
1120        is transparent for points that are not in the subset.
1121        """
1122        self.palette = self.master.get_palette()
1123        c_data = c_data.copy()
1124        c_data[np.isnan(c_data)] = len(self.palette)
1125        c_data = c_data.astype(int)
1126        colors = self.palette.qcolors_w_nan
1127        pens = np.array(
1128            [_make_pen(col.darker(self.DarkerValue), 1.5) for col in colors])
1129        pen = pens[c_data]
1130        if self.alpha_value < 255:
1131            for col in colors:
1132                col.setAlpha(self.alpha_value)
1133        brushes = np.array([QBrush(col) for col in colors])
1134        brush = brushes[c_data]
1135
1136        if subset is not None:
1137            black = np.full(len(brush), QBrush(QColor(0, 0, 0, 0)))
1138            brush = np.where(subset, brush, black)
1139        return pen, brush
1140
1141    def update_colors(self):
1142        """
1143        Trigger an update of point colors
1144
1145        The method calls `self.get_colors`, which in turn calls the widget's
1146        `get_color_data` to get the indices in the pallette. `get_colors`
1147        returns a list of pens and brushes to which this method uses to
1148        update the colors. Finally, the method triggers the update of the
1149        legend and the density plot.
1150        """
1151        if self.scatterplot_item is not None:
1152            pen_data, brush_data = self.get_colors()
1153            self.scatterplot_item.setPen(pen_data, update=False, mask=None)
1154            self.scatterplot_item.setBrush(brush_data, mask=None)
1155        self.update_z_values()
1156        self.update_legends()
1157        self.update_density()
1158
1159    update_alpha_value = update_colors
1160
1161    def update_density(self):
1162        """
1163        Remove the existing density plot (if there is one) and replace it
1164        with a new one (if enabled).
1165
1166        The method gets the colors from the pens of the currently plotted
1167        points.
1168        """
1169        if self.density_img:
1170            self.plot_widget.removeItem(self.density_img)
1171            self.density_img = None
1172        if self.class_density and self.scatterplot_item is not None:
1173            c_data = self.master.get_color_data()
1174            if c_data is None:
1175                return
1176            visible_c_data = self._filter_visible(c_data)
1177            mask = np.bitwise_and(np.isfinite(visible_c_data),
1178                                  visible_c_data < MAX_COLORS - 1)
1179            pens = self.scatterplot_item.data['pen']
1180            rgb_data = [
1181                pen.color().getRgb()[:3] if pen is not None else (255, 255, 255)
1182                for known, pen in zip(mask, pens)
1183                if known]
1184            if len(set(rgb_data)) <= 1:
1185                return
1186            [min_x, max_x], [min_y, max_y] = self.view_box.viewRange()
1187            x_data, y_data = self.scatterplot_item.getData()
1188            self.density_img = classdensity.class_density_image(
1189                min_x, max_x, min_y, max_y, self.resolution,
1190                x_data[mask], y_data[mask], rgb_data)
1191            self.plot_widget.addItem(self.density_img, ignoreBounds=True)
1192
1193    def update_selection_colors(self):
1194        """
1195        Trigger an update of selection markers
1196
1197        This update method is usually not called by the widget but by the
1198        plot, since it is the plot that handles the selections.
1199
1200        Like other update methods, it calls the corresponding get method
1201        (`get_colors_sel`) which returns a list of pens and brushes.
1202        """
1203        if self.scatterplot_item_sel is None:
1204            return
1205        pen, brush = self.get_colors_sel()
1206        self.scatterplot_item_sel.setPen(pen, update=False, mask=None)
1207        self.scatterplot_item_sel.setBrush(brush, mask=None)
1208        self.update_z_values()
1209
1210    def get_colors_sel(self):
1211        """
1212        Return pens and brushes for selection markers.
1213
1214        A pen can is set to `Qt.NoPen` if a point is not selected.
1215
1216        All brushes are completely transparent whites.
1217
1218        Returns:
1219            (tuple): a list of pens and a list of brushes
1220        """
1221        nopen = QPen(Qt.NoPen)
1222        if self.selection is None:
1223            pen = [nopen] * self.n_shown
1224        else:
1225            sels = np.max(self.selection)
1226            if sels == 1:
1227                pen = np.where(
1228                    self._filter_visible(self.selection),
1229                    _make_pen(QColor(255, 190, 0, 255), SELECTION_WIDTH),
1230                    nopen)
1231            else:
1232                palette = colorpalettes.LimitedDiscretePalette(
1233                    number_of_colors=sels + 1)
1234                pen = np.choose(
1235                    self._filter_visible(self.selection),
1236                    [nopen] + [_make_pen(palette[i], SELECTION_WIDTH)
1237                               for i in range(sels)])
1238        return pen, [QBrush(QColor(255, 255, 255, 0))] * self.n_shown
1239
1240    # Labels
1241    def get_labels(self):
1242        """
1243        Prepare data for labels for points
1244
1245        The method returns the results of the widget's `get_label_data`
1246
1247        Returns:
1248            (labels): a sequence of labels
1249        """
1250        return self._filter_visible(self.master.get_label_data())
1251
1252    def update_labels(self):
1253        """
1254        Trigger an update of labels
1255
1256        The method calls `get_labels` which in turn calls the widget's
1257        `get_label_data`. The obtained labels are shown if the corresponding
1258        points are selected or if `label_only_selected` is `false`.
1259        """
1260        for label in self.labels:
1261            self.plot_widget.removeItem(label)
1262        self.labels = []
1263
1264        mask = None
1265        if self.scatterplot_item is not None:
1266            x, y = self.scatterplot_item.getData()
1267            mask = self._label_mask(x, y)
1268
1269        if mask is not None:
1270            labels = self.get_labels()
1271            if labels is None:
1272                mask = None
1273
1274        self._signal_too_many_labels(
1275            mask is not None and mask.sum() > self.MAX_VISIBLE_LABELS)
1276        if self._too_many_labels or mask is None or not np.any(mask):
1277            return
1278
1279        black = pg.mkColor(0, 0, 0)
1280        labels = labels[mask]
1281        x = x[mask]
1282        y = y[mask]
1283        for label, xp, yp in zip(labels, x, y):
1284            ti = TextItem(label, black)
1285            ti.setPos(xp, yp)
1286            self.plot_widget.addItem(ti)
1287            self.labels.append(ti)
1288            ti.setFont(self.parameter_setter.label_font)
1289
1290    def _signal_too_many_labels(self, too_many):
1291        if self._too_many_labels != too_many:
1292            self._too_many_labels = too_many
1293            self.too_many_labels.emit(too_many)
1294
1295    def _label_mask(self, x, y):
1296        (x0, x1), (y0, y1) = self.view_box.viewRange()
1297        mask = np.logical_and(
1298            np.logical_and(x >= x0, x <= x1),
1299            np.logical_and(y >= y0, y <= y1))
1300        if self.label_only_selected:
1301            sub_mask = self._filter_visible(self.master.get_subset_mask())
1302            if self.selection is None:
1303                if sub_mask is None:
1304                    return None
1305                else:
1306                    sel_mask = sub_mask
1307            else:
1308                sel_mask = self._filter_visible(self.selection) != 0
1309                if sub_mask is not None:
1310                    sel_mask = np.logical_or(sel_mask, sub_mask)
1311            mask = np.logical_and(mask, sel_mask)
1312        return mask
1313
1314    # Shapes
1315    def get_shapes(self):
1316        """
1317        Prepare data for shapes of points in the plot
1318
1319        The method is called by `update_shapes`. It gets the data from
1320        the widget's `get_shape_data`, and then calls its `impute_shapes`
1321        to impute the missing shape (usually with some default shape).
1322
1323        Returns:
1324            (np.ndarray): an array of symbols (e.g. o, x, + ...)
1325        """
1326        shape_data = self.master.get_shape_data()
1327        shape_data = self._filter_visible(shape_data)
1328        # Data has to be copied so the imputation can change it in-place
1329        # TODO: Try avoiding this when we move imputation to the widget
1330        if shape_data is not None:
1331            shape_data = np.copy(shape_data)
1332        shape_imputer = getattr(
1333            self.master, "impute_shapes", self.default_impute_shapes)
1334        shape_imputer(shape_data, len(self.CurveSymbols) - 1)
1335        if isinstance(shape_data, np.ndarray):
1336            shape_data = shape_data.astype(int)
1337        else:
1338            shape_data = np.zeros(self.n_shown, dtype=int)
1339        return self.CurveSymbols[shape_data]
1340
1341    @staticmethod
1342    def default_impute_shapes(shape_data, default_symbol):
1343        """
1344        Fallback imputation for shapes.
1345
1346        Use the default symbol, usually the last symbol in the list.
1347
1348        Returns:
1349            (bool): True if there was any missing data
1350        """
1351        if shape_data is None:
1352            return False
1353        nans = np.isnan(shape_data)
1354        if np.any(nans):
1355            shape_data[nans] = default_symbol
1356            return True
1357        else:
1358            return False
1359
1360    def update_shapes(self):
1361        """
1362        Trigger an update of point symbols
1363
1364        The method calls `get_shapes` to obtain an array with a symbol
1365        for each point and uses it to update the symbols.
1366
1367        Finally, the method updates the legend.
1368        """
1369        if self.scatterplot_item:
1370            shape_data = self.get_shapes()
1371            self.scatterplot_item.setSymbol(shape_data)
1372        self.update_legends()
1373
1374    def update_z_values(self):
1375        """
1376        Set z-values for point in the plot
1377
1378        The order is as follows:
1379        - selected points that are also in the subset on top,
1380        - followed by selected points,
1381        - followed by points from the subset,
1382        - followed by the rest.
1383        Within each of these four groups, points are ordered by their colors.
1384
1385        Points with less frequent colors are above those with more frequent.
1386        The points for which the value for the color is missing are at the
1387        bottom of their respective group.
1388        """
1389        if not self.scatterplot_item:
1390            return
1391
1392        subset = self.master.get_subset_mask()
1393        c_data = self.master.get_color_data()
1394        if subset is None and self.selection is None and c_data is None:
1395            self.scatterplot_item.setZ(None)
1396            return
1397
1398        z = np.zeros(self.n_shown)
1399
1400        if subset is not None:
1401            subset = self._filter_visible(subset)
1402            z[subset] += 1000
1403
1404        if self.selection is not None:
1405            z[self._filter_visible(self.selection) != 0] += 2000
1406
1407        if c_data is not None:
1408            c_nan = np.isnan(c_data)
1409            vis_data = self._filter_visible(c_data)
1410            vis_nan = np.isnan(vis_data)
1411            z[vis_nan] -= 999
1412            if not self.master.is_continuous_color():
1413                dist = np.bincount(c_data[~c_nan].astype(int))
1414                vis_knowns = vis_data[~vis_nan].astype(int)
1415                argdist = np.argsort(dist)
1416                z[~vis_nan] -= argdist[vis_knowns]
1417
1418        self.scatterplot_item.setZ(z)
1419
1420    def update_grid_visibility(self):
1421        """Show or hide the grid"""
1422        self.plot_widget.showGrid(x=self.show_grid, y=self.show_grid)
1423
1424    def update_legend_visibility(self):
1425        """
1426        Show or hide legends based on whether they are enabled and non-empty
1427        """
1428        self.shape_legend.setVisible(
1429            self.show_legend and bool(self.shape_legend.items))
1430        self.color_legend.setVisible(
1431            self.show_legend and bool(self.color_legend.items))
1432
1433    def update_legends(self):
1434        """Update content of legends and their visibility"""
1435        cont_color = self.master.is_continuous_color()
1436        shape_labels = self.master.get_shape_labels()
1437        color_labels = self.master.get_color_labels()
1438        if not cont_color and shape_labels is not None \
1439                and shape_labels == color_labels:
1440            colors = self.master.get_color_data()
1441            shapes = self.master.get_shape_data()
1442            mask = np.isfinite(colors) * np.isfinite(shapes)
1443            combined = (colors == shapes)[mask].all()
1444        else:
1445            combined = False
1446        if combined:
1447            self._update_combined_legend(shape_labels)
1448        else:
1449            self._update_shape_legend(shape_labels)
1450            if cont_color:
1451                self._update_continuous_color_legend(color_labels)
1452            else:
1453                self._update_color_legend(color_labels)
1454        self.update_legend_visibility()
1455        Updater.update_legend_font(self.parameter_setter.cat_legend_items,
1456                                   **self.parameter_setter.cat_legend_settings)
1457        Updater.update_num_legend_font(self.parameter_setter.num_legend,
1458                                       **self.parameter_setter.num_legend_settings)
1459
1460    def _update_shape_legend(self, labels):
1461        self.shape_legend.clear()
1462        if labels is None or self.scatterplot_item is None:
1463            return
1464        color = QColor(0, 0, 0)
1465        color.setAlpha(self.alpha_value)
1466        for label, symbol in zip(labels, self.CurveSymbols):
1467            self.shape_legend.addItem(
1468                SymbolItemSample(pen=color, brush=color, size=10, symbol=symbol),
1469                escape(label))
1470
1471    def _update_continuous_color_legend(self, label_formatter):
1472        self.color_legend.clear()
1473        if self.scale is None or self.scatterplot_item is None:
1474            return
1475        label = PaletteItemSample(self.palette, self.scale, label_formatter)
1476        self.color_legend.addItem(label, "")
1477        self.color_legend.setGeometry(label.boundingRect())
1478
1479    def _update_color_legend(self, labels):
1480        self.color_legend.clear()
1481        if labels is None:
1482            return
1483        self._update_colored_legend(self.color_legend, labels, 'o')
1484
1485    def _update_combined_legend(self, labels):
1486        # update_colored_legend will already clear the shape legend
1487        # so we remove colors here
1488        use_legend = \
1489            self.shape_legend if self.shape_legend.items else self.color_legend
1490        self.color_legend.clear()
1491        self.shape_legend.clear()
1492        self._update_colored_legend(use_legend, labels, self.CurveSymbols)
1493
1494    def _update_colored_legend(self, legend, labels, symbols):
1495        if self.scatterplot_item is None or not self.palette:
1496            return
1497        if isinstance(symbols, str):
1498            symbols = itertools.repeat(symbols, times=len(labels))
1499        colors = self.palette.values_to_colors(np.arange(len(labels)))
1500        for color, label, symbol in zip(colors, labels, symbols):
1501            color = QColor(*color)
1502            pen = _make_pen(color.darker(self.DarkerValue), 1.5)
1503            color.setAlpha(self.alpha_value)
1504            brush = QBrush(color)
1505            legend.addItem(
1506                SymbolItemSample(pen=pen, brush=brush, size=10, symbol=symbol),
1507                escape(label))
1508
1509    def zoom_button_clicked(self):
1510        self.plot_widget.getViewBox().setMouseMode(
1511            self.plot_widget.getViewBox().RectMode)
1512
1513    def pan_button_clicked(self):
1514        self.plot_widget.getViewBox().setMouseMode(
1515            self.plot_widget.getViewBox().PanMode)
1516
1517    def select_button_clicked(self):
1518        self.plot_widget.getViewBox().setMouseMode(
1519            self.plot_widget.getViewBox().RectMode)
1520
1521    def reset_button_clicked(self):
1522        self.plot_widget.getViewBox().autoRange()
1523        self.update_labels()
1524
1525    def select_by_click(self, _, points):
1526        if self.scatterplot_item is not None:
1527            self.select(points)
1528
1529    def select_by_rectangle(self, rect):
1530        if self.scatterplot_item is not None:
1531            x0, x1 = sorted((rect.topLeft().x(), rect.bottomRight().x()))
1532            y0, y1 = sorted((rect.topLeft().y(), rect.bottomRight().y()))
1533            x, y = self.master.get_coordinates_data()
1534            indices = np.flatnonzero(
1535                (x0 <= x) & (x <= x1) & (y0 <= y) & (y <= y1))
1536            self.select_by_indices(indices.astype(int))
1537
1538    def unselect_all(self):
1539        if self.selection is not None:
1540            self.selection = None
1541            self.update_selection_colors()
1542            if self.label_only_selected:
1543                self.update_labels()
1544            self.master.selection_changed()
1545
1546    def select(self, points):
1547        # noinspection PyArgumentList
1548        if self.scatterplot_item is None:
1549            return
1550        indices = [p.data() for p in points]
1551        self.select_by_indices(indices)
1552
1553    def select_by_indices(self, indices):
1554        if self.selection is None:
1555            self.selection = np.zeros(self.n_valid, dtype=np.uint8)
1556        keys = QApplication.keyboardModifiers()
1557        if keys & Qt.ControlModifier:
1558            self.selection_append(indices)
1559        elif keys & Qt.ShiftModifier:
1560            self.selection_new_group(indices)
1561        elif keys & Qt.AltModifier:
1562            self.selection_remove(indices)
1563        else:
1564            self.selection_select(indices)
1565
1566    def selection_select(self, indices):
1567        self.selection = np.zeros(self.n_valid, dtype=np.uint8)
1568        self.selection[indices] = 1
1569        self._update_after_selection()
1570
1571    def selection_append(self, indices):
1572        self.selection[indices] = max(np.max(self.selection), 1)
1573        self._update_after_selection()
1574
1575    def selection_new_group(self, indices):
1576        self.selection[indices] = np.max(self.selection) + 1
1577        self._update_after_selection()
1578
1579    def selection_remove(self, indices):
1580        self.selection[indices] = 0
1581        self._update_after_selection()
1582
1583    def _update_after_selection(self):
1584        self._compress_indices()
1585        self.update_selection_colors()
1586        if self.label_only_selected:
1587            self.update_labels()
1588        self.master.selection_changed()
1589
1590    def _compress_indices(self):
1591        indices = sorted(set(self.selection) | {0})
1592        if len(indices) == max(indices) + 1:
1593            return
1594        mapping = np.zeros((max(indices) + 1,), dtype=int)
1595        for i, ind in enumerate(indices):
1596            mapping[ind] = i
1597        self.selection = mapping[self.selection]
1598
1599    def get_selection(self):
1600        if self.selection is None:
1601            return np.array([], dtype=np.uint8)
1602        else:
1603            return np.flatnonzero(self.selection)
1604
1605    def help_event(self, event):
1606        """
1607        Create a `QToolTip` for the point hovered by the mouse
1608        """
1609        if self.scatterplot_item is None:
1610            return False
1611        act_pos = self.scatterplot_item.mapFromScene(event.scenePos())
1612        point_data = [p.data() for p in self.scatterplot_item.pointsAt(act_pos)]
1613        text = self.master.get_tooltip(point_data)
1614        if text:
1615            QToolTip.showText(event.screenPos(), text, widget=self.plot_widget)
1616            return True
1617        else:
1618            return False
1619