1import operator
2from functools import reduce, wraps
3from collections import namedtuple, deque, OrderedDict
4
5import numpy as np
6import sklearn.metrics as skl_metrics
7
8from AnyQt.QtWidgets import QListView, QLabel, QGridLayout, QFrame, QAction, \
9    QToolTip
10from AnyQt.QtGui import QColor, QPen, QBrush, QPainter, QPalette, QFont, \
11    QCursor, QFontMetrics
12from AnyQt.QtCore import Qt, QSize
13import pyqtgraph as pg
14
15import Orange
16from Orange.widgets import widget, gui, settings
17from Orange.widgets.evaluate.contexthandlers import \
18    EvaluationResultsContextHandler
19from Orange.widgets.evaluate.utils import check_results_adequacy
20from Orange.widgets.utils import colorpalettes
21from Orange.widgets.utils.widgetpreview import WidgetPreview
22from Orange.widgets.widget import Input
23from Orange.widgets import report
24
25from Orange.widgets.evaluate.utils import results_for_preview
26from Orange.evaluation.testing import Results
27
28
29#: Points on a ROC curve
30ROCPoints = namedtuple(
31    "ROCPoints",
32    ["fpr",        # (N,) array of false positive rate coordinates (ascending)
33     "tpr",        # (N,) array of true positive rate coordinates
34     "thresholds"  # (N,) array of thresholds (in descending order)
35     ]
36)
37ROCPoints.is_valid = property(lambda self: self.fpr.size > 0)
38
39#: ROC Curve and it's convex hull
40ROCCurve = namedtuple(
41    "ROCCurve",
42    ["points",  # ROCPoints
43     "hull"     # ROCPoints of the convex hull
44     ]
45)
46ROCCurve.is_valid = property(lambda self: self.points.is_valid)
47
48#: A ROC Curve averaged vertically
49ROCAveragedVert = namedtuple(
50    "ROCAveragedVert",
51    ["points",   # ROCPoints sampled by fpr
52     "hull",     # ROCPoints of the convex hull
53     "tpr_std",  # array standard deviation of tpr at each fpr point
54     ]
55)
56ROCAveragedVert.is_valid = property(lambda self: self.points.is_valid)
57
58#: A ROC Curve averaged by thresholds
59ROCAveragedThresh = namedtuple(
60    "ROCAveragedThresh",
61    ["points",   # ROCPoints sampled by threshold
62     "hull",     # ROCPoints of the convex hull
63     "tpr_std",  # array standard deviations of tpr at each threshold
64     "fpr_std"   # array standard deviations of fpr at each threshold
65     ]
66)
67ROCAveragedThresh.is_valid = property(lambda self: self.points.is_valid)
68
69#: Combined data for a ROC curve of a single algorithm
70ROCData = namedtuple(
71    "ROCData",
72    ["merged",  # ROCCurve merged over all folds
73     "folds",   # ROCCurve list, one for each fold
74     "avg_vertical",   # ROCAveragedVert
75     "avg_threshold",  # ROCAveragedThresh
76     ]
77)
78
79
80def roc_data_from_results(results, clf_index, target):
81    """
82    Compute ROC Curve(s) from evaluation results.
83
84    :param Orange.evaluation.Results results:
85        Evaluation results.
86    :param int clf_index:
87        Learner index in the `results`.
88    :param int target:
89        Target class index (i.e. positive class).
90    :rval ROCData:
91        A instance holding the computed curves.
92    """
93    merged = roc_curve_for_fold(results, ..., clf_index, target)
94    merged_curve = ROCCurve(ROCPoints(*merged),
95                            ROCPoints(*roc_curve_convex_hull(merged)))
96
97    folds = results.folds if results.folds is not None else [...]
98    fold_curves = []
99    for fold in folds:
100        points = roc_curve_for_fold(results, fold, clf_index, target)
101        hull = roc_curve_convex_hull(points)
102        c = ROCCurve(ROCPoints(*points), ROCPoints(*hull))
103        fold_curves.append(c)
104
105    curves = [fold.points for fold in fold_curves
106              if fold.is_valid]
107
108    if curves:
109        fpr, tpr, std = roc_curve_vertical_average(curves)
110
111        thresh = np.zeros_like(fpr) * np.nan
112        hull = roc_curve_convex_hull((fpr, tpr, thresh))
113        v_avg = ROCAveragedVert(
114            ROCPoints(fpr, tpr, thresh),
115            ROCPoints(*hull),
116            std
117        )
118    else:
119        # return an invalid vertical averaged ROC
120        v_avg = ROCAveragedVert(
121            ROCPoints(np.array([]), np.array([]), np.array([])),
122            ROCPoints(np.array([]), np.array([]), np.array([])),
123            np.array([])
124        )
125
126    if curves:
127        all_thresh = np.hstack([t for _, _, t in curves])
128        all_thresh = np.clip(all_thresh, 0.0 - 1e-10, 1.0 + 1e-10)
129        all_thresh = np.unique(all_thresh)[::-1]
130        thresh = all_thresh[::max(all_thresh.size // 10, 1)]
131
132        (fpr, fpr_std), (tpr, tpr_std) = \
133            roc_curve_threshold_average(curves, thresh)
134
135        hull = roc_curve_convex_hull((fpr, tpr, thresh))
136
137        t_avg = ROCAveragedThresh(
138            ROCPoints(fpr, tpr, thresh),
139            ROCPoints(*hull),
140            tpr_std,
141            fpr_std
142        )
143    else:
144        # return an invalid threshold averaged ROC
145        t_avg = ROCAveragedThresh(
146            ROCPoints(np.array([]), np.array([]), np.array([])),
147            ROCPoints(np.array([]), np.array([]), np.array([])),
148            np.array([]),
149            np.array([])
150        )
151    return ROCData(merged_curve, fold_curves, v_avg, t_avg)
152
153ROCData.from_results = staticmethod(roc_data_from_results)
154
155#: A curve item to be displayed in a plot
156PlotCurve = namedtuple(
157    "PlotCurve",
158    ["curve",        # ROCCurve source curve
159     "curve_item",   # pg.PlotDataItem main curve
160     "hull_item"     # pg.PlotDataItem curve's convex hull
161     ]
162)
163
164
165def plot_curve(curve, pen=None, shadow_pen=None, symbol="+",
166               symbol_size=3, name=None):
167    """
168    Construct a `PlotCurve` for the given `ROCCurve`.
169
170    :param ROCCurve curve:
171        Source curve.
172
173    The other parameters are passed to pg.PlotDataItem
174
175    :rtype: PlotCurve
176    """
177    def extend_to_origin(points):
178        "Extend ROCPoints to include coordinate origin if not already present"
179        if points.tpr.size and (points.tpr[0] > 0 or points.fpr[0] > 0):
180            points = ROCPoints(
181                np.r_[0, points.fpr], np.r_[0, points.tpr],
182                np.r_[points.thresholds[0] + 1, points.thresholds]
183            )
184        return points
185
186    points = extend_to_origin(curve.points)
187    item = pg.PlotCurveItem(
188        points.fpr, points.tpr, pen=pen, shadowPen=shadow_pen,
189        name=name, antialias=True
190    )
191    sp = pg.ScatterPlotItem(
192        curve.points.fpr, curve.points.tpr, symbol=symbol,
193        size=symbol_size, pen=shadow_pen,
194        name=name
195    )
196    sp.setParentItem(item)
197
198    hull = extend_to_origin(curve.hull)
199
200    hull_item = pg.PlotDataItem(
201        hull.fpr, hull.tpr, pen=pen, antialias=True
202    )
203    return PlotCurve(curve, item, hull_item)
204
205PlotCurve.from_roc_curve = staticmethod(plot_curve)
206
207#: A curve displayed in a plot with error bars
208PlotAvgCurve = namedtuple(
209    "PlotAvgCurve",
210    ["curve",         # ROCCurve
211     "curve_item",    # pg.PlotDataItem
212     "hull_item",     # pg.PlotDataItem
213     "confint_item",  # pg.ErrorBarItem
214     ]
215)
216
217
218def plot_avg_curve(curve, pen=None, shadow_pen=None, symbol="+",
219                   symbol_size=4, name=None):
220    """
221    Construct a `PlotAvgCurve` for the given `curve`.
222
223    :param curve: Source curve.
224    :type curve: ROCAveragedVert or ROCAveragedThresh
225
226    The other parameters are passed to pg.PlotDataItem
227
228    :rtype: PlotAvgCurve
229    """
230    pc = plot_curve(curve, pen=pen, shadow_pen=shadow_pen, symbol=symbol,
231                    symbol_size=symbol_size, name=name)
232
233    points = curve.points
234    if isinstance(curve, ROCAveragedVert):
235        tpr_std = curve.tpr_std
236        error_item = pg.ErrorBarItem(
237            x=points.fpr[1:-1], y=points.tpr[1:-1],
238            height=2 * tpr_std[1:-1],
239            pen=pen, beam=0.025,
240            antialias=True,
241        )
242    elif isinstance(curve, ROCAveragedThresh):
243        tpr_std, fpr_std = curve.tpr_std, curve.fpr_std
244        error_item = pg.ErrorBarItem(
245            x=points.fpr[1:-1], y=points.tpr[1:-1],
246            height=2 * tpr_std[1:-1], width=2 * fpr_std[1:-1],
247            pen=pen, beam=0.025,
248            antialias=True,
249        )
250    return PlotAvgCurve(curve, pc.curve_item, pc.hull_item, error_item)
251
252PlotAvgCurve.from_roc_curve = staticmethod(plot_avg_curve)
253
254Some = namedtuple("Some", ["val"])
255
256
257def once(f):
258    """
259    Return a function that will be called only once, and it's result cached.
260    """
261    cached = None
262
263    @wraps(f)
264    def wraped():
265        nonlocal cached
266        if cached is None:
267            cached = Some(f())
268        return cached.val
269    return wraped
270
271
272PlotCurves = namedtuple(
273    "PlotCurves",
274    ["merge",   # :: () -> PlotCurve
275     "folds",   # :: () -> [PlotCurve]
276     "avg_vertical",   # :: () -> PlotAvgCurve
277     "avg_threshold",  # :: () -> PlotAvgCurve
278     ]
279)
280
281
282class InfiniteLine(pg.InfiniteLine):
283    """pyqtgraph.InfiniteLine extended to support antialiasing.
284    """
285    def __init__(self, pos=None, angle=90, pen=None, movable=False,
286                 bounds=None, antialias=False):
287        super().__init__(pos, angle, pen, movable, bounds)
288        self.antialias = antialias
289
290    def paint(self, p, *args):
291        if self.antialias:
292            p.setRenderHint(QPainter.Antialiasing, True)
293        super().paint(p, *args)
294
295
296class OWROCAnalysis(widget.OWWidget):
297    name = "ROC Analysis"
298    description = "Display the Receiver Operating Characteristics curve " \
299                  "based on the evaluation of classifiers."
300    icon = "icons/ROCAnalysis.svg"
301    priority = 1010
302    keywords = []
303
304    class Inputs:
305        evaluation_results = Input("Evaluation Results", Orange.evaluation.Results)
306
307    buttons_area_orientation = None
308    settingsHandler = EvaluationResultsContextHandler()
309    target_index = settings.ContextSetting(0)
310    selected_classifiers = settings.ContextSetting([])
311
312    display_perf_line = settings.Setting(True)
313    display_def_threshold = settings.Setting(True)
314
315    fp_cost = settings.Setting(500)
316    fn_cost = settings.Setting(500)
317    target_prior = settings.Setting(50.0, schema_only=True)
318
319    #: ROC Averaging Types
320    Merge, Vertical, Threshold, NoAveraging = 0, 1, 2, 3
321    roc_averaging = settings.Setting(Merge)
322
323    display_convex_hull = settings.Setting(False)
324    display_convex_curve = settings.Setting(False)
325
326    graph_name = "plot"
327
328    def __init__(self):
329        super().__init__()
330
331        self.results = None
332        self.classifier_names = []
333        self.perf_line = None
334        self.colors = []
335        self._curve_data = {}
336        self._plot_curves = {}
337        self._rocch = None
338        self._perf_line = None
339        self._tooltip_cache = None
340
341        box = gui.vBox(self.controlArea, "Plot")
342        self.target_cb = gui.comboBox(
343            box, self, "target_index",
344            label="Target", orientation=Qt.Horizontal,
345            callback=self._on_target_changed,
346            contentsLength=8, searchable=True)
347
348        gui.widgetLabel(box, "Classifiers")
349        line_height = 4 * QFontMetrics(self.font()).lineSpacing()
350        self.classifiers_list_box = gui.listBox(
351            box, self, "selected_classifiers", "classifier_names",
352            selectionMode=QListView.MultiSelection,
353            callback=self._on_classifiers_changed,
354            sizeHint=QSize(0, line_height))
355
356        abox = gui.vBox(self.controlArea, "Curves")
357        gui.comboBox(abox, self, "roc_averaging",
358                     items=["Merge Predictions from Folds", "Mean TP Rate",
359                            "Mean TP and FP at Threshold", "Show Individual Curves"],
360                     callback=self._replot)
361
362        gui.checkBox(abox, self, "display_convex_curve",
363                     "Show convex ROC curves", callback=self._replot)
364        gui.checkBox(abox, self, "display_convex_hull",
365                     "Show ROC convex hull", callback=self._replot)
366
367        box = gui.vBox(self.controlArea, "Analysis")
368
369        gui.checkBox(box, self, "display_def_threshold",
370                     "Default threshold (0.5) point",
371                     callback=self._on_display_def_threshold_changed)
372
373        gui.checkBox(box, self, "display_perf_line", "Show performance line",
374                     callback=self._on_display_perf_line_changed)
375        grid = QGridLayout()
376        gui.indentedBox(box, orientation=grid)
377
378        sp = gui.spin(box, self, "fp_cost", 1, 1000, 10,
379                      alignment=Qt.AlignRight,
380                      callback=self._on_display_perf_line_changed)
381        grid.addWidget(QLabel("FP Cost:"), 0, 0)
382        grid.addWidget(sp, 0, 1)
383
384        sp = gui.spin(box, self, "fn_cost", 1, 1000, 10,
385                      alignment=Qt.AlignRight,
386                      callback=self._on_display_perf_line_changed)
387        grid.addWidget(QLabel("FN Cost:"))
388        grid.addWidget(sp, 1, 1)
389        self.target_prior_sp = gui.spin(box, self, "target_prior", 1, 99,
390                                        alignment=Qt.AlignRight,
391                                        callback=self._on_target_prior_changed)
392        self.target_prior_sp.setSuffix(" %")
393        self.target_prior_sp.addAction(QAction("Auto", sp))
394        grid.addWidget(QLabel("Prior probability:"))
395        grid.addWidget(self.target_prior_sp, 2, 1)
396
397        self.plotview = pg.GraphicsView(background="w")
398        self.plotview.setFrameStyle(QFrame.StyledPanel)
399        self.plotview.scene().sigMouseMoved.connect(self._on_mouse_moved)
400
401        self.plot = pg.PlotItem(enableMenu=False)
402        self.plot.setMouseEnabled(False, False)
403        self.plot.hideButtons()
404
405        pen = QPen(self.palette().color(QPalette.Text))
406
407        tickfont = QFont(self.font())
408        tickfont.setPixelSize(max(int(tickfont.pixelSize() * 2 // 3), 11))
409
410        axis = self.plot.getAxis("bottom")
411        axis.setTickFont(tickfont)
412        axis.setPen(pen)
413        axis.setLabel("FP Rate (1-Specificity)")
414        axis.setGrid(16)
415
416        axis = self.plot.getAxis("left")
417        axis.setTickFont(tickfont)
418        axis.setPen(pen)
419        axis.setLabel("TP Rate (Sensitivity)")
420        axis.setGrid(16)
421
422        self.plot.showGrid(True, True, alpha=0.1)
423        self.plot.setRange(xRange=(0.0, 1.0), yRange=(0.0, 1.0), padding=0.05)
424
425        self.plotview.setCentralItem(self.plot)
426        self.mainArea.layout().addWidget(self.plotview)
427
428    @Inputs.evaluation_results
429    def set_results(self, results):
430        """Set the input evaluation results."""
431        self.closeContext()
432        self.clear()
433        self.results = check_results_adequacy(results, self.Error)
434        if self.results is not None:
435            self._initialize(self.results)
436            self.openContext(self.results.domain.class_var,
437                             self.classifier_names)
438            self._setup_plot()
439        else:
440            self.warning()
441
442    def clear(self):
443        """Clear the widget state."""
444        self.results = None
445        self.plot.clear()
446        self.classifier_names = []
447        self.selected_classifiers = []
448        self.target_cb.clear()
449        self.colors = []
450        self._curve_data = {}
451        self._plot_curves = {}
452        self._rocch = None
453        self._perf_line = None
454        self._tooltip_cache = None
455
456    def _initialize(self, results):
457        names = getattr(results, "learner_names", None)
458
459        if names is None:
460            names = ["#{}".format(i + 1)
461                     for i in range(len(results.predicted))]
462
463        self.colors = colorpalettes.get_default_curve_colors(len(names))
464
465        self.classifier_names = names
466        self.selected_classifiers = list(range(len(names)))
467        for i in range(len(names)):
468            listitem = self.classifiers_list_box.item(i)
469            listitem.setIcon(colorpalettes.ColorIcon(self.colors[i]))
470
471        class_var = results.data.domain.class_var
472        self.target_cb.addItems(class_var.values)
473        self.target_index = 0
474        self._set_target_prior()
475
476    def _set_target_prior(self):
477        """
478        This function sets the initial target class probability prior value
479        based on the input data.
480        """
481        if self.results.data:
482            # here we can use target_index directly since values in the
483            # dropdown are sorted in same order than values in the table
484            target_values_cnt = np.count_nonzero(
485                self.results.data.Y == self.target_index)
486            count_all = np.count_nonzero(~np.isnan(self.results.data.Y))
487            self.target_prior = np.round(target_values_cnt / count_all * 100)
488
489            # set the spin text to gray color when set automatically
490            self.target_prior_sp.setStyleSheet("color: gray;")
491
492    def curve_data(self, target, clf_idx):
493        """Return `ROCData' for the given target and classifier."""
494        if (target, clf_idx) not in self._curve_data:
495            # pylint: disable=no-member
496            data = ROCData.from_results(self.results, clf_idx, target)
497            self._curve_data[target, clf_idx] = data
498
499        return self._curve_data[target, clf_idx]
500
501    def plot_curves(self, target, clf_idx):
502        """Return a set of functions `plot_curves` generating plot curves."""
503        def generate_pens(basecolor):
504            pen = QPen(basecolor, 1)
505            pen.setCosmetic(True)
506
507            shadow_pen = QPen(pen.color().lighter(160), 2.5)
508            shadow_pen.setCosmetic(True)
509            return pen, shadow_pen
510
511        data = self.curve_data(target, clf_idx)
512
513        if (target, clf_idx) not in self._plot_curves:
514            pen, shadow_pen = generate_pens(self.colors[clf_idx])
515            name = self.classifier_names[clf_idx]
516            @once
517            def merged():
518                return plot_curve(
519                    data.merged, pen=pen, shadow_pen=shadow_pen, name=name)
520            @once
521            def folds():
522                return [plot_curve(fold, pen=pen, shadow_pen=shadow_pen)
523                        for fold in data.folds]
524            @once
525            def avg_vert():
526                return plot_avg_curve(data.avg_vertical, pen=pen,
527                                      shadow_pen=shadow_pen, name=name)
528            @once
529            def avg_thres():
530                return plot_avg_curve(data.avg_threshold, pen=pen,
531                                      shadow_pen=shadow_pen, name=name)
532
533            self._plot_curves[target, clf_idx] = PlotCurves(
534                merge=merged, folds=folds,
535                avg_vertical=avg_vert, avg_threshold=avg_thres
536            )
537
538        return self._plot_curves[target, clf_idx]
539
540    def _setup_plot(self):
541        def merge_averaging():
542            for curve in curves:
543                graphics = curve.merge()
544                curve = graphics.curve
545                self.plot.addItem(graphics.curve_item)
546
547                if self.display_convex_curve:
548                    self.plot.addItem(graphics.hull_item)
549
550                if self.display_def_threshold and curve.is_valid:
551                    points = curve.points
552                    ind = np.argmin(np.abs(points.thresholds - 0.5))
553                    item = pg.TextItem(
554                        text="{:.3f}".format(points.thresholds[ind]),
555                    )
556                    item.setPos(points.fpr[ind], points.tpr[ind])
557                    self.plot.addItem(item)
558
559            hull_curves = [curve.merged.hull for curve in selected]
560            if hull_curves:
561                self._rocch = convex_hull(hull_curves)
562                iso_pen = QPen(QColor(Qt.black), 1)
563                iso_pen.setCosmetic(True)
564                self._perf_line = InfiniteLine(pen=iso_pen, antialias=True)
565                self.plot.addItem(self._perf_line)
566            return hull_curves
567
568        def vertical_averaging():
569            for curve in curves:
570                graphics = curve.avg_vertical()
571
572                self.plot.addItem(graphics.curve_item)
573                self.plot.addItem(graphics.confint_item)
574            return [curve.avg_vertical.hull for curve in selected]
575
576        def threshold_averaging():
577            for curve in curves:
578                graphics = curve.avg_threshold()
579                self.plot.addItem(graphics.curve_item)
580                self.plot.addItem(graphics.confint_item)
581            return [curve.avg_threshold.hull for curve in selected]
582
583        def no_averaging():
584            for curve in curves:
585                graphics = curve.folds()
586                for fold in graphics:
587                    self.plot.addItem(fold.curve_item)
588                    if self.display_convex_curve:
589                        self.plot.addItem(fold.hull_item)
590            return [fold.hull for curve in selected for fold in curve.folds]
591
592        averagings = {
593            OWROCAnalysis.Merge: merge_averaging,
594            OWROCAnalysis.Vertical: vertical_averaging,
595            OWROCAnalysis.Threshold: threshold_averaging,
596            OWROCAnalysis.NoAveraging: no_averaging
597        }
598
599        target = self.target_index
600        selected = self.selected_classifiers
601
602        curves = [self.plot_curves(target, i) for i in selected]
603        selected = [self.curve_data(target, i) for i in selected]
604        hull_curves = averagings[self.roc_averaging]()
605
606        if self.display_convex_hull and hull_curves:
607            hull = convex_hull(hull_curves)
608            hull_pen = QPen(QColor(200, 200, 200, 100), 2)
609            hull_pen.setCosmetic(True)
610            item = self.plot.plot(
611                hull.fpr, hull.tpr,
612                pen=hull_pen,
613                brush=QBrush(QColor(200, 200, 200, 50)),
614                fillLevel=0)
615            item.setZValue(-10000)
616
617        pen = QPen(QColor(100, 100, 100, 100), 1, Qt.DashLine)
618        pen.setCosmetic(True)
619        self.plot.plot([0, 1], [0, 1], pen=pen, antialias=True)
620
621        if self.roc_averaging == OWROCAnalysis.Merge:
622            self._update_perf_line()
623
624        self._update_axes_ticks()
625
626        warning = ""
627        if not all(c.is_valid for c in hull_curves):
628            if any(c.is_valid for c in hull_curves):
629                warning = "Some ROC curves are undefined"
630            else:
631                warning = "All ROC curves are undefined"
632        self.warning(warning)
633
634    def _update_axes_ticks(self):
635        def enumticks(a):
636            a = np.unique(a)
637            if len(a) > 15:
638                return None
639            return [[(x, f"{x:.2f}") for x in a[::-1]]]
640
641        data = self.curve_data(self.target_index, self.selected_classifiers[0])
642        points = data.merged.points
643
644        axis = self.plot.getAxis("bottom")
645        axis.setTicks(enumticks(points.fpr))
646
647        axis = self.plot.getAxis("left")
648        axis.setTicks(enumticks(points.tpr))
649
650    def _on_mouse_moved(self, pos):
651        target = self.target_index
652        selected = self.selected_classifiers
653        curves = [(clf_idx, self.plot_curves(target, clf_idx))
654                  for clf_idx in selected]  # type: List[Tuple[int, PlotCurves]]
655        valid_thresh, valid_clf = [], []
656        pt, ave_mode = None, self.roc_averaging
657
658        for clf_idx, crv in curves:
659            if self.roc_averaging == OWROCAnalysis.Merge:
660                curve = crv.merge()
661            elif self.roc_averaging == OWROCAnalysis.Vertical:
662                curve = crv.avg_vertical()
663            elif self.roc_averaging == OWROCAnalysis.Threshold:
664                curve = crv.avg_threshold()
665            else:
666                # currently not implemented for 'Show Individual Curves'
667                return
668
669            sp = curve.curve_item.childItems()[0]  # type: pg.ScatterPlotItem
670            act_pos = sp.mapFromScene(pos)
671            pts = sp.pointsAt(act_pos)
672
673            if pts:
674                mouse_pt = pts[0].pos()
675                if self._tooltip_cache:
676                    cache_pt, cache_thresh, cache_clf, cache_ave = self._tooltip_cache
677                    curr_thresh, curr_clf = [], []
678                    if np.linalg.norm(mouse_pt - cache_pt) < 10e-6 \
679                            and cache_ave == self.roc_averaging:
680                        mask = np.equal(cache_clf, clf_idx)
681                        curr_thresh = np.compress(mask, cache_thresh).tolist()
682                        curr_clf = np.compress(mask, cache_clf).tolist()
683                    else:
684                        QToolTip.showText(QCursor.pos(), "")
685                        self._tooltip_cache = None
686
687                    if curr_thresh:
688                        valid_thresh.append(*curr_thresh)
689                        valid_clf.append(*curr_clf)
690                        pt = cache_pt
691                        continue
692
693                curve_pts = curve.curve.points
694                roc_points = np.column_stack((curve_pts.fpr, curve_pts.tpr))
695                diff = np.subtract(roc_points, mouse_pt)
696                # Find closest point on curve and save the corresponding threshold
697                idx_closest = np.argmin(np.linalg.norm(diff, axis=1))
698
699                thresh = curve_pts.thresholds[idx_closest]
700                if not np.isnan(thresh):
701                    valid_thresh.append(thresh)
702                    valid_clf.append(clf_idx)
703                    pt = [curve_pts.fpr[idx_closest], curve_pts.tpr[idx_closest]]
704
705        if valid_thresh:
706            clf_names = self.classifier_names
707            msg = "Thresholds:\n" + "\n".join(["({:s}) {:.3f}".format(clf_names[i], thresh)
708                                               for i, thresh in zip(valid_clf, valid_thresh)])
709            QToolTip.showText(QCursor.pos(), msg)
710            self._tooltip_cache = (pt, valid_thresh, valid_clf, ave_mode)
711
712    def _on_target_changed(self):
713        self.plot.clear()
714        self._set_target_prior()
715        self._setup_plot()
716
717    def _on_classifiers_changed(self):
718        self.plot.clear()
719        if self.results is not None:
720            self._setup_plot()
721
722    def _on_target_prior_changed(self):
723        self.target_prior_sp.setStyleSheet("color: black;")
724        self._on_display_perf_line_changed()
725
726    def _on_display_perf_line_changed(self):
727        if self.roc_averaging == OWROCAnalysis.Merge:
728            self._update_perf_line()
729
730        if self.perf_line is not None:
731            self.perf_line.setVisible(self.display_perf_line)
732
733    def _on_display_def_threshold_changed(self):
734        self._replot()
735
736    def _replot(self):
737        self.plot.clear()
738        if self.results is not None:
739            self._setup_plot()
740
741    def _update_perf_line(self):
742        if self._perf_line is None:
743            return
744
745        self._perf_line.setVisible(self.display_perf_line)
746        if self.display_perf_line:
747            m = roc_iso_performance_slope(
748                self.fp_cost, self.fn_cost, self.target_prior / 100.0)
749
750            hull = self._rocch
751            if hull.is_valid:
752                ind = roc_iso_performance_line(m, hull)
753                angle = np.arctan2(m, 1)  # in radians
754                self._perf_line.setAngle(angle * 180 / np.pi)
755                self._perf_line.setPos((hull.fpr[ind[0]], hull.tpr[ind[0]]))
756            else:
757                self._perf_line.setVisible(False)
758
759    def onDeleteWidget(self):
760        self.clear()
761
762    def send_report(self):
763        if self.results is None:
764            return
765        items = OrderedDict()
766        items["Target class"] = self.target_cb.currentText()
767        if self.display_perf_line:
768            items["Costs"] = \
769                "FP = {}, FN = {}".format(self.fp_cost, self.fn_cost)
770            items["Target probability"] = "{} %".format(self.target_prior)
771        caption = report.list_legend(self.classifiers_list_box,
772                                     self.selected_classifiers)
773        self.report_items(items)
774        self.report_plot()
775        self.report_caption(caption)
776
777
778def interp(x, xp, fp, left=None, right=None):
779    """
780    Like numpy.interp except for handling of running sequences of
781    same values in `xp`.
782    """
783    x = np.asanyarray(x)
784    xp = np.asanyarray(xp)
785    fp = np.asanyarray(fp)
786
787    if xp.shape != fp.shape:
788        raise ValueError("xp and fp must have the same shape")
789
790    ind = np.searchsorted(xp, x, side="right")
791    fx = np.zeros(len(x))
792
793    under = ind == 0
794    over = ind == len(xp)
795    between = ~under & ~over
796
797    fx[under] = left if left is not None else fp[0]
798    fx[over] = right if right is not None else fp[-1]
799
800    if right is not None:
801        # Fix points exactly on the right boundary.
802        fx[x == xp[-1]] = fp[-1]
803
804    ind = ind[between]
805
806    df = (fp[ind] - fp[ind - 1]) / (xp[ind] - xp[ind - 1])
807
808    fx[between] = df * (x[between] - xp[ind]) + fp[ind]
809
810    return fx
811
812
813def roc_curve_for_fold(res, fold, clf_idx, target):
814    fold_actual = res.actual[fold]
815    P = np.sum(fold_actual == target)
816    N = fold_actual.size - P
817
818    if P == 0 or N == 0:
819        # Undefined TP and FP rate
820        return np.array([]), np.array([]), np.array([])
821
822    fold_probs = res.probabilities[clf_idx][fold][:, target]
823    drop_intermediate = len(fold_actual) > 20
824    fpr, tpr, thresholds = skl_metrics.roc_curve(
825        fold_actual, fold_probs, pos_label=target,
826        drop_intermediate=drop_intermediate
827    )
828
829    # skl sets the first threshold to the highest threshold in the data + 1
830    # since we deal with probabilities, we (carefully) set it to 1
831    # Unrelated comparisons, thus pylint: disable=chained-comparison
832    if len(thresholds) > 1 and thresholds[1] <= 1:
833        thresholds[0] = 1
834    return fpr, tpr, thresholds
835
836
837def roc_curve_vertical_average(curves, samples=10):
838    if not curves:
839        raise ValueError("No curves")
840    fpr_sample = np.linspace(0.0, 1.0, samples)
841    tpr_samples = []
842    for fpr, tpr, _ in curves:
843        tpr_samples.append(interp(fpr_sample, fpr, tpr, left=0, right=1))
844
845    tpr_samples = np.array(tpr_samples)
846    return fpr_sample, tpr_samples.mean(axis=0), tpr_samples.std(axis=0)
847
848
849def roc_curve_threshold_average(curves, thresh_samples):
850    if not curves:
851        raise ValueError("No curves")
852    fpr_samples, tpr_samples = [], []
853    for fpr, tpr, thresh in curves:
854        ind = np.searchsorted(thresh[::-1], thresh_samples, side="left")
855        ind = ind[::-1]
856        ind = np.clip(ind, 0, len(thresh) - 1)
857        fpr_samples.append(fpr[ind])
858        tpr_samples.append(tpr[ind])
859
860    fpr_samples = np.array(fpr_samples)
861    tpr_samples = np.array(tpr_samples)
862
863    return ((fpr_samples.mean(axis=0), fpr_samples.std(axis=0)),
864            (tpr_samples.mean(axis=0), fpr_samples.std(axis=0)))
865
866
867def roc_curve_thresh_avg_interp(curves, thresh_samples):
868    fpr_samples, tpr_samples = [], []
869    for fpr, tpr, thresh in curves:
870        thresh = thresh[::-1]
871        fpr = interp(thresh_samples, thresh, fpr[::-1], left=1.0, right=0.0)
872        tpr = interp(thresh_samples, thresh, tpr[::-1], left=1.0, right=0.0)
873        fpr_samples.append(fpr)
874        tpr_samples.append(tpr)
875
876    fpr_samples = np.array(fpr_samples)
877    tpr_samples = np.array(tpr_samples)
878
879    return ((fpr_samples.mean(axis=0), fpr_samples.std(axis=0)),
880            (tpr_samples.mean(axis=0), fpr_samples.std(axis=0)))
881
882
883RocPoint = namedtuple("RocPoint", ["fpr", "tpr", "threshold"])
884
885
886def roc_curve_convex_hull(curve):
887    def slope(p1, p2):
888        x1, y1, _ = p1
889        x2, y2, _ = p2
890        if x1 != x2:
891            return  (y2 - y1) / (x2 - x1)
892        else:
893            return np.inf
894
895    fpr, _, _ = curve
896
897    if len(fpr) <= 2:
898        return curve
899    points = map(RocPoint._make, zip(*curve))
900
901    hull = deque([next(points)])
902
903    for point in points:
904        while True:
905            if len(hull) < 2:
906                hull.append(point)
907                break
908            else:
909                last = hull[-1]
910                if point.fpr != last.fpr and \
911                        slope(hull[-2], last) > slope(last, point):
912                    hull.append(point)
913                    break
914                else:
915                    hull.pop()
916
917    fpr = np.array([p.fpr for p in hull])
918    tpr = np.array([p.tpr for p in hull])
919    thres = np.array([p.threshold for p in hull])
920    return (fpr, tpr, thres)
921
922
923def convex_hull(curves):
924    def slope(p1, p2):
925        x1, y1, *_ = p1
926        x2, y2, *_ = p2
927        if x1 != x2:
928            return (y2 - y1) / (x2 - x1)
929        else:
930            return np.inf
931
932    curves = [list(map(RocPoint._make, zip(*curve))) for curve in curves]
933
934    merged_points = reduce(operator.iadd, curves, [])
935    merged_points = sorted(merged_points)
936
937    if not merged_points:
938        return ROCPoints(np.array([]), np.array([]), np.array([]))
939
940    if len(merged_points) <= 2:
941        return ROCPoints._make(map(np.array, zip(*merged_points)))
942
943    points = iter(merged_points)
944
945    hull = deque([next(points)])
946
947    for point in points:
948        while True:
949            if len(hull) < 2:
950                hull.append(point)
951                break
952            else:
953                last = hull[-1]
954                if point[0] != last[0] and \
955                        slope(hull[-2], last) > slope(last, point):
956                    hull.append(point)
957                    break
958                else:
959                    hull.pop()
960
961    return ROCPoints._make(map(np.array, zip(*hull)))
962
963
964def roc_iso_performance_line(slope, hull, tol=1e-5):
965    """
966    Return the indices where a line with `slope` touches the ROC convex hull.
967    """
968    fpr, tpr, *_ = hull
969
970    # Compute the distance of each point to a reference iso line
971    # going through point (0, 1). The point(s) with the minimum
972    # distance are our result
973
974    # y = m * x + 1
975    # m * x - 1y + 1 = 0
976    a, b, c = slope, -1, 1
977    dist = distance_to_line(a, b, c, fpr, tpr)
978    mindist = np.min(dist)
979
980    return np.flatnonzero((dist - mindist) <= tol)
981
982
983def distance_to_line(a, b, c, x0, y0):
984    """
985    Return the distance to a line ax + by + c = 0
986    """
987    assert a != 0 or b != 0
988    return np.abs(a * x0 + b * y0 + c) / np.sqrt(a ** 2 + b ** 2)
989
990
991def roc_iso_performance_slope(fp_cost, fn_cost, p):
992    assert 0 <= p <= 1
993    if fn_cost * p == 0:
994        return np.inf
995    else:
996        return (fp_cost * (1. - p)) / (fn_cost * p)
997
998
999def _create_results():  # pragma: no cover
1000    probs1 = [0.984, 0.907, 0.881, 0.865, 0.815, 0.741, 0.735, 0.635,
1001              0.582, 0.554, 0.413, 0.317, 0.287, 0.225, 0.216, 0.183]
1002    probs = np.array([[[1 - x, x] for x in probs1]])
1003    preds = (probs > 0.5).astype(float)
1004    return Results(
1005        data=Orange.data.Table("heart_disease")[:16],
1006        row_indices=np.arange(16),
1007        actual=np.array(list(map(int, "1100111001001000"))),
1008        probabilities=probs, predicted=preds
1009    )
1010
1011
1012if __name__ == "__main__":  # pragma: no cover
1013    # WidgetPreview(OWROCAnalysis).run(_create_results())
1014    WidgetPreview(OWROCAnalysis).run(results_for_preview())
1015