1from functools import partial, reduce
2from itertools import count, groupby, repeat
3from xml.sax.saxutils import escape
4
5import numpy as np
6from scipy.stats import norm, rayleigh, beta, gamma, pareto, expon
7
8from AnyQt.QtWidgets import QGraphicsRectItem
9from AnyQt.QtGui import QColor, QPen, QBrush, QPainter, QPalette, QPolygonF
10from AnyQt.QtCore import Qt, QRectF, QPointF, pyqtSignal as Signal
11from orangewidget.utils.listview import ListViewSearch
12import pyqtgraph as pg
13
14from Orange.data import Table, DiscreteVariable, ContinuousVariable, Domain
15from Orange.preprocess.discretize import decimal_binnings, time_binnings, \
16    short_time_units
17from Orange.statistics import distribution, contingency
18from Orange.widgets import gui, settings
19from Orange.widgets.utils.annotated_data import \
20    create_groups_table, create_annotated_table, ANNOTATED_DATA_SIGNAL_NAME
21from Orange.widgets.utils.itemmodels import DomainModel
22from Orange.widgets.utils.widgetpreview import WidgetPreview
23from Orange.widgets.visualize.utils.plotutils import ElidedLabelsAxis
24from Orange.widgets.widget import Input, Output, OWWidget, Msg
25
26from Orange.widgets.visualize.owscatterplotgraph import \
27    LegendItem as SPGLegendItem
28
29
30class ScatterPlotItem(pg.ScatterPlotItem):
31    Symbols = pg.graphicsItems.ScatterPlotItem.Symbols
32
33    # pylint: disable=arguments-differ
34    def paint(self, painter, option, widget=None):
35        if self.opts["pxMode"]:
36            painter.setRenderHint(QPainter.SmoothPixmapTransform, True)
37        if self.opts["antialias"]:
38            painter.setRenderHint(QPainter.Antialiasing, True)
39        super().paint(painter, option, widget)
40
41
42class LegendItem(SPGLegendItem):
43    @staticmethod
44    def mousePressEvent(event):
45        if event.button() == Qt.LeftButton:
46            event.accept()
47        else:
48            event.ignore()
49
50    def mouseMoveEvent(self, event):
51        if event.buttons() & Qt.LeftButton:
52            event.accept()
53            if self.parentItem() is not None:
54                self.autoAnchor(
55                    self.pos() + (event.pos() - event.lastPos()) / 2)
56        else:
57            event.ignore()
58
59    @staticmethod
60    def mouseReleaseEvent(event):
61        if event.button() == Qt.LeftButton:
62            event.accept()
63        else:
64            event.ignore()
65
66
67class DistributionBarItem(pg.GraphicsObject):
68    def __init__(self, x, width, padding, freqs, colors, stacked, expanded,
69                 tooltip, desc, hidden):
70        super().__init__()
71        self.x = x
72        self.width = width
73        self.freqs = freqs
74        self.colors = colors
75        self.padding = padding
76        self.stacked = stacked
77        self.expanded = expanded
78        self.__picture = None
79        self.polygon = None
80        self.hovered = False
81        self._tooltip = tooltip
82        self.desc = desc
83        self.hidden = False
84        self.setHidden(hidden)
85        self.setAcceptHoverEvents(True)
86
87    def hoverEnterEvent(self, event):
88        super().hoverEnterEvent(event)
89        self.hovered = True
90        self.update()
91
92    def hoverLeaveEvent(self, event):
93        super().hoverLeaveEvent(event)
94        self.hovered = False
95        self.update()
96
97    def setHidden(self, hidden):
98        self.hidden = hidden
99        if not hidden:
100            self.setToolTip(self._tooltip)
101
102    def paint(self, painter, _options, _widget):
103        if self.hidden:
104            return
105
106        if self.expanded:
107            tot = np.sum(self.freqs)
108            if tot == 0:
109                return
110            freqs = self.freqs / tot
111        else:
112            freqs = self.freqs
113
114        if not self.padding:
115            padding = self.mapRectFromDevice(QRectF(0, 0, 0.5, 0)).width()
116        else:
117            padding = min(20, self.width * self.padding)
118        sx = self.x + padding
119        padded_width = self.width - 2 * padding
120
121        if self.stacked:
122            painter.setPen(Qt.NoPen)
123            y = 0
124            for freq, color in zip(freqs, self.colors):
125                painter.setBrush(QBrush(color))
126                painter.drawRect(QRectF(sx, y, padded_width, freq))
127                y += freq
128            self.polygon = QPolygonF(QRectF(sx, 0, padded_width, y))
129        else:
130            polypoints = [QPointF(sx, 0)]
131            pen = QPen(QBrush(Qt.white), 0.5)
132            pen.setCosmetic(True)
133            painter.setPen(pen)
134            wsingle = padded_width / len(self.freqs)
135            for i, freq, color in zip(count(), freqs, self.colors):
136                painter.setBrush(QBrush(color))
137                x = sx + wsingle * i
138                painter.drawRect(
139                    QRectF(x, 0, wsingle, freq))
140                polypoints += [QPointF(x, freq),
141                               QPointF(x + wsingle, freq)]
142            polypoints += [QPointF(polypoints[-1].x(), 0), QPointF(sx, 0)]
143            self.polygon = QPolygonF(polypoints)
144
145        if self.hovered:
146            pen = QPen(QBrush(Qt.blue), 2, Qt.DashLine)
147            pen.setCosmetic(True)
148            painter.setPen(pen)
149            painter.setBrush(Qt.NoBrush)
150            painter.drawPolygon(self.polygon)
151
152    @property
153    def x0(self):
154        return self.x
155
156    @property
157    def x1(self):
158        return self.x + self.width
159
160    def boundingRect(self):
161        if self.expanded:
162            height = 1
163        elif self.stacked:
164            height = sum(self.freqs)
165        else:
166            height = max(self.freqs)
167        return QRectF(self.x, 0, self.width, height)
168
169
170class DistributionWidget(pg.PlotWidget):
171    item_clicked = Signal(DistributionBarItem, Qt.KeyboardModifiers, bool)
172    blank_clicked = Signal()
173    mouse_released = Signal()
174
175    def __init__(self, *args, **kwargs):
176        super().__init__(*args, **kwargs)
177        self.last_item = None
178
179    def _get_bar_item(self, pos):
180        for item in self.items(pos):
181            if isinstance(item, DistributionBarItem):
182                return item
183        return None
184
185    def mousePressEvent(self, ev):
186        super().mousePressEvent(ev)
187        if ev.isAccepted():
188            return
189        if ev.button() != Qt.LeftButton:
190            ev.ignore()
191            return
192
193        ev.accept()
194        self.last_item = self._get_bar_item(ev.pos())
195        if self.last_item:
196            self.item_clicked.emit(self.last_item, ev.modifiers(), False)
197        else:
198            self.blank_clicked.emit()
199
200    def mouseReleaseEvent(self, ev):
201        self.last_item = None
202        self.mouse_released.emit()
203
204    def mouseMoveEvent(self, ev):
205        super().mouseMoveEvent(ev)
206        if self.last_item is not None:
207            item = self._get_bar_item(ev.pos())
208            if item is not None and item is not self.last_item:
209                self.item_clicked.emit(item, ev.modifiers(), True)
210                self.last_item = item
211
212
213class AshCurve:
214    @staticmethod
215    def fit(a):
216        return (a, )
217
218    @staticmethod
219    def pdf(x, a, sigma=1, weights=None):
220        hist, _ = np.histogram(a, x, weights=weights)
221        kernel_x = np.arange(len(x)) - len(hist) / 2
222        kernel = 1 / (np.sqrt(2 * np.pi)) * np.exp(-(kernel_x * sigma) ** 2 / 2)
223        ash = np.convolve(hist, kernel, mode="same")
224        ash /= ash.sum()
225        return ash
226
227
228class ElidedAxisNoUnits(ElidedLabelsAxis):
229    def __init__(self, orientation, pen=None, linkView=None, parent=None,
230                 maxTickLength=-5, showValues=True):
231        self.show_unit = False
232        self.tick_dict = {}
233        super().__init__(orientation, pen=pen, linkView=linkView, parent=parent,
234                         maxTickLength=maxTickLength, showValues=showValues)
235
236    def setShowUnit(self, show_unit):
237        self.show_unit = show_unit
238
239    def labelString(self):
240        if self.show_unit:
241            return super().labelString()
242
243        style = ';'.join(f"{k}: {v}" for k, v in self.labelStyle.items())
244        return f"<span style='{style}'>{self.labelText}</span>"
245
246
247class OWDistributions(OWWidget):
248    name = "Distributions"
249    description = "Display value distributions of a data feature in a graph."
250    icon = "icons/Distribution.svg"
251    priority = 120
252    keywords = ["histogram"]
253
254    class Inputs:
255        data = Input("Data", Table, doc="Set the input dataset")
256
257    class Outputs:
258        selected_data = Output("Selected Data", Table, default=True)
259        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)
260        histogram_data = Output("Histogram Data", Table)
261
262    class Error(OWWidget.Error):
263        no_defined_values_var = \
264            Msg("Variable '{}' does not have any defined values")
265        no_defined_values_pair = \
266            Msg("No data instances with '{}' and '{}' defined")
267
268    class Warning(OWWidget.Warning):
269        ignored_nans = Msg("Data instances with missing values are ignored")
270
271    settingsHandler = settings.DomainContextHandler()
272    var = settings.ContextSetting(None)
273    cvar = settings.ContextSetting(None)
274    selection = settings.ContextSetting(set(), schema_only=True)
275    # number_of_bins must be a context setting because selection depends on it
276    number_of_bins = settings.ContextSetting(5, schema_only=True)
277
278    fitted_distribution = settings.Setting(0)
279    hide_bars = settings.Setting(False)
280    show_probs = settings.Setting(False)
281    stacked_columns = settings.Setting(False)
282    cumulative_distr = settings.Setting(False)
283    sort_by_freq = settings.Setting(False)
284    kde_smoothing = settings.Setting(10)
285
286    auto_apply = settings.Setting(True)
287
288    graph_name = "plot"
289
290    Fitters = (
291        ("None", None, (), ()),
292        ("Normal", norm, ("loc", "scale"), ("μ", "σ")),
293        ("Beta", beta, ("a", "b", "loc", "scale"),
294         ("α", "β", "-loc", "-scale")),
295        ("Gamma", gamma, ("a", "loc", "scale"), ("α", "β", "-loc", "-scale")),
296        ("Rayleigh", rayleigh, ("loc", "scale"), ("-loc", "σ")),
297        ("Pareto", pareto, ("b", "loc", "scale"), ("α", "-loc", "-scale")),
298        ("Exponential", expon, ("loc", "scale"), ("-loc", "λ")),
299        ("Kernel density", AshCurve, ("a",), ("",))
300    )
301
302    DragNone, DragAdd, DragRemove = range(3)
303
304    def __init__(self):
305        super().__init__()
306        self.data = None
307        self.valid_data = self.valid_group_data = None
308        self.bar_items = []
309        self.curve_items = []
310        self.curve_descriptions = None
311        self.binnings = []
312
313        self.last_click_idx = None
314        self.drag_operation = self.DragNone
315        self.key_operation = None
316        self._user_var_bins = {}
317
318        varview = gui.listView(
319            self.controlArea, self, "var", box="Variable",
320            model=DomainModel(valid_types=DomainModel.PRIMITIVE,
321                              separators=False),
322            callback=self._on_var_changed,
323            viewType=ListViewSearch
324        )
325        gui.checkBox(
326            varview.box, self, "sort_by_freq", "Sort categories by frequency",
327            callback=self._on_sort_by_freq, stateWhenDisabled=False)
328
329        box = self.continuous_box = gui.vBox(self.controlArea, "Distribution")
330        gui.comboBox(
331            box, self, "fitted_distribution", label="Fitted distribution",
332            orientation=Qt.Horizontal, items=(name[0] for name in self.Fitters),
333            callback=self._on_fitted_dist_changed)
334        slider = gui.hSlider(
335            box, self, "number_of_bins",
336            label="Bin width", orientation=Qt.Horizontal,
337            minValue=0, maxValue=max(1, len(self.binnings) - 1),
338            createLabel=False, callback=self._on_bins_changed)
339        self.bin_width_label = gui.widgetLabel(slider.box)
340        self.bin_width_label.setFixedWidth(35)
341        self.bin_width_label.setAlignment(Qt.AlignRight)
342        slider.sliderReleased.connect(self._on_bin_slider_released)
343        self.smoothing_box = gui.hSlider(
344            box, self, "kde_smoothing",
345            label="Smoothing", orientation=Qt.Horizontal,
346            minValue=2, maxValue=20, callback=self.replot, disabled=True)
347        gui.checkBox(
348            box, self, "hide_bars", "Hide bars", stateWhenDisabled=False,
349            callback=self._on_hide_bars_changed,
350            disabled=not self.fitted_distribution)
351
352        box = gui.vBox(self.controlArea, "Columns")
353        gui.comboBox(
354            box, self, "cvar", label="Split by", orientation=Qt.Horizontal,
355            searchable=True,
356            model=DomainModel(placeholder="(None)",
357                              valid_types=(DiscreteVariable), ),
358            callback=self._on_cvar_changed, contentsLength=18)
359        gui.checkBox(
360            box, self, "stacked_columns", "Stack columns",
361            callback=self.replot)
362        gui.checkBox(
363            box, self, "show_probs", "Show probabilities",
364            callback=self._on_show_probabilities_changed)
365        gui.checkBox(
366            box, self, "cumulative_distr", "Show cumulative distribution",
367            callback=self._on_show_cumulative)
368
369        gui.auto_apply(self.buttonsArea, self, commit=self.apply)
370
371        self._set_smoothing_visibility()
372        self._setup_plots()
373        self._setup_legend()
374
375    def _setup_plots(self):
376        def add_new_plot(zvalue):
377            plot = pg.ViewBox(enableMouse=False, enableMenu=False)
378            self.ploti.scene().addItem(plot)
379            pg.AxisItem("right").linkToView(plot)
380            plot.setXLink(self.ploti)
381            plot.setZValue(zvalue)
382            return plot
383
384        self.plotview = DistributionWidget()
385        self.plotview.item_clicked.connect(self._on_item_clicked)
386        self.plotview.blank_clicked.connect(self._on_blank_clicked)
387        self.plotview.mouse_released.connect(self._on_end_selecting)
388        self.plotview.setRenderHint(QPainter.Antialiasing)
389        box = gui.vBox(self.mainArea, box=True, margin=0)
390        box.layout().addWidget(self.plotview)
391        self.ploti = pg.PlotItem(
392            enableMenu=False, enableMouse=False,
393            axisItems={"bottom": ElidedAxisNoUnits("bottom")})
394        self.plot = self.ploti.vb
395        self.plot.setMouseEnabled(False, False)
396        self.ploti.hideButtons()
397        self.plotview.setCentralItem(self.ploti)
398
399        self.plot_pdf = add_new_plot(10)
400        self.plot_mark = add_new_plot(-10)
401        self.plot_mark.setYRange(0, 1)
402        self.ploti.vb.sigResized.connect(self.update_views)
403        self.update_views()
404
405        pen = QPen(self.palette().color(QPalette.Text))
406        self.ploti.getAxis("bottom").setPen(pen)
407        left = self.ploti.getAxis("left")
408        left.setPen(pen)
409        left.setStyle(stopAxisAtTick=(True, True))
410
411    def _setup_legend(self):
412        self._legend = LegendItem()
413        self._legend.setParentItem(self.plot_pdf)
414        self._legend.hide()
415        self._legend.anchor((1, 0), (1, 0))
416
417    # -----------------------------
418    # Event and signal handlers
419
420    def update_views(self):
421        for plot in (self.plot_pdf, self.plot_mark):
422            plot.setGeometry(self.plot.sceneBoundingRect())
423            plot.linkedViewChanged(self.plot, plot.XAxis)
424
425    def onDeleteWidget(self):
426        self.plot.clear()
427        self.plot_pdf.clear()
428        self.plot_mark.clear()
429        super().onDeleteWidget()
430
431    @Inputs.data
432    def set_data(self, data):
433        self.closeContext()
434        self.var = self.cvar = None
435        self.data = data
436        domain = self.data.domain if self.data else None
437        varmodel = self.controls.var.model()
438        cvarmodel = self.controls.cvar.model()
439        varmodel.set_domain(domain)
440        cvarmodel.set_domain(domain)
441        if varmodel:
442            self.var = varmodel[min(len(domain.class_vars), len(varmodel) - 1)]
443        if domain is not None and domain.has_discrete_class:
444            self.cvar = domain.class_var
445        self.reset_select()
446        self._user_var_bins.clear()
447        self.openContext(domain)
448        self.set_valid_data()
449        self.recompute_binnings()
450        self.replot()
451        self.apply()
452
453    def _on_var_changed(self):
454        self.reset_select()
455        self.set_valid_data()
456        self.recompute_binnings()
457        self.replot()
458        self.apply()
459
460    def _on_cvar_changed(self):
461        self.set_valid_data()
462        self.replot()
463        self.apply()
464
465    def _on_show_cumulative(self):
466        self.replot()
467        self.apply()
468
469    def _on_sort_by_freq(self):
470        self.replot()
471        self.apply()
472
473    def _on_bins_changed(self):
474        self.reset_select()
475        self._set_bin_width_slider_label()
476        self.replot()
477        # this is triggered when dragging, so don't call apply here;
478        # apply is called on sliderReleased
479
480    def _on_bin_slider_released(self):
481        self._user_var_bins[self.var] = self.number_of_bins
482        self.apply()
483
484    def _on_fitted_dist_changed(self):
485        self.controls.hide_bars.setDisabled(not self.fitted_distribution)
486        self._set_smoothing_visibility()
487        self.replot()
488
489    def _on_hide_bars_changed(self):
490        for bar in self.bar_items:  # pylint: disable=blacklisted-name
491            bar.setHidden(self.hide_bars)
492        self._set_curve_brushes()
493        self.plot.update()
494
495    def _set_smoothing_visibility(self):
496        self.smoothing_box.setDisabled(
497            self.Fitters[self.fitted_distribution][1] is not AshCurve)
498
499    def _set_bin_width_slider_label(self):
500        if self.number_of_bins < len(self.binnings):
501            text = reduce(
502                lambda s, rep: s.replace(*rep),
503                short_time_units.items(),
504                self.binnings[self.number_of_bins].width_label)
505        else:
506            text = ""
507        self.bin_width_label.setText(text)
508
509    def _on_show_probabilities_changed(self):
510        label = self.controls.fitted_distribution.label
511        if self.show_probs:
512            label.setText("Fitted probability")
513            label.setToolTip(
514                "Chosen distribution is used to compute Bayesian probabilities")
515        else:
516            label.setText("Fitted distribution")
517            label.setToolTip("")
518        self.replot()
519
520    @property
521    def is_valid(self):
522        return self.valid_data is not None
523
524    def set_valid_data(self):
525        err_def_var = self.Error.no_defined_values_var
526        err_def_pair = self.Error.no_defined_values_pair
527        err_def_var.clear()
528        err_def_pair.clear()
529        self.Warning.ignored_nans.clear()
530
531        self.valid_data = self.valid_group_data = None
532        if self.var is None:
533            return
534
535        column = self.data.get_column_view(self.var)[0].astype(float)
536        valid_mask = np.isfinite(column)
537        if not np.any(valid_mask):
538            self.Error.no_defined_values_var(self.var.name)
539            return
540        if self.cvar:
541            ccolumn = self.data.get_column_view(self.cvar)[0].astype(float)
542            valid_mask *= np.isfinite(ccolumn)
543            if not np.any(valid_mask):
544                self.Error.no_defined_values_pair(self.var.name, self.cvar.name)
545                return
546            self.valid_group_data = ccolumn[valid_mask]
547        if not np.all(valid_mask):
548            self.Warning.ignored_nans()
549        self.valid_data = column[valid_mask]
550
551    # -----------------------------
552    # Plotting
553
554    def replot(self):
555        self._clear_plot()
556        if self.is_valid:
557            self._set_axis_names()
558            self._update_controls_state()
559            self._call_plotting()
560            self._display_legend()
561        self.show_selection()
562
563    def _clear_plot(self):
564        self.plot.clear()
565        self.plot_pdf.clear()
566        self.plot_mark.clear()
567        self.bar_items = []
568        self.curve_items = []
569        self._legend.clear()
570        self._legend.hide()
571
572    def _set_axis_names(self):
573        assert self.is_valid  # called only from replot, so assumes data is OK
574        bottomaxis = self.ploti.getAxis("bottom")
575        bottomaxis.setLabel(self.var and self.var.name)
576        bottomaxis.setShowUnit(not (self.var and self.var.is_time))
577
578        leftaxis = self.ploti.getAxis("left")
579        if self.show_probs and self.cvar:
580            leftaxis.setLabel(
581                f"Probability of '{self.cvar.name}' at given '{self.var.name}'")
582        else:
583            leftaxis.setLabel("Frequency")
584        leftaxis.resizeEvent()
585
586    def _update_controls_state(self):
587        assert self.is_valid  # called only from replot, so assumes data is OK
588        self.controls.sort_by_freq.setDisabled(self.var.is_continuous)
589        self.continuous_box.setDisabled(self.var.is_discrete)
590        self.controls.show_probs.setDisabled(self.cvar is None)
591        self.controls.stacked_columns.setDisabled(self.cvar is None)
592
593    def _call_plotting(self):
594        assert self.is_valid  # called only from replot, so assumes data is OK
595        self.curve_descriptions = None
596        if self.var.is_discrete:
597            if self.cvar:
598                self._disc_split_plot()
599            else:
600                self._disc_plot()
601        else:
602            if self.cvar:
603                self._cont_split_plot()
604            else:
605                self._cont_plot()
606        self.plot.autoRange()
607
608    def _add_bar(self, x, width, padding, freqs, colors, stacked, expanded,
609                 tooltip, desc, hidden=False):
610        item = DistributionBarItem(
611            x, width, padding, freqs, colors, stacked, expanded, tooltip,
612            desc, hidden)
613        self.plot.addItem(item)
614        self.bar_items.append(item)
615
616    def _disc_plot(self):
617        var = self.var
618        dist = distribution.get_distribution(self.data, self.var)
619        dist = np.array(dist)  # Distribution misbehaves in further operations
620        if self.sort_by_freq:
621            order = np.argsort(dist)[::-1]
622        else:
623            order = np.arange(len(dist))
624
625        ordered_values = np.array(var.values)[order]
626        self.ploti.getAxis("bottom").setTicks([list(enumerate(ordered_values))])
627
628        colors = [QColor(0, 128, 255)]
629        for i, freq, desc in zip(count(), dist[order], ordered_values):
630            tooltip = \
631                "<p style='white-space:pre;'>" \
632                f"<b>{escape(desc)}</b>: {int(freq)} " \
633                f"({100 * freq / len(self.valid_data):.2f} %) "
634            self._add_bar(
635                i - 0.5, 1, 0.1, [freq], colors,
636                stacked=False, expanded=False, tooltip=tooltip, desc=desc)
637
638    def _disc_split_plot(self):
639        var = self.var
640        conts = contingency.get_contingency(self.data, self.cvar, self.var)
641        conts = np.array(conts)  # Contingency misbehaves in further operations
642        if self.sort_by_freq:
643            order = np.argsort(conts.sum(axis=1))[::-1]
644        else:
645            order = np.arange(len(conts))
646
647        ordered_values = np.array(var.values)[order]
648        self.ploti.getAxis("bottom").setTicks([list(enumerate(ordered_values))])
649
650        gcolors = [QColor(*col) for col in self.cvar.colors]
651        gvalues = self.cvar.values
652        total = len(self.data)
653        for i, freqs, desc in zip(count(), conts[order], ordered_values):
654            self._add_bar(
655                i - 0.5, 1, 0.1, freqs, gcolors,
656                stacked=self.stacked_columns, expanded=self.show_probs,
657                tooltip=self._split_tooltip(
658                    desc, np.sum(freqs), total, gvalues, freqs),
659                desc=desc)
660
661    def _cont_plot(self):
662        self._set_cont_ticks()
663        data = self.valid_data
664        binning = self.binnings[self.number_of_bins]
665        y, x = np.histogram(data, bins=binning.thresholds)
666        total = len(data)
667        colors = [QColor(0, 128, 255)]
668        if self.fitted_distribution:
669            colors[0] = colors[0].lighter(130)
670
671        tot_freq = 0
672        lasti = len(y) - 1
673        width = np.min(x[1:] - x[:-1])
674        unique = self.number_of_bins == 0 and binning.width is None
675        xoff = -width / 2 if unique else 0
676        for i, (x0, x1), freq in zip(count(), zip(x, x[1:]), y):
677            tot_freq += freq
678            desc = self.str_int(x0, x1, not i, i == lasti, unique)
679            tooltip = \
680                "<p style='white-space:pre;'>" \
681                f"<b>{escape(desc)}</b>: " \
682                f"{freq} ({100 * freq / total:.2f} %)</p>"
683            bar_width = width if unique else x1 - x0
684            self._add_bar(
685                x0 + xoff, bar_width, 0,
686                [tot_freq if self.cumulative_distr else freq],
687                colors, stacked=False, expanded=False, tooltip=tooltip,
688                desc=desc, hidden=self.hide_bars)
689
690        if self.fitted_distribution:
691            self._plot_approximations(
692                x[0], x[-1], [self._fit_approximation(data)],
693                [QColor(0, 0, 0)], (1,))
694
695    def _cont_split_plot(self):
696        self._set_cont_ticks()
697        data = self.valid_data
698        binning = self.binnings[self.number_of_bins]
699        _, bins = np.histogram(data, bins=binning.thresholds)
700        gvalues = self.cvar.values
701        varcolors = [QColor(*col) for col in self.cvar.colors]
702        if self.fitted_distribution:
703            gcolors = [c.lighter(130) for c in varcolors]
704        else:
705            gcolors = varcolors
706        nvalues = len(gvalues)
707        ys = []
708        fitters = []
709        prior_sizes = []
710        for val_idx in range(nvalues):
711            group_data = data[self.valid_group_data == val_idx]
712            prior_sizes.append(len(group_data))
713            ys.append(np.histogram(group_data, bins)[0])
714            if self.fitted_distribution:
715                fitters.append(self._fit_approximation(group_data))
716        total = len(data)
717        prior_sizes = np.array(prior_sizes)
718        tot_freqs = np.zeros(len(ys))
719
720        lasti = len(ys[0]) - 1
721        width = np.min(bins[1:] - bins[:-1])
722        unique = self.number_of_bins == 0 and binning.width is None
723        xoff = -width / 2 if unique else 0
724        for i, x0, x1, freqs in zip(count(), bins, bins[1:], zip(*ys)):
725            tot_freqs += freqs
726            plotfreqs = tot_freqs.copy() if self.cumulative_distr else freqs
727            desc = self.str_int(x0, x1, not i, i == lasti, unique)
728            bar_width = width if unique else x1 - x0
729            self._add_bar(
730                x0 + xoff, bar_width, 0 if self.stacked_columns else 0.1,
731                plotfreqs,
732                gcolors, stacked=self.stacked_columns, expanded=self.show_probs,
733                hidden=self.hide_bars,
734                tooltip=self._split_tooltip(
735                    desc, np.sum(plotfreqs), total, gvalues, plotfreqs),
736                desc=desc)
737
738        if fitters:
739            self._plot_approximations(bins[0], bins[-1], fitters, varcolors,
740                                      prior_sizes / len(data))
741
742    def _set_cont_ticks(self):
743        axis = self.ploti.getAxis("bottom")
744        if self.var and self.var.is_time:
745            binning = self.binnings[self.number_of_bins]
746            labels = np.array(binning.short_labels)
747            thresholds = np.array(binning.thresholds)
748            lengths = np.array([len(lab) for lab in labels])
749            slengths = set(lengths)
750            if len(slengths) == 1:
751                ticks = [list(zip(thresholds[::2], labels[::2])),
752                         list(zip(thresholds[1::2], labels[1::2]))]
753            else:
754                ticks = []
755                for length in sorted(slengths, reverse=True):
756                    idxs = lengths == length
757                    ticks.append(list(zip(thresholds[idxs], labels[idxs])))
758            axis.setTicks(ticks)
759        else:
760            axis.setTicks(None)
761
762    def _fit_approximation(self, y):
763        def join_pars(pairs):
764            strv = self.var.str_val
765            return ", ".join(f"{sname}={strv(val)}" for sname, val in pairs)
766
767        def str_params():
768            s = join_pars(
769                (sname, val) for sname, val in zip(str_names, fitted)
770                if sname and sname[0] != "-")
771            par = join_pars(
772                (sname[1:], val) for sname, val in zip(str_names, fitted)
773                if sname and sname[0] == "-")
774            if par:
775                s += f" ({par})"
776            return s
777
778        if not y.size:
779            return None, None
780        _, dist, names, str_names = self.Fitters[self.fitted_distribution]
781        fitted = dist.fit(y)
782        params = dict(zip(names, fitted))
783        return partial(dist.pdf, **params), str_params()
784
785    def _plot_approximations(self, x0, x1, fitters, colors, prior_probs):
786        x = np.linspace(x0, x1, 100)
787        ys = np.zeros((len(fitters), 100))
788        self.curve_descriptions = [s for _, s in fitters]
789        for y, (fitter, _) in zip(ys, fitters):
790            if fitter is None:
791                continue
792            if self.Fitters[self.fitted_distribution][1] is AshCurve:
793                y[:] = fitter(x, sigma=(22 - self.kde_smoothing) / 40)
794            else:
795                y[:] = fitter(x)
796            if self.cumulative_distr:
797                y[:] = np.cumsum(y)
798        tots = np.sum(ys, axis=0)
799
800        show_probs = self.show_probs and self.cvar is not None
801        plot = self.ploti if show_probs else self.plot_pdf
802
803        for y, prior_prob, color in zip(ys, prior_probs, colors):
804            if not prior_prob:
805                continue
806            if show_probs:
807                y_p = y * prior_prob
808                tot = (y_p + (tots - y) * (1 - prior_prob))
809                tot[tot == 0] = 1
810                y = y_p / tot
811            curve = pg.PlotCurveItem(
812                x=x, y=y, fillLevel=0,
813                pen=pg.mkPen(width=5, color=color),
814                shadowPen=pg.mkPen(width=8, color=color.darker(120)))
815            plot.addItem(curve)
816            self.curve_items.append(curve)
817        if not show_probs:
818            self.plot_pdf.autoRange()
819        self._set_curve_brushes()
820
821    def _set_curve_brushes(self):
822        for curve in self.curve_items:
823            if self.hide_bars:
824                color = curve.opts['pen'].color().lighter(160)
825                color.setAlpha(128)
826                curve.setBrush(pg.mkBrush(color))
827            else:
828                curve.setBrush(None)
829
830    @staticmethod
831    def _split_tooltip(valname, tot_group, total, gvalues, freqs):
832        div_group = tot_group or 1
833        cs = "white-space:pre; text-align: right;"
834        s = f"style='{cs} padding-left: 1em'"
835        snp = f"style='{cs}'"
836        return f"<table style='border-collapse: collapse'>" \
837               f"<tr><th {s}>{escape(valname)}:</th>" \
838               f"<td {snp}><b>{int(tot_group)}</b></td>" \
839               "<td/>" \
840               f"<td {s}><b>{100 * tot_group / total:.2f} %</b></td></tr>" + \
841               f"<tr><td/><td/><td {s}>(in group)</td><td {s}>(overall)</td>" \
842               "</tr>" + \
843               "".join(
844                   "<tr>"
845                   f"<th {s}>{value}:</th>"
846                   f"<td {snp}><b>{int(freq)}</b></td>"
847                   f"<td {s}>{100 * freq / div_group:.2f} %</td>"
848                   f"<td {s}>{100 * freq / total:.2f} %</td>"
849                   "</tr>"
850                   for value, freq in zip(gvalues, freqs)) + \
851               "</table>"
852
853    def _display_legend(self):
854        assert self.is_valid  # called only from replot, so assumes data is OK
855        if self.cvar is None:
856            if not self.curve_descriptions or not self.curve_descriptions[0]:
857                self._legend.hide()
858                return
859            self._legend.addItem(
860                pg.PlotCurveItem(pen=pg.mkPen(width=5, color=0.0)),
861                self.curve_descriptions[0])
862        else:
863            cvar_values = self.cvar.values
864            colors = [QColor(*col) for col in self.cvar.colors]
865            descriptions = self.curve_descriptions or repeat(None)
866            for color, name, desc in zip(colors, cvar_values, descriptions):
867                self._legend.addItem(
868                    ScatterPlotItem(pen=color, brush=color, size=10, shape="s"),
869                    escape(name + (f" ({desc})" if desc else "")))
870        self._legend.show()
871
872    # -----------------------------
873    # Bins
874
875    def recompute_binnings(self):
876        if self.is_valid and self.var.is_continuous:
877            # binning is computed on valid var data, ignoring any cvar nans
878            column = self.data.get_column_view(self.var)[0].astype(float)
879            if np.any(np.isfinite(column)):
880                if self.var.is_time:
881                    self.binnings = time_binnings(column, min_unique=5)
882                    self.bin_width_label.setFixedWidth(45)
883                else:
884                    self.binnings = decimal_binnings(
885                        column, min_width=self.min_var_resolution(self.var),
886                        add_unique=10, min_unique=5)
887                    self.bin_width_label.setFixedWidth(35)
888                max_bins = len(self.binnings) - 1
889        else:
890            self.binnings = []
891            max_bins = 0
892
893        self.controls.number_of_bins.setMaximum(max_bins)
894        self.number_of_bins = min(
895            max_bins, self._user_var_bins.get(self.var, self.number_of_bins))
896        self._set_bin_width_slider_label()
897
898    @staticmethod
899    def min_var_resolution(var):
900        # pylint: disable=unidiomatic-typecheck
901        if type(var) is not ContinuousVariable:
902            return 0
903        return 10 ** -var.number_of_decimals
904
905    def str_int(self, x0, x1, first, last, unique=False):
906        var = self.var
907        sx0, sx1 = var.repr_val(x0), var.repr_val(x1)
908        if self.cumulative_distr:
909            return f"{var.name} < {sx1}"
910        elif first and last or unique:
911            return f"{var.name} = {sx0}"
912        elif first:
913            return f"{var.name} < {sx1}"
914        elif last:
915            return f"{var.name} ≥ {sx0}"
916        elif sx0 == sx1 or x1 - x0 <= self.min_var_resolution(var):
917            return f"{var.name} = {sx0}"
918        else:
919            return f"{sx0} ≤ {var.name} < {sx1}"
920
921    # -----------------------------
922    # Selection
923
924    def _on_item_clicked(self, item, modifiers, drag):
925        def add_or_remove(idx, add):
926            self.drag_operation = [self.DragRemove, self.DragAdd][add]
927            if add:
928                self.selection.add(idx)
929            else:
930                if idx in self.selection:
931                    # This can be False when removing with dragging and the
932                    # mouse crosses unselected items
933                    self.selection.remove(idx)
934
935        def add_range(add):
936            if self.last_click_idx is None:
937                add = True
938                idx_range = {idx}
939            else:
940                from_idx, to_idx = sorted((self.last_click_idx, idx))
941                idx_range = set(range(from_idx, to_idx + 1))
942            self.drag_operation = [self.DragRemove, self.DragAdd][add]
943            if add:
944                self.selection |= idx_range
945            else:
946                self.selection -= idx_range
947
948        self.key_operation = None
949        if item is None:
950            self.reset_select()
951            return
952
953        idx = self.bar_items.index(item)
954        if drag:
955            # Dragging has to add a range, otherwise fast dragging skips bars
956            add_range(self.drag_operation == self.DragAdd)
957        else:
958            if modifiers & Qt.ShiftModifier:
959                add_range(self.drag_operation == self.DragAdd)
960            elif modifiers & Qt.ControlModifier:
961                add_or_remove(idx, add=idx not in self.selection)
962            else:
963                if self.selection == {idx}:
964                    # Clicking on a single selected bar  deselects it,
965                    # but dragging from here will select
966                    add_or_remove(idx, add=False)
967                    self.drag_operation = self.DragAdd
968                else:
969                    self.selection.clear()
970                    add_or_remove(idx, add=True)
971        self.last_click_idx = idx
972
973        self.show_selection()
974
975    def _on_blank_clicked(self):
976        self.reset_select()
977
978    def reset_select(self):
979        self.selection.clear()
980        self.last_click_idx = None
981        self.drag_operation = None
982        self.key_operation = None
983        self.show_selection()
984
985    def _on_end_selecting(self):
986        self.apply()
987
988    def show_selection(self):
989        self.plot_mark.clear()
990        if not self.is_valid:  # though if it's not, selection is empty anyway
991            return
992
993        blue = QColor(Qt.blue)
994        pen = QPen(QBrush(blue), 3)
995        pen.setCosmetic(True)
996        brush = QBrush(blue.lighter(190))
997
998        for group in self.grouped_selection():
999            group = list(group)
1000            left_idx, right_idx = group[0], group[-1]
1001            left_pad, right_pad = self._determine_padding(left_idx, right_idx)
1002            x0 = self.bar_items[left_idx].x0 - left_pad
1003            x1 = self.bar_items[right_idx].x1 + right_pad
1004            item = QGraphicsRectItem(x0, 0, x1 - x0, 1)
1005            item.setPen(pen)
1006            item.setBrush(brush)
1007            if self.var.is_continuous:
1008                valname = self.str_int(
1009                    x0, x1, not left_idx, right_idx == len(self.bar_items) - 1)
1010                inside = sum(np.sum(self.bar_items[i].freqs) for i in group)
1011                total = len(self.valid_data)
1012                item.setToolTip(
1013                    "<p style='white-space:pre;'>"
1014                    f"<b>{escape(valname)}</b>: "
1015                    f"{inside} ({100 * inside / total:.2f} %)")
1016            self.plot_mark.addItem(item)
1017
1018    def _determine_padding(self, left_idx, right_idx):
1019        def _padding(i):
1020            return (self.bar_items[i + 1].x0 - self.bar_items[i].x1) / 2
1021
1022        if len(self.bar_items) == 1:
1023            return 6, 6
1024        if left_idx == 0 and right_idx == len(self.bar_items) - 1:
1025            return (_padding(0), ) * 2
1026
1027        if left_idx > 0:
1028            left_pad = _padding(left_idx - 1)
1029        if right_idx < len(self.bar_items) - 1:
1030            right_pad = _padding(right_idx)
1031        else:
1032            right_pad = left_pad
1033        if left_idx == 0:
1034            left_pad = right_pad
1035        return left_pad, right_pad
1036
1037    def grouped_selection(self):
1038        return [[g[1] for g in group]
1039                for _, group in groupby(enumerate(sorted(self.selection)),
1040                                        key=lambda x: x[1] - x[0])]
1041
1042    def keyPressEvent(self, e):
1043        def on_nothing_selected():
1044            if e.key() == Qt.Key_Left:
1045                self.last_click_idx = len(self.bar_items) - 1
1046            else:
1047                self.last_click_idx = 0
1048            self.selection.add(self.last_click_idx)
1049
1050        def on_key_left():
1051            if e.modifiers() & Qt.ShiftModifier:
1052                if self.key_operation == Qt.Key_Right and first != last:
1053                    self.selection.remove(last)
1054                    self.last_click_idx = last - 1
1055                elif first:
1056                    self.key_operation = Qt.Key_Left
1057                    self.selection.add(first - 1)
1058                    self.last_click_idx = first - 1
1059            else:
1060                self.selection.clear()
1061                self.last_click_idx = max(first - 1, 0)
1062                self.selection.add(self.last_click_idx)
1063
1064        def on_key_right():
1065            if e.modifiers() & Qt.ShiftModifier:
1066                if self.key_operation == Qt.Key_Left and first != last:
1067                    self.selection.remove(first)
1068                    self.last_click_idx = first + 1
1069                elif not self._is_last_bar(last):
1070                    self.key_operation = Qt.Key_Right
1071                    self.selection.add(last + 1)
1072                    self.last_click_idx = last + 1
1073            else:
1074                self.selection.clear()
1075                self.last_click_idx = min(last + 1, len(self.bar_items) - 1)
1076                self.selection.add(self.last_click_idx)
1077
1078        if not self.is_valid or not self.bar_items \
1079                or e.key() not in (Qt.Key_Left, Qt.Key_Right):
1080            super().keyPressEvent(e)
1081            return
1082
1083        prev_selection = self.selection.copy()
1084        if not self.selection:
1085            on_nothing_selected()
1086        else:
1087            first, last = min(self.selection), max(self.selection)
1088            if e.key() == Qt.Key_Left:
1089                on_key_left()
1090            else:
1091                on_key_right()
1092
1093        if self.selection != prev_selection:
1094            self.drag_operation = self.DragAdd
1095            self.show_selection()
1096            self.apply()
1097
1098    def keyReleaseEvent(self, ev):
1099        if ev.key() == Qt.Key_Shift:
1100            self.key_operation = None
1101        super().keyReleaseEvent(ev)
1102
1103
1104    # -----------------------------
1105    # Output
1106
1107    def apply(self):
1108        data = self.data
1109        selected_data = annotated_data = histogram_data = None
1110        if self.is_valid:
1111            if self.var.is_discrete:
1112                group_indices, values = self._get_output_indices_disc()
1113            else:
1114                group_indices, values = self._get_output_indices_cont()
1115            selected = np.nonzero(group_indices)[0]
1116            if selected.size:
1117                selected_data = create_groups_table(
1118                    data, group_indices,
1119                    include_unselected=False, values=values)
1120            annotated_data = create_annotated_table(data, selected)
1121            if self.var.is_continuous:  # annotate with bins
1122                hist_indices, hist_values = self._get_histogram_indices()
1123                annotated_data = create_groups_table(
1124                    annotated_data, hist_indices, var_name="Bin", values=hist_values)
1125            histogram_data = self._get_histogram_table()
1126
1127        self.Outputs.selected_data.send(selected_data)
1128        self.Outputs.annotated_data.send(annotated_data)
1129        self.Outputs.histogram_data.send(histogram_data)
1130
1131    def _get_output_indices_disc(self):
1132        group_indices = np.zeros(len(self.data), dtype=np.int32)
1133        col = self.data.get_column_view(self.var)[0].astype(float)
1134        for group_idx, val_idx in enumerate(self.selection, start=1):
1135            group_indices[col == val_idx] = group_idx
1136        values = [self.var.values[i] for i in self.selection]
1137        return group_indices, values
1138
1139    def _get_output_indices_cont(self):
1140        group_indices = np.zeros(len(self.data), dtype=np.int32)
1141        col = self.data.get_column_view(self.var)[0].astype(float)
1142        values = []
1143        for group_idx, group in enumerate(self.grouped_selection(), start=1):
1144            x0 = x1 = None
1145            for bar_idx in group:
1146                minx, maxx, mask = self._get_cont_baritem_indices(col, bar_idx)
1147                if x0 is None:
1148                    x0 = minx
1149                x1 = maxx
1150                group_indices[mask] = group_idx
1151            # pylint: disable=undefined-loop-variable
1152            values.append(
1153                self.str_int(x0, x1, not bar_idx, self._is_last_bar(bar_idx)))
1154        return group_indices, values
1155
1156    def _get_histogram_table(self):
1157        var_bin = DiscreteVariable("Bin", [bar.desc for bar in self.bar_items])
1158        var_freq = ContinuousVariable("Count")
1159        X = []
1160        if self.cvar:
1161            domain = Domain([var_bin, self.cvar, var_freq])
1162            for i, bar in enumerate(self.bar_items):
1163                for j, freq in enumerate(bar.freqs):
1164                    X.append([i, j, freq])
1165        else:
1166            domain = Domain([var_bin, var_freq])
1167            for i, bar in enumerate(self.bar_items):
1168                X.append([i, bar.freqs[0]])
1169        return Table.from_numpy(domain, X)
1170
1171    def _get_histogram_indices(self):
1172        group_indices = np.zeros(len(self.data), dtype=np.int32)
1173        col = self.data.get_column_view(self.var)[0].astype(float)
1174        values = []
1175        for bar_idx in range(len(self.bar_items)):
1176            x0, x1, mask = self._get_cont_baritem_indices(col, bar_idx)
1177            group_indices[mask] = bar_idx + 1
1178            values.append(
1179                self.str_int(x0, x1, not bar_idx, self._is_last_bar(bar_idx)))
1180        return group_indices, values
1181
1182    def _get_cont_baritem_indices(self, col, bar_idx):
1183        bar_item = self.bar_items[bar_idx]
1184        minx = bar_item.x0
1185        maxx = bar_item.x1 + (bar_idx == len(self.bar_items) - 1)
1186        with np.errstate(invalid="ignore"):
1187            return minx, maxx, (col >= minx) * (col < maxx)
1188
1189    def _is_last_bar(self, idx):
1190        return idx == len(self.bar_items) - 1
1191
1192    # -----------------------------
1193    # Report
1194
1195    def get_widget_name_extension(self):
1196        return self.var
1197
1198    def send_report(self):
1199        self.plotview.scene().setSceneRect(self.plotview.sceneRect())
1200        if not self.is_valid:
1201            return
1202        self.report_plot()
1203        if self.cumulative_distr:
1204            text = f"Cummulative distribution of '{self.var.name}'"
1205        else:
1206            text = f"Distribution of '{self.var.name}'"
1207        if self.cvar:
1208            text += f" with columns split by '{self.cvar.name}'"
1209        self.report_caption(text)
1210
1211
1212if __name__ == "__main__":  # pragma: no cover
1213    WidgetPreview(OWDistributions).run(Table("heart_disease.tab"))
1214