1# pylint: disable=too-many-lines
2from collections import namedtuple
3from itertools import chain, count
4from typing import List, Optional, Tuple, Set, Sequence
5
6import numpy as np
7from scipy import stats
8from sklearn.neighbors import KernelDensity
9
10from AnyQt.QtCore import QItemSelection, QPointF, QRectF, QSize, Qt, Signal
11from AnyQt.QtGui import QBrush, QColor, QPainter, QPainterPath, QPolygonF
12from AnyQt.QtWidgets import QCheckBox, QSizePolicy, QGraphicsRectItem, \
13    QGraphicsSceneMouseEvent, QApplication, QWidget, QComboBox
14
15import pyqtgraph as pg
16
17from orangewidget.utils.listview import ListViewSearch
18from orangewidget.utils.visual_settings_dlg import KeyType, ValueType, \
19    VisualSettingsDialog
20
21from Orange.data import ContinuousVariable, DiscreteVariable, Table
22from Orange.widgets import gui
23from Orange.widgets.settings import ContextSetting, DomainContextHandler, \
24    Setting
25from Orange.widgets.utils.annotated_data import ANNOTATED_DATA_SIGNAL_NAME, \
26    create_annotated_table
27from Orange.widgets.utils.itemmodels import VariableListModel
28from Orange.widgets.utils.sql import check_sql_input
29from Orange.widgets.visualize.owboxplot import SortProxyModel
30from Orange.widgets.visualize.utils.customizableplot import \
31    CommonParameterSetter, Updater
32from Orange.widgets.visualize.utils.plotutils import AxisItem
33from Orange.widgets.widget import OWWidget, Input, Output, Msg
34
35# scaling types
36AREA, COUNT, WIDTH = range(3)
37
38
39class ViolinPlotViewBox(pg.ViewBox):
40    sigSelectionChanged = Signal(QPointF, QPointF, bool)
41    sigDeselect = Signal(bool)
42
43    def __init__(self, _):
44        super().__init__()
45        self.setMouseMode(self.RectMode)
46
47    def mouseDragEvent(self, ev, axis=None):
48        if axis is None:
49            ev.accept()
50            if ev.button() == Qt.LeftButton:
51                p1, p2 = ev.buttonDownPos(), ev.pos()
52                self.sigSelectionChanged.emit(self.mapToView(p1),
53                                              self.mapToView(p2),
54                                              ev.isFinish())
55        else:
56            ev.ignore()
57
58    def mousePressEvent(self, ev: QGraphicsSceneMouseEvent):
59        self.sigDeselect.emit(False)
60        super().mousePressEvent(ev)
61
62    def mouseClickEvent(self, ev):
63        ev.accept()
64        self.sigDeselect.emit(True)
65
66
67class ParameterSetter(CommonParameterSetter):
68    BOTTOM_AXIS_LABEL, IS_VERTICAL_LABEL = "Bottom axis", "Vertical tick text"
69
70    def __init__(self, master):
71        self.master: ViolinPlot = master
72        self.titles_settings = {}
73        self.ticks_settings = {}
74        self.is_vertical_setting = False
75        super().__init__()
76
77    def update_setters(self):
78        def update_titles(**settings):
79            self.titles_settings.update(**settings)
80            Updater.update_axes_titles_font(self.axis_items, **settings)
81
82        def update_ticks(**settings):
83            self.ticks_settings.update(**settings)
84            Updater.update_axes_ticks_font(self.axis_items, **settings)
85
86        def update_bottom_axis(**settings):
87            self.is_vertical_setting = settings[self.IS_VERTICAL_LABEL]
88            self.bottom_axis.setRotateTicks(self.is_vertical_setting)
89
90        self._setters[self.LABELS_BOX][self.AXIS_TITLE_LABEL] = update_titles
91        self._setters[self.LABELS_BOX][self.AXIS_TICKS_LABEL] = update_ticks
92        self._setters[self.PLOT_BOX] = {
93            self.BOTTOM_AXIS_LABEL: update_bottom_axis,
94        }
95
96        self.initial_settings = {
97            self.LABELS_BOX: {
98                self.FONT_FAMILY_LABEL: self.FONT_FAMILY_SETTING,
99                self.TITLE_LABEL: self.FONT_SETTING,
100                self.AXIS_TITLE_LABEL: self.FONT_SETTING,
101                self.AXIS_TICKS_LABEL: self.FONT_SETTING,
102            },
103            self.ANNOT_BOX: {
104                self.TITLE_LABEL: {self.TITLE_LABEL: ("", "")},
105            },
106            self.PLOT_BOX: {
107                self.BOTTOM_AXIS_LABEL: {
108                    self.IS_VERTICAL_LABEL: (None, self.is_vertical_setting),
109                },
110            },
111        }
112
113    @property
114    def title_item(self) -> pg.LabelItem:
115        return self.master.getPlotItem().titleLabel
116
117    @property
118    def axis_items(self) -> List[AxisItem]:
119        return [value["item"] for value in
120                self.master.getPlotItem().axes.values()]
121
122    @property
123    def bottom_axis(self) -> AxisItem:
124        return self.master.getAxis("bottom")
125
126
127def fit_kernel(data: np.ndarray, kernel: str) -> \
128        Tuple[Optional[KernelDensity], float]:
129    assert np.all(np.isfinite(data))
130
131    if np.unique(data).size < 2:
132        return None, 1
133
134    # obtain bandwidth
135    try:
136        kde = stats.gaussian_kde(data)
137        bw = kde.factor * data.std(ddof=1)
138    except np.linalg.LinAlgError:
139        bw = 1
140
141    # fit selected kernel
142    kde = KernelDensity(bandwidth=bw, kernel=kernel)
143    kde.fit(data.reshape(-1, 1))
144    return kde, bw
145
146
147def scale_density(scale_type: int, density: np.ndarray, n_data: int,
148                  max_density: float) -> np.ndarray:
149    if scale_type == AREA:
150        return density
151    elif scale_type == COUNT:
152        return density * n_data / max_density
153    elif scale_type == WIDTH:
154        return density / max_density
155    else:
156        raise NotImplementedError
157
158
159class ViolinItem(pg.GraphicsObject):
160    RugPlot = namedtuple("RugPlot", "support, density")
161
162    def __init__(self, data: np.ndarray, color: QColor, kernel: str,
163                 scale: int, show_rug: bool, orientation: Qt.Orientations):
164        self.__scale = scale
165        self.__show_rug_plot = show_rug
166        self.__orientation = orientation
167
168        kde, bw = fit_kernel(data, kernel)
169        self.__kde: KernelDensity = kde
170        self.__bandwidth: float = bw
171
172        path, max_density = self._create_violin(data)
173        self.__violin_path: QPainterPath = path
174        self.__violin_brush: QBrush = QBrush(color)
175
176        self.__rug_plot_data: ViolinItem.RugPlot = \
177            self._create_rug_plot(data, max_density)
178
179        super().__init__()
180
181    @property
182    def density(self) -> np.ndarray:
183        # density on unique data
184        return self.__rug_plot_data.density
185
186    @property
187    def violin_width(self) -> float:
188        width = self.boundingRect().width() \
189            if self.__orientation == Qt.Vertical \
190            else self.boundingRect().height()
191        return width or 1
192
193    def set_show_rug_plot(self, show: bool):
194        self.__show_rug_plot = show
195        self.update()
196
197    def boundingRect(self) -> QRectF:
198        return self.__violin_path.boundingRect()
199
200    def paint(self, painter: QPainter, *_):
201        painter.save()
202        painter.setPen(pg.mkPen(QColor(Qt.black)))
203        painter.setBrush(self.__violin_brush)
204        painter.drawPath(self.__violin_path)
205
206        if self.__show_rug_plot:
207            data, density = self.__rug_plot_data
208            painter.setPen(pg.mkPen(QColor(Qt.black), width=1))
209            for x, y in zip(density, data):
210                if self.__orientation == Qt.Vertical:
211                    painter.drawLine(QPointF(-x, y), QPointF(x, y))
212                else:
213                    painter.drawLine(QPointF(y, -x), QPointF(y, x))
214
215        painter.restore()
216
217    def _create_violin(self, data: np.ndarray) -> Tuple[QPainterPath, float]:
218        if self.__kde is None:
219            x, p, max_density = np.zeros(1), np.zeros(1), 0
220        else:
221            x = np.linspace(data.min() - self.__bandwidth * 2,
222                            data.max() + self.__bandwidth * 2, 1000)
223            p = np.exp(self.__kde.score_samples(x.reshape(-1, 1)))
224            max_density = p.max()
225            p = scale_density(self.__scale, p, len(data), max_density)
226
227        if self.__orientation == Qt.Vertical:
228            pts = [QPointF(pi, xi) for xi, pi in zip(x, p)]
229            pts += [QPointF(-pi, xi) for xi, pi in reversed(list(zip(x, p)))]
230        else:
231            pts = [QPointF(xi, pi) for xi, pi in zip(x, p)]
232            pts += [QPointF(xi, -pi) for xi, pi in reversed(list(zip(x, p)))]
233        pts += pts[:1]
234
235        polygon = QPolygonF(pts)
236        path = QPainterPath()
237        path.addPolygon(polygon)
238        return path, max_density
239
240    def _create_rug_plot(self, data: np.ndarray, max_density: float) -> Tuple:
241        if self.__kde is None:
242            return self.RugPlot(data, np.zeros(data.size))
243
244        n_data = len(data)
245        data = np.unique(data)  # to optimize scoring
246        density = np.exp(self.__kde.score_samples(data.reshape(-1, 1)))
247        density = scale_density(self.__scale, density, n_data, max_density)
248        return self.RugPlot(data, density)
249
250
251class BoxItem(pg.GraphicsObject):
252    def __init__(self, data: np.ndarray, rect: QRectF,
253                 orientation: Qt.Orientations):
254        self.__bounding_rect = rect
255        self.__orientation = orientation
256
257        self.__box_plot_data: Tuple = self._create_box_plot(data)
258
259        super().__init__()
260
261    def boundingRect(self) -> QRectF:
262        return self.__bounding_rect
263
264    def paint(self, painter: QPainter, _, widget: Optional[QWidget]):
265        painter.save()
266
267        q0, q25, q75, q100 = self.__box_plot_data
268        if self.__orientation == Qt.Vertical:
269            quartile1 = QPointF(0, q0), QPointF(0, q100)
270            quartile2 = QPointF(0, q25), QPointF(0, q75)
271        else:
272            quartile1 = QPointF(q0, 0), QPointF(q100, 0)
273            quartile2 = QPointF(q25, 0), QPointF(q75, 0)
274
275        factor = 1 if widget is None else widget.devicePixelRatio()
276        painter.setPen(pg.mkPen(QColor(Qt.black), width=2 * factor))
277        painter.drawLine(*quartile1)
278        painter.setPen(pg.mkPen(QColor(Qt.black), width=6 * factor))
279        painter.drawLine(*quartile2)
280
281        painter.restore()
282
283    @staticmethod
284    def _create_box_plot(data: np.ndarray) -> Tuple:
285        if data.size == 0:
286            return (0,) * 4
287
288        q25, q75 = np.percentile(data, [25, 75])
289        whisker_lim = 1.5 * stats.iqr(data)
290        min_ = np.min(data[data >= (q25 - whisker_lim)])
291        max_ = np.max(data[data <= (q75 + whisker_lim)])
292        return min_, q25, q75, max_
293
294
295class MedianItem(pg.ScatterPlotItem):
296    def __init__(self, data: np.ndarray, orientation: Qt.Orientations):
297        self.__value = value = 0 if data.size == 0 else np.median(data)
298        x, y = (0, value) if orientation == Qt.Vertical else (value, 0)
299        super().__init__(x=[x], y=[y], size=5,
300                         pen=pg.mkPen(QColor(Qt.white)),
301                         brush=pg.mkBrush(QColor(Qt.white)))
302
303    @property
304    def value(self) -> float:
305        return self.__value
306
307    def setX(self, x: float):
308        self.setData(x=[x], y=[self.value])
309
310    def setY(self, y: float):
311        self.setData(x=[self.value], y=[y])
312
313
314class StripItem(pg.ScatterPlotItem):
315    def __init__(self, data: np.ndarray, density: np.ndarray,
316                 color: QColor, orientation: Qt.Orientations):
317        _, indices = np.unique(data, return_inverse=True)
318        density = density[indices]
319        self.__xdata = x = np.random.RandomState(0).uniform(-density, density)
320        self.__ydata = data
321        x, y = (x, data) if orientation == Qt.Vertical else (data, x)
322        color = color.lighter(150)
323        super().__init__(x=x, y=y, size=5, brush=pg.mkBrush(color))
324
325    def setX(self, x: float):
326        self.setData(x=self.__xdata + x, y=self.__ydata)
327
328    def setY(self, y: float):
329        self.setData(x=self.__ydata, y=self.__xdata + y)
330
331
332class SelectionRect(pg.GraphicsObject):
333    def __init__(self, rect: QRectF, orientation: Qt.Orientations):
334        self.__rect: QRectF = rect
335        self.__orientation: Qt.Orientations = orientation
336        self.__selection_range: Optional[Tuple[float, float]] = None
337        super().__init__()
338
339    @property
340    def selection_range(self) -> Optional[Tuple[float, float]]:
341        return self.__selection_range
342
343    @selection_range.setter
344    def selection_range(self, selection_range: Optional[Tuple[float, float]]):
345        self.__selection_range = selection_range
346        self.update()
347
348    @property
349    def selection_rect(self) -> QRectF:
350        rect: QRectF = self.__rect
351        if self.__selection_range is not None:
352            if self.__orientation == Qt.Vertical:
353                rect.setTop(self.__selection_range[0])
354                rect.setBottom(self.__selection_range[1])
355            else:
356                rect.setLeft(self.__selection_range[0])
357                rect.setRight(self.__selection_range[1])
358        return rect
359
360    def boundingRect(self) -> QRectF:
361        return self.__rect
362
363    def paint(self, painter: QPainter, *_):
364        painter.save()
365        painter.setPen(pg.mkPen((255, 255, 100), width=1))
366        painter.setBrush(pg.mkBrush(255, 255, 0, 100))
367        if self.__selection_range is not None:
368            painter.drawRect(self.selection_rect)
369        painter.restore()
370
371
372class ViolinPlot(pg.PlotWidget):
373    VIOLIN_PADDING_FACTOR = 1.25
374    SELECTION_PADDING_FACTOR = 1.20
375    selection_changed = Signal(list, list)
376
377    def __init__(self, parent: OWWidget, kernel: str, scale: int,
378                 orientation: Qt.Orientations, show_box_plot: bool,
379                 show_strip_plot: bool, show_rug_plot: bool, sort_items: bool):
380
381        # data
382        self.__values: Optional[np.ndarray] = None
383        self.__value_var: Optional[ContinuousVariable] = None
384        self.__group_values: Optional[np.ndarray] = None
385        self.__group_var: Optional[DiscreteVariable] = None
386
387        # settings
388        self.__kernel = kernel
389        self.__scale = scale
390        self.__orientation = orientation
391        self.__show_box_plot = show_box_plot
392        self.__show_strip_plot = show_strip_plot
393        self.__show_rug_plot = show_rug_plot
394        self.__sort_items = sort_items
395
396        # items
397        self.__violin_items: List[ViolinItem] = []
398        self.__box_items: List[BoxItem] = []
399        self.__median_items: List[MedianItem] = []
400        self.__strip_items: List[pg.ScatterPlotItem] = []
401
402        # selection
403        self.__selection: Set[int] = set()
404        self.__selection_rects: List[SelectionRect] = []
405
406        view_box = ViolinPlotViewBox(self)
407        super().__init__(parent, viewBox=view_box,
408                         background="w", enableMenu=False,
409                         axisItems={"bottom": AxisItem("bottom"),
410                                    "left": AxisItem("left")})
411        self.setAntialiasing(True)
412        self.hideButtons()
413        self.getPlotItem().setContentsMargins(10, 10, 10, 10)
414        self.setMouseEnabled(False, False)
415        view_box.sigSelectionChanged.connect(self._update_selection)
416        view_box.sigDeselect.connect(self._deselect)
417
418        self.parameter_setter = ParameterSetter(self)
419
420    @property
421    def _selection_ranges(self) -> List[Optional[Tuple[float, float]]]:
422        return [rect.selection_range for rect in self.__selection_rects]
423
424    @_selection_ranges.setter
425    def _selection_ranges(self, ranges: List[Optional[Tuple[float, float]]]):
426        for min_max, sel_rect in zip(ranges, self.__selection_rects):
427            sel_rect.selection_range = min_max
428
429    @property
430    def _sorted_group_indices(self) -> Sequence[int]:
431        medians = [item.value for item in self.__median_items]
432        return np.argsort(medians) if self.__sort_items \
433            else range(len(medians))
434
435    @property
436    def _max_item_width(self) -> float:
437        if not self.__violin_items:
438            return 0
439        return max(item.violin_width * self.VIOLIN_PADDING_FACTOR
440                   for item in self.__violin_items)
441
442    def set_data(self, values: np.ndarray, value_var: ContinuousVariable,
443                 group_values: Optional[np.ndarray],
444                 group_var: Optional[DiscreteVariable]):
445        self.__values = values
446        self.__value_var = value_var
447        self.__group_values = group_values
448        self.__group_var = group_var
449        self._set_axes()
450        self._plot_data()
451
452    def set_kernel(self, kernel: str):
453        if self.__kernel != kernel:
454            self.__kernel = kernel
455            self._plot_data()
456
457    def set_scale(self, scale: int):
458        if self.__scale != scale:
459            self.__scale = scale
460            self._plot_data()
461
462    def set_orientation(self, orientation: Qt.Orientations):
463        if self.__orientation != orientation:
464            self.__orientation = orientation
465            self._clear_axes()
466            self._set_axes()
467            self._plot_data()
468
469    def set_show_box_plot(self, show: bool):
470        if self.__show_box_plot != show:
471            self.__show_box_plot = show
472            for item in self.__box_items:
473                item.setVisible(show)
474            for item in self.__median_items:
475                item.setVisible(show)
476
477    def set_show_strip_plot(self, show: bool):
478        if self.__show_strip_plot != show:
479            self.__show_strip_plot = show
480            for item in self.__strip_items:
481                item.setVisible(show)
482
483    def set_show_rug_plot(self, show: bool):
484        if self.__show_rug_plot != show:
485            self.__show_rug_plot = show
486            for item in self.__violin_items:
487                item.set_show_rug_plot(show)
488
489    def set_sort_items(self, sort_items: bool):
490        if self.__sort_items != sort_items:
491            self.__sort_items = sort_items
492            if self.__group_var is not None:
493                self.order_items()
494
495    def order_items(self):
496        assert self.__group_var is not None
497
498        indices = self._sorted_group_indices
499
500        for i, index in enumerate(indices):
501            violin: ViolinItem = self.__violin_items[index]
502            box: BoxItem = self.__box_items[index]
503            median: MedianItem = self.__median_items[index]
504            strip: StripItem = self.__strip_items[index]
505            sel_rect: QGraphicsRectItem = self.__selection_rects[index]
506
507            if self.__orientation == Qt.Vertical:
508                x = i * self._max_item_width
509                violin.setX(x)
510                box.setX(x)
511                median.setX(x)
512                strip.setX(x)
513                sel_rect.setX(x)
514            else:
515                y = - i * self._max_item_width
516                violin.setY(y)
517                box.setY(y)
518                median.setY(y)
519                strip.setY(y)
520                sel_rect.setY(y)
521
522        sign = 1 if self.__orientation == Qt.Vertical else -1
523        side = "bottom" if self.__orientation == Qt.Vertical else "left"
524        ticks = [[(i * self._max_item_width * sign,
525                   self.__group_var.values[index])
526                  for i, index in enumerate(indices)]]
527        self.getAxis(side).setTicks(ticks)
528
529    def set_selection(self, ranges: List[Optional[Tuple[float, float]]]):
530        if self.__values is None:
531            return
532
533        self._selection_ranges = ranges
534
535        self.__selection = set()
536        for index, min_max in enumerate(ranges):
537            if min_max is None:
538                continue
539            mask = np.bitwise_and(self.__values >= min_max[0],
540                                  self.__values <= min_max[1])
541            if self.__group_values is not None:
542                mask = np.bitwise_and(mask, self.__group_values == index)
543            self.__selection |= set(np.flatnonzero(mask))
544
545        self.selection_changed.emit(sorted(self.__selection),
546                                    self._selection_ranges)
547
548    def _set_axes(self):
549        if self.__value_var is None:
550            return
551        value_title = self.__value_var.name
552        group_title = self.__group_var.name if self.__group_var else ""
553        vertical = self.__orientation == Qt.Vertical
554        self.getAxis("left" if vertical else "bottom").setLabel(value_title)
555        self.getAxis("bottom" if vertical else "left").setLabel(group_title)
556
557        if self.__group_var is None:
558            self.getAxis("bottom" if vertical else "left").setTicks([])
559
560    def _plot_data(self):
561        # save selection ranges
562        ranges = self._selection_ranges
563
564        self._clear_data_items()
565        if self.__values is None:
566            return
567
568        if not self.__group_var:
569            self._set_violin_item(self.__values, QColor(Qt.lightGray))
570        else:
571            assert self.__group_values is not None
572            for index in range(len(self.__group_var.values)):
573                mask = self.__group_values == index
574                color = QColor(*self.__group_var.colors[index])
575                self._set_violin_item(self.__values[mask], color)
576
577            self.order_items()
578
579        # apply selection ranges
580        self._selection_ranges = ranges
581
582    def _set_violin_item(self, values: np.ndarray, color: QColor):
583        values = values[~np.isnan(values)]
584
585        violin = ViolinItem(values, color, self.__kernel, self.__scale,
586                            self.__show_rug_plot, self.__orientation)
587        self.addItem(violin)
588        self.__violin_items.append(violin)
589
590        box = BoxItem(values, violin.boundingRect(), self.__orientation)
591        box.setVisible(self.__show_box_plot)
592        self.addItem(box)
593        self.__box_items.append(box)
594
595        median = MedianItem(values, self.__orientation)
596        median.setVisible(self.__show_box_plot)
597        self.addItem(median)
598        self.__median_items.append(median)
599
600        strip = StripItem(values, violin.density, color, self.__orientation)
601        strip.setVisible(self.__show_strip_plot)
602        self.addItem(strip)
603        self.__strip_items.append(strip)
604
605        width = self._max_item_width * self.SELECTION_PADDING_FACTOR / \
606                self.VIOLIN_PADDING_FACTOR
607        if self.__orientation == Qt.Vertical:
608            rect = QRectF(-width / 2, median.value, width, 0)
609        else:
610            rect = QRectF(median.value, -width / 2, 0, width)
611        sel_rect = SelectionRect(rect, self.__orientation)
612        self.addItem(sel_rect)
613        self.__selection_rects.append(sel_rect)
614
615    def clear_plot(self):
616        self.clear()
617        self._clear_data()
618        self._clear_data_items()
619        self._clear_axes()
620        self._clear_selection()
621
622    def _clear_data(self):
623        self.__values = None
624        self.__value_var = None
625        self.__group_values = None
626        self.__group_var = None
627
628    def _clear_data_items(self):
629        for i in range(len(self.__violin_items)):
630            self.removeItem(self.__violin_items[i])
631            self.removeItem(self.__box_items[i])
632            self.removeItem(self.__median_items[i])
633            self.removeItem(self.__strip_items[i])
634            self.removeItem(self.__selection_rects[i])
635        self.__violin_items.clear()
636        self.__box_items.clear()
637        self.__median_items.clear()
638        self.__strip_items.clear()
639        self.__selection_rects.clear()
640
641    def _clear_axes(self):
642        self.setAxisItems({"bottom": AxisItem(orientation="bottom"),
643                           "left": AxisItem(orientation="left")})
644        Updater.update_axes_titles_font(
645            self.parameter_setter.axis_items,
646            **self.parameter_setter.titles_settings
647        )
648        Updater.update_axes_ticks_font(
649            self.parameter_setter.axis_items,
650            **self.parameter_setter.ticks_settings
651        )
652        self.getAxis("bottom").setRotateTicks(
653            self.parameter_setter.is_vertical_setting
654        )
655
656    def _clear_selection(self):
657        self.__selection = set()
658
659    def _update_selection(self, p1: QPointF, p2: QPointF, finished: bool):
660        # When finished, emit selection_changed.
661        if len(self.__selection_rects) == 0:
662            return
663        assert self._max_item_width > 0
664
665        rect = QRectF(p1, p2).normalized()
666        if self.__orientation == Qt.Vertical:
667            min_max = rect.y(), rect.y() + rect.height()
668            index = int((p1.x() + self._max_item_width / 2) /
669                        self._max_item_width)
670        else:
671            min_max = rect.x(), rect.x() + rect.width()
672            index = int((-p1.y() + self._max_item_width / 2) /
673                        self._max_item_width)
674
675        index = min(index, len(self.__selection_rects) - 1)
676        index = self._sorted_group_indices[index]
677
678        self.__selection_rects[index].selection_range = min_max
679
680        if not finished:
681            return
682
683        mask = np.bitwise_and(self.__values >= min_max[0],
684                              self.__values <= min_max[1])
685        if self.__group_values is not None:
686            mask = np.bitwise_and(mask, self.__group_values == index)
687
688        selection = set(np.flatnonzero(mask))
689        keys = QApplication.keyboardModifiers()
690        if keys & Qt.ShiftModifier:
691            remove_mask = self.__group_values == index
692            selection |= self.__selection - set(np.flatnonzero(remove_mask))
693        if self.__selection != selection:
694            self.__selection = selection
695            self.selection_changed.emit(sorted(self.__selection),
696                                        self._selection_ranges)
697
698    def _deselect(self, finished: bool):
699        # When finished, emit selection_changed.
700        keys = QApplication.keyboardModifiers()
701        if keys & Qt.ShiftModifier:
702            return
703
704        for index in range(len(self.__selection_rects)):
705            self.__selection_rects[index].selection_range = None
706        if self.__selection and finished:
707            self.__selection = set()
708            self.selection_changed.emit([], [])
709
710    @staticmethod
711    def sizeHint() -> QSize:
712        return QSize(800, 600)
713
714
715class OWViolinPlot(OWWidget):
716    name = "Violin Plot"
717    description = "Visualize the distribution of feature" \
718                  " values in a violin plot."
719    icon = "icons/ViolinPlot.svg"
720    priority = 110
721    keywords = ["kernel", "density"]
722
723    class Inputs:
724        data = Input("Data", Table)
725
726    class Outputs:
727        selected_data = Output("Selected Data", Table, default=True)
728        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)
729
730    class Error(OWWidget.Error):
731        no_cont_features = Msg("Plotting requires a numeric feature.")
732        not_enough_instances = Msg("Plotting requires at least two instances.")
733
734    KERNELS = ["gaussian", "epanechnikov", "linear"]
735    KERNEL_LABELS = ["Normal", "Epanechnikov", "Linear"]
736    SCALE_LABELS = ["Area", "Count", "Width"]
737
738    settingsHandler = DomainContextHandler()
739    value_var = ContextSetting(None)
740    order_by_importance = Setting(False)
741    group_var = ContextSetting(None)
742    order_grouping_by_importance = Setting(False)
743    show_box_plot = Setting(True)
744    show_strip_plot = Setting(False)
745    show_rug_plot = Setting(False)
746    order_violins = Setting(False)
747    orientation_index = Setting(1)  # Vertical
748    kernel_index = Setting(0)  # Normal kernel
749    scale_index = Setting(AREA)
750    selection_ranges = Setting([], schema_only=True)
751    visual_settings = Setting({}, schema_only=True)
752
753    graph_name = "graph.plotItem"
754    buttons_area_orientation = None
755
756    def __init__(self):
757        super().__init__()
758        self.data: Optional[Table] = None
759        self.orig_data: Optional[Table] = None
760        self.graph: ViolinPlot = None
761        self._value_var_model: VariableListModel = None
762        self._group_var_model: VariableListModel = None
763        self._value_var_view: ListViewSearch = None
764        self._group_var_view: ListViewSearch = None
765        self._order_violins_cb: QCheckBox = None
766        self._scale_combo: QComboBox = None
767        self.selection = []
768        self.__pending_selection: List = self.selection_ranges
769
770        self.setup_gui()
771        VisualSettingsDialog(
772            self, self.graph.parameter_setter.initial_settings
773        )
774
775    def setup_gui(self):
776        self._add_graph()
777        self._add_controls()
778
779    def _add_graph(self):
780        box = gui.vBox(self.mainArea)
781        self.graph = ViolinPlot(self, self.kernel,
782                                self.scale_index, self.orientation,
783                                self.show_box_plot, self.show_strip_plot,
784                                self.show_rug_plot, self.order_violins)
785        self.graph.selection_changed.connect(self.__selection_changed)
786        box.layout().addWidget(self.graph)
787
788    def __selection_changed(self, indices: List, ranges: List):
789        self.selection_ranges = ranges
790        if self.selection != indices:
791            self.selection = indices
792            self.commit()
793
794    def _add_controls(self):
795        self._value_var_model = VariableListModel()
796        sorted_model = SortProxyModel(sortRole=Qt.UserRole)
797        sorted_model.setSourceModel(self._value_var_model)
798        sorted_model.sort(0)
799
800        view = self._value_var_view = ListViewSearch()
801        view.setModel(sorted_model)
802        view.setMinimumSize(QSize(30, 100))
803        view.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Ignored)
804        view.selectionModel().selectionChanged.connect(
805            self.__value_var_changed
806        )
807
808        self._group_var_model = VariableListModel(placeholder="None")
809        sorted_model = SortProxyModel(sortRole=Qt.UserRole)
810        sorted_model.setSourceModel(self._group_var_model)
811        sorted_model.sort(0)
812
813        view = self._group_var_view = ListViewSearch()
814        view.setModel(sorted_model)
815        view.setMinimumSize(QSize(30, 100))
816        view.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Ignored)
817        view.selectionModel().selectionChanged.connect(
818            self.__group_var_changed
819        )
820
821        box = gui.vBox(self.controlArea, "Variable")
822        box.layout().addWidget(self._value_var_view)
823        gui.checkBox(box, self, "order_by_importance",
824                     "Order by relevance to subgroups",
825                     tooltip="Order by ��² or ANOVA over the subgroups",
826                     callback=self.apply_value_var_sorting)
827
828        box = gui.vBox(self.controlArea, "Subgroups")
829        box.layout().addWidget(self._group_var_view)
830        gui.checkBox(box, self, "order_grouping_by_importance",
831                     "Order by relevance to variable",
832                     tooltip="Order by ��² or ANOVA over the variable values",
833                     callback=self.apply_group_var_sorting)
834
835        box = gui.vBox(self.controlArea, "Display",
836                       sizePolicy=(QSizePolicy.Minimum, QSizePolicy.Maximum))
837        gui.checkBox(box, self, "show_box_plot", "Box plot",
838                     callback=self.__show_box_plot_changed)
839        gui.checkBox(box, self, "show_strip_plot", "Strip plot",
840                     callback=self.__show_strip_plot_changed)
841        gui.checkBox(box, self, "show_rug_plot", "Rug plot",
842                     callback=self.__show_rug_plot_changed)
843        self._order_violins_cb = gui.checkBox(
844            box, self, "order_violins", "Order subgroups",
845            callback=self.__order_violins_changed,
846        )
847        gui.radioButtons(box, self, "orientation_index",
848                         ["Horizontal", "Vertical"], label="Orientation: ",
849                         orientation=Qt.Horizontal,
850                         callback=self.__orientation_changed)
851
852        box = gui.vBox(self.controlArea, "Density Estimation",
853                       sizePolicy=(QSizePolicy.Minimum, QSizePolicy.Maximum))
854        gui.comboBox(box, self, "kernel_index", items=self.KERNEL_LABELS,
855                     label="Kernel:", labelWidth=60, orientation=Qt.Horizontal,
856                     callback=self.__kernel_changed)
857        self._scale_combo = gui.comboBox(
858            box, self, "scale_index", items=self.SCALE_LABELS,
859            label="Scale:", labelWidth=60, orientation=Qt.Horizontal,
860            callback=self.__scale_changed
861        )
862
863    def __value_var_changed(self, selection: QItemSelection):
864        if not selection:
865            return
866        self.value_var = selection.indexes()[0].data(gui.TableVariable)
867        self.apply_group_var_sorting()
868        self.setup_plot()
869        self.__selection_changed([], [])
870
871    def __group_var_changed(self, selection: QItemSelection):
872        if not selection:
873            return
874        self.group_var = selection.indexes()[0].data(gui.TableVariable)
875        self.apply_value_var_sorting()
876        self.enable_controls()
877        self.setup_plot()
878        self.__selection_changed([], [])
879
880    def __show_box_plot_changed(self):
881        self.graph.set_show_box_plot(self.show_box_plot)
882
883    def __show_strip_plot_changed(self):
884        self.graph.set_show_strip_plot(self.show_strip_plot)
885
886    def __show_rug_plot_changed(self):
887        self.graph.set_show_rug_plot(self.show_rug_plot)
888
889    def __order_violins_changed(self):
890        self.graph.set_sort_items(self.order_violins)
891
892    def __orientation_changed(self):
893        self.graph.set_orientation(self.orientation)
894
895    def __kernel_changed(self):
896        self.graph.set_kernel(self.kernel)
897
898    def __scale_changed(self):
899        self.graph.set_scale(self.scale_index)
900
901    @property
902    def kernel(self) -> str:
903        # pylint: disable=invalid-sequence-index
904        return self.KERNELS[self.kernel_index]
905
906    @property
907    def orientation(self) -> Qt.Orientations:
908        # pylint: disable=invalid-sequence-index
909        return [Qt.Horizontal, Qt.Vertical][self.orientation_index]
910
911    @Inputs.data
912    @check_sql_input
913    def set_data(self, data: Optional[Table]):
914        self.closeContext()
915        self.clear()
916        self.orig_data = self.data = data
917        self.check_data()
918        self.init_list_view()
919        self.openContext(self.data)
920        self.set_list_view_selection()
921        self.apply_value_var_sorting()
922        self.apply_group_var_sorting()
923        self.enable_controls()
924        self.setup_plot()
925        self.apply_selection()
926
927    def check_data(self):
928        self.clear_messages()
929        if self.data is not None:
930            if self.data.domain.has_continuous_attributes(True, True) == 0:
931                self.Error.no_cont_features()
932                self.data = None
933            elif len(self.data) < 2:
934                self.Error.not_enough_instances()
935                self.data = None
936
937    def init_list_view(self):
938        if not self.data:
939            return
940
941        domain = self.data.domain
942        self._value_var_model[:] = [
943            var for var in chain(
944                domain.class_vars, domain.metas, domain.attributes)
945            if var.is_continuous and not var.attributes.get("hidden", False)]
946        self._group_var_model[:] = [None] + [
947            var for var in chain(
948                domain.class_vars, domain.metas, domain.attributes)
949            if var.is_discrete and not var.attributes.get("hidden", False)]
950
951        if len(self._value_var_model) > 0:
952            self.value_var = self._value_var_model[0]
953
954        self.group_var = self._group_var_model[0]
955        if domain.class_var and domain.class_var.is_discrete:
956            self.group_var = domain.class_var
957
958    def set_list_view_selection(self):
959        for view, var, callback in ((self._value_var_view, self.value_var,
960                                     self.__value_var_changed),
961                                    (self._group_var_view, self.group_var,
962                                     self.__group_var_changed)):
963            src_model = view.model().sourceModel()
964            if var not in src_model:
965                continue
966            sel_model = view.selectionModel()
967            sel_model.selectionChanged.disconnect(callback)
968            row = src_model.indexOf(var)
969            index = view.model().index(row, 0)
970            sel_model.select(index, sel_model.ClearAndSelect)
971            self._ensure_selection_visible(view)
972            sel_model.selectionChanged.connect(callback)
973
974    def apply_value_var_sorting(self):
975        def compute_score(attr):
976            if attr is group_var:
977                return 3
978            col = self.data.get_column_view(attr)[0].astype(float)
979            groups = (col[group_col == i] for i in range(n_groups))
980            groups = (col[~np.isnan(col)] for col in groups)
981            groups = [group for group in groups if len(group)]
982            p = stats.f_oneway(*groups)[1] if len(groups) > 1 else 2
983            if np.isnan(p):
984                return 2
985            return p
986
987        if self.data is None:
988            return
989        group_var = self.group_var
990        if self.order_by_importance and group_var is not None:
991            n_groups = len(group_var.values)
992            group_col = self.data.get_column_view(group_var)[0].astype(float)
993            self._sort_list(self._value_var_model, self._value_var_view,
994                            compute_score)
995        else:
996            self._sort_list(self._value_var_model, self._value_var_view, None)
997
998    def apply_group_var_sorting(self):
999        def compute_stat(group):
1000            if group is value_var:
1001                return 3
1002            if group is None:
1003                return -1
1004            col = self.data.get_column_view(group)[0].astype(float)
1005            groups = (value_col[col == i] for i in range(len(group.values)))
1006            groups = (col[~np.isnan(col)] for col in groups)
1007            groups = [group for group in groups if len(group)]
1008            p = stats.f_oneway(*groups)[1] if len(groups) > 1 else 2
1009            if np.isnan(p):
1010                return 2
1011            return p
1012
1013        if self.data is None:
1014            return
1015        value_var = self.value_var
1016        if self.order_grouping_by_importance:
1017            value_col = self.data.get_column_view(value_var)[0].astype(float)
1018            self._sort_list(self._group_var_model, self._group_var_view,
1019                            compute_stat)
1020        else:
1021            self._sort_list(self._group_var_model, self._group_var_view, None)
1022
1023    def _sort_list(self, source_model, view, key=None):
1024        if key is None:
1025            c = count()
1026
1027            def key(_):  # pylint: disable=function-redefined
1028                return next(c)
1029
1030        for i, attr in enumerate(source_model):
1031            source_model.setData(source_model.index(i), key(attr), Qt.UserRole)
1032        self._ensure_selection_visible(view)
1033
1034    @staticmethod
1035    def _ensure_selection_visible(view):
1036        selection = view.selectedIndexes()
1037        if len(selection) == 1:
1038            view.scrollTo(selection[0])
1039
1040    def enable_controls(self):
1041        enable = self.group_var is not None or not self.data
1042        self._order_violins_cb.setEnabled(enable)
1043        self._scale_combo.setEnabled(enable)
1044
1045    def setup_plot(self):
1046        self.graph.clear_plot()
1047        if not self.data:
1048            return
1049
1050        y = self.data.get_column_view(self.value_var)[0].astype(float)
1051        x = None
1052        if self.group_var:
1053            x = self.data.get_column_view(self.group_var)[0].astype(float)
1054        self.graph.set_data(y, self.value_var, x, self.group_var)
1055
1056    def apply_selection(self):
1057        if self.__pending_selection:
1058            # commit is invoked on selection_changed
1059            self.selection_ranges = self.__pending_selection
1060            self.__pending_selection = []
1061            self.graph.set_selection(self.selection_ranges)
1062        else:
1063            self.commit()
1064
1065    def commit(self):
1066        selected = None
1067        if self.data is not None and bool(self.selection):
1068            selected = self.data[self.selection]
1069        annotated = create_annotated_table(self.orig_data, self.selection)
1070        self.Outputs.selected_data.send(selected)
1071        self.Outputs.annotated_data.send(annotated)
1072
1073    def clear(self):
1074        self._value_var_model[:] = []
1075        self._group_var_model[:] = []
1076        self.selection = []
1077        self.selection_ranges = []
1078        self.graph.clear_plot()
1079
1080    def send_report(self):
1081        if self.data is None:
1082            return
1083        self.report_plot()
1084
1085    def set_visual_settings(self, key: KeyType, value: ValueType):
1086        self.graph.parameter_setter.set_parameter(key, value)
1087        # pylint: disable=unsupported-assignment-operation
1088        self.visual_settings[key] = value
1089
1090
1091if __name__ == "__main__":
1092    from Orange.widgets.utils.widgetpreview import WidgetPreview
1093
1094    WidgetPreview(OWViolinPlot).run(set_data=Table("heart_disease"))
1095