1import operator 2from functools import reduce, wraps 3from collections import namedtuple, deque, OrderedDict 4 5import numpy as np 6import sklearn.metrics as skl_metrics 7 8from AnyQt.QtWidgets import QListView, QLabel, QGridLayout, QFrame, QAction, \ 9 QToolTip 10from AnyQt.QtGui import QColor, QPen, QBrush, QPainter, QPalette, QFont, \ 11 QCursor, QFontMetrics 12from AnyQt.QtCore import Qt, QSize 13import pyqtgraph as pg 14 15import Orange 16from Orange.widgets import widget, gui, settings 17from Orange.widgets.evaluate.contexthandlers import \ 18 EvaluationResultsContextHandler 19from Orange.widgets.evaluate.utils import check_results_adequacy 20from Orange.widgets.utils import colorpalettes 21from Orange.widgets.utils.widgetpreview import WidgetPreview 22from Orange.widgets.widget import Input 23from Orange.widgets import report 24 25from Orange.widgets.evaluate.utils import results_for_preview 26from Orange.evaluation.testing import Results 27 28 29#: Points on a ROC curve 30ROCPoints = namedtuple( 31 "ROCPoints", 32 ["fpr", # (N,) array of false positive rate coordinates (ascending) 33 "tpr", # (N,) array of true positive rate coordinates 34 "thresholds" # (N,) array of thresholds (in descending order) 35 ] 36) 37ROCPoints.is_valid = property(lambda self: self.fpr.size > 0) 38 39#: ROC Curve and it's convex hull 40ROCCurve = namedtuple( 41 "ROCCurve", 42 ["points", # ROCPoints 43 "hull" # ROCPoints of the convex hull 44 ] 45) 46ROCCurve.is_valid = property(lambda self: self.points.is_valid) 47 48#: A ROC Curve averaged vertically 49ROCAveragedVert = namedtuple( 50 "ROCAveragedVert", 51 ["points", # ROCPoints sampled by fpr 52 "hull", # ROCPoints of the convex hull 53 "tpr_std", # array standard deviation of tpr at each fpr point 54 ] 55) 56ROCAveragedVert.is_valid = property(lambda self: self.points.is_valid) 57 58#: A ROC Curve averaged by thresholds 59ROCAveragedThresh = namedtuple( 60 "ROCAveragedThresh", 61 ["points", # ROCPoints sampled by threshold 62 "hull", # ROCPoints of the convex hull 63 "tpr_std", # array standard deviations of tpr at each threshold 64 "fpr_std" # array standard deviations of fpr at each threshold 65 ] 66) 67ROCAveragedThresh.is_valid = property(lambda self: self.points.is_valid) 68 69#: Combined data for a ROC curve of a single algorithm 70ROCData = namedtuple( 71 "ROCData", 72 ["merged", # ROCCurve merged over all folds 73 "folds", # ROCCurve list, one for each fold 74 "avg_vertical", # ROCAveragedVert 75 "avg_threshold", # ROCAveragedThresh 76 ] 77) 78 79 80def roc_data_from_results(results, clf_index, target): 81 """ 82 Compute ROC Curve(s) from evaluation results. 83 84 :param Orange.evaluation.Results results: 85 Evaluation results. 86 :param int clf_index: 87 Learner index in the `results`. 88 :param int target: 89 Target class index (i.e. positive class). 90 :rval ROCData: 91 A instance holding the computed curves. 92 """ 93 merged = roc_curve_for_fold(results, ..., clf_index, target) 94 merged_curve = ROCCurve(ROCPoints(*merged), 95 ROCPoints(*roc_curve_convex_hull(merged))) 96 97 folds = results.folds if results.folds is not None else [...] 98 fold_curves = [] 99 for fold in folds: 100 points = roc_curve_for_fold(results, fold, clf_index, target) 101 hull = roc_curve_convex_hull(points) 102 c = ROCCurve(ROCPoints(*points), ROCPoints(*hull)) 103 fold_curves.append(c) 104 105 curves = [fold.points for fold in fold_curves 106 if fold.is_valid] 107 108 if curves: 109 fpr, tpr, std = roc_curve_vertical_average(curves) 110 111 thresh = np.zeros_like(fpr) * np.nan 112 hull = roc_curve_convex_hull((fpr, tpr, thresh)) 113 v_avg = ROCAveragedVert( 114 ROCPoints(fpr, tpr, thresh), 115 ROCPoints(*hull), 116 std 117 ) 118 else: 119 # return an invalid vertical averaged ROC 120 v_avg = ROCAveragedVert( 121 ROCPoints(np.array([]), np.array([]), np.array([])), 122 ROCPoints(np.array([]), np.array([]), np.array([])), 123 np.array([]) 124 ) 125 126 if curves: 127 all_thresh = np.hstack([t for _, _, t in curves]) 128 all_thresh = np.clip(all_thresh, 0.0 - 1e-10, 1.0 + 1e-10) 129 all_thresh = np.unique(all_thresh)[::-1] 130 thresh = all_thresh[::max(all_thresh.size // 10, 1)] 131 132 (fpr, fpr_std), (tpr, tpr_std) = \ 133 roc_curve_threshold_average(curves, thresh) 134 135 hull = roc_curve_convex_hull((fpr, tpr, thresh)) 136 137 t_avg = ROCAveragedThresh( 138 ROCPoints(fpr, tpr, thresh), 139 ROCPoints(*hull), 140 tpr_std, 141 fpr_std 142 ) 143 else: 144 # return an invalid threshold averaged ROC 145 t_avg = ROCAveragedThresh( 146 ROCPoints(np.array([]), np.array([]), np.array([])), 147 ROCPoints(np.array([]), np.array([]), np.array([])), 148 np.array([]), 149 np.array([]) 150 ) 151 return ROCData(merged_curve, fold_curves, v_avg, t_avg) 152 153ROCData.from_results = staticmethod(roc_data_from_results) 154 155#: A curve item to be displayed in a plot 156PlotCurve = namedtuple( 157 "PlotCurve", 158 ["curve", # ROCCurve source curve 159 "curve_item", # pg.PlotDataItem main curve 160 "hull_item" # pg.PlotDataItem curve's convex hull 161 ] 162) 163 164 165def plot_curve(curve, pen=None, shadow_pen=None, symbol="+", 166 symbol_size=3, name=None): 167 """ 168 Construct a `PlotCurve` for the given `ROCCurve`. 169 170 :param ROCCurve curve: 171 Source curve. 172 173 The other parameters are passed to pg.PlotDataItem 174 175 :rtype: PlotCurve 176 """ 177 def extend_to_origin(points): 178 "Extend ROCPoints to include coordinate origin if not already present" 179 if points.tpr.size and (points.tpr[0] > 0 or points.fpr[0] > 0): 180 points = ROCPoints( 181 np.r_[0, points.fpr], np.r_[0, points.tpr], 182 np.r_[points.thresholds[0] + 1, points.thresholds] 183 ) 184 return points 185 186 points = extend_to_origin(curve.points) 187 item = pg.PlotCurveItem( 188 points.fpr, points.tpr, pen=pen, shadowPen=shadow_pen, 189 name=name, antialias=True 190 ) 191 sp = pg.ScatterPlotItem( 192 curve.points.fpr, curve.points.tpr, symbol=symbol, 193 size=symbol_size, pen=shadow_pen, 194 name=name 195 ) 196 sp.setParentItem(item) 197 198 hull = extend_to_origin(curve.hull) 199 200 hull_item = pg.PlotDataItem( 201 hull.fpr, hull.tpr, pen=pen, antialias=True 202 ) 203 return PlotCurve(curve, item, hull_item) 204 205PlotCurve.from_roc_curve = staticmethod(plot_curve) 206 207#: A curve displayed in a plot with error bars 208PlotAvgCurve = namedtuple( 209 "PlotAvgCurve", 210 ["curve", # ROCCurve 211 "curve_item", # pg.PlotDataItem 212 "hull_item", # pg.PlotDataItem 213 "confint_item", # pg.ErrorBarItem 214 ] 215) 216 217 218def plot_avg_curve(curve, pen=None, shadow_pen=None, symbol="+", 219 symbol_size=4, name=None): 220 """ 221 Construct a `PlotAvgCurve` for the given `curve`. 222 223 :param curve: Source curve. 224 :type curve: ROCAveragedVert or ROCAveragedThresh 225 226 The other parameters are passed to pg.PlotDataItem 227 228 :rtype: PlotAvgCurve 229 """ 230 pc = plot_curve(curve, pen=pen, shadow_pen=shadow_pen, symbol=symbol, 231 symbol_size=symbol_size, name=name) 232 233 points = curve.points 234 if isinstance(curve, ROCAveragedVert): 235 tpr_std = curve.tpr_std 236 error_item = pg.ErrorBarItem( 237 x=points.fpr[1:-1], y=points.tpr[1:-1], 238 height=2 * tpr_std[1:-1], 239 pen=pen, beam=0.025, 240 antialias=True, 241 ) 242 elif isinstance(curve, ROCAveragedThresh): 243 tpr_std, fpr_std = curve.tpr_std, curve.fpr_std 244 error_item = pg.ErrorBarItem( 245 x=points.fpr[1:-1], y=points.tpr[1:-1], 246 height=2 * tpr_std[1:-1], width=2 * fpr_std[1:-1], 247 pen=pen, beam=0.025, 248 antialias=True, 249 ) 250 return PlotAvgCurve(curve, pc.curve_item, pc.hull_item, error_item) 251 252PlotAvgCurve.from_roc_curve = staticmethod(plot_avg_curve) 253 254Some = namedtuple("Some", ["val"]) 255 256 257def once(f): 258 """ 259 Return a function that will be called only once, and it's result cached. 260 """ 261 cached = None 262 263 @wraps(f) 264 def wraped(): 265 nonlocal cached 266 if cached is None: 267 cached = Some(f()) 268 return cached.val 269 return wraped 270 271 272PlotCurves = namedtuple( 273 "PlotCurves", 274 ["merge", # :: () -> PlotCurve 275 "folds", # :: () -> [PlotCurve] 276 "avg_vertical", # :: () -> PlotAvgCurve 277 "avg_threshold", # :: () -> PlotAvgCurve 278 ] 279) 280 281 282class InfiniteLine(pg.InfiniteLine): 283 """pyqtgraph.InfiniteLine extended to support antialiasing. 284 """ 285 def __init__(self, pos=None, angle=90, pen=None, movable=False, 286 bounds=None, antialias=False): 287 super().__init__(pos, angle, pen, movable, bounds) 288 self.antialias = antialias 289 290 def paint(self, p, *args): 291 if self.antialias: 292 p.setRenderHint(QPainter.Antialiasing, True) 293 super().paint(p, *args) 294 295 296class OWROCAnalysis(widget.OWWidget): 297 name = "ROC Analysis" 298 description = "Display the Receiver Operating Characteristics curve " \ 299 "based on the evaluation of classifiers." 300 icon = "icons/ROCAnalysis.svg" 301 priority = 1010 302 keywords = [] 303 304 class Inputs: 305 evaluation_results = Input("Evaluation Results", Orange.evaluation.Results) 306 307 buttons_area_orientation = None 308 settingsHandler = EvaluationResultsContextHandler() 309 target_index = settings.ContextSetting(0) 310 selected_classifiers = settings.ContextSetting([]) 311 312 display_perf_line = settings.Setting(True) 313 display_def_threshold = settings.Setting(True) 314 315 fp_cost = settings.Setting(500) 316 fn_cost = settings.Setting(500) 317 target_prior = settings.Setting(50.0, schema_only=True) 318 319 #: ROC Averaging Types 320 Merge, Vertical, Threshold, NoAveraging = 0, 1, 2, 3 321 roc_averaging = settings.Setting(Merge) 322 323 display_convex_hull = settings.Setting(False) 324 display_convex_curve = settings.Setting(False) 325 326 graph_name = "plot" 327 328 def __init__(self): 329 super().__init__() 330 331 self.results = None 332 self.classifier_names = [] 333 self.perf_line = None 334 self.colors = [] 335 self._curve_data = {} 336 self._plot_curves = {} 337 self._rocch = None 338 self._perf_line = None 339 self._tooltip_cache = None 340 341 box = gui.vBox(self.controlArea, "Plot") 342 self.target_cb = gui.comboBox( 343 box, self, "target_index", 344 label="Target", orientation=Qt.Horizontal, 345 callback=self._on_target_changed, 346 contentsLength=8, searchable=True) 347 348 gui.widgetLabel(box, "Classifiers") 349 line_height = 4 * QFontMetrics(self.font()).lineSpacing() 350 self.classifiers_list_box = gui.listBox( 351 box, self, "selected_classifiers", "classifier_names", 352 selectionMode=QListView.MultiSelection, 353 callback=self._on_classifiers_changed, 354 sizeHint=QSize(0, line_height)) 355 356 abox = gui.vBox(self.controlArea, "Curves") 357 gui.comboBox(abox, self, "roc_averaging", 358 items=["Merge Predictions from Folds", "Mean TP Rate", 359 "Mean TP and FP at Threshold", "Show Individual Curves"], 360 callback=self._replot) 361 362 gui.checkBox(abox, self, "display_convex_curve", 363 "Show convex ROC curves", callback=self._replot) 364 gui.checkBox(abox, self, "display_convex_hull", 365 "Show ROC convex hull", callback=self._replot) 366 367 box = gui.vBox(self.controlArea, "Analysis") 368 369 gui.checkBox(box, self, "display_def_threshold", 370 "Default threshold (0.5) point", 371 callback=self._on_display_def_threshold_changed) 372 373 gui.checkBox(box, self, "display_perf_line", "Show performance line", 374 callback=self._on_display_perf_line_changed) 375 grid = QGridLayout() 376 gui.indentedBox(box, orientation=grid) 377 378 sp = gui.spin(box, self, "fp_cost", 1, 1000, 10, 379 alignment=Qt.AlignRight, 380 callback=self._on_display_perf_line_changed) 381 grid.addWidget(QLabel("FP Cost:"), 0, 0) 382 grid.addWidget(sp, 0, 1) 383 384 sp = gui.spin(box, self, "fn_cost", 1, 1000, 10, 385 alignment=Qt.AlignRight, 386 callback=self._on_display_perf_line_changed) 387 grid.addWidget(QLabel("FN Cost:")) 388 grid.addWidget(sp, 1, 1) 389 self.target_prior_sp = gui.spin(box, self, "target_prior", 1, 99, 390 alignment=Qt.AlignRight, 391 callback=self._on_target_prior_changed) 392 self.target_prior_sp.setSuffix(" %") 393 self.target_prior_sp.addAction(QAction("Auto", sp)) 394 grid.addWidget(QLabel("Prior probability:")) 395 grid.addWidget(self.target_prior_sp, 2, 1) 396 397 self.plotview = pg.GraphicsView(background="w") 398 self.plotview.setFrameStyle(QFrame.StyledPanel) 399 self.plotview.scene().sigMouseMoved.connect(self._on_mouse_moved) 400 401 self.plot = pg.PlotItem(enableMenu=False) 402 self.plot.setMouseEnabled(False, False) 403 self.plot.hideButtons() 404 405 pen = QPen(self.palette().color(QPalette.Text)) 406 407 tickfont = QFont(self.font()) 408 tickfont.setPixelSize(max(int(tickfont.pixelSize() * 2 // 3), 11)) 409 410 axis = self.plot.getAxis("bottom") 411 axis.setTickFont(tickfont) 412 axis.setPen(pen) 413 axis.setLabel("FP Rate (1-Specificity)") 414 axis.setGrid(16) 415 416 axis = self.plot.getAxis("left") 417 axis.setTickFont(tickfont) 418 axis.setPen(pen) 419 axis.setLabel("TP Rate (Sensitivity)") 420 axis.setGrid(16) 421 422 self.plot.showGrid(True, True, alpha=0.1) 423 self.plot.setRange(xRange=(0.0, 1.0), yRange=(0.0, 1.0), padding=0.05) 424 425 self.plotview.setCentralItem(self.plot) 426 self.mainArea.layout().addWidget(self.plotview) 427 428 @Inputs.evaluation_results 429 def set_results(self, results): 430 """Set the input evaluation results.""" 431 self.closeContext() 432 self.clear() 433 self.results = check_results_adequacy(results, self.Error) 434 if self.results is not None: 435 self._initialize(self.results) 436 self.openContext(self.results.domain.class_var, 437 self.classifier_names) 438 self._setup_plot() 439 else: 440 self.warning() 441 442 def clear(self): 443 """Clear the widget state.""" 444 self.results = None 445 self.plot.clear() 446 self.classifier_names = [] 447 self.selected_classifiers = [] 448 self.target_cb.clear() 449 self.colors = [] 450 self._curve_data = {} 451 self._plot_curves = {} 452 self._rocch = None 453 self._perf_line = None 454 self._tooltip_cache = None 455 456 def _initialize(self, results): 457 names = getattr(results, "learner_names", None) 458 459 if names is None: 460 names = ["#{}".format(i + 1) 461 for i in range(len(results.predicted))] 462 463 self.colors = colorpalettes.get_default_curve_colors(len(names)) 464 465 self.classifier_names = names 466 self.selected_classifiers = list(range(len(names))) 467 for i in range(len(names)): 468 listitem = self.classifiers_list_box.item(i) 469 listitem.setIcon(colorpalettes.ColorIcon(self.colors[i])) 470 471 class_var = results.data.domain.class_var 472 self.target_cb.addItems(class_var.values) 473 self.target_index = 0 474 self._set_target_prior() 475 476 def _set_target_prior(self): 477 """ 478 This function sets the initial target class probability prior value 479 based on the input data. 480 """ 481 if self.results.data: 482 # here we can use target_index directly since values in the 483 # dropdown are sorted in same order than values in the table 484 target_values_cnt = np.count_nonzero( 485 self.results.data.Y == self.target_index) 486 count_all = np.count_nonzero(~np.isnan(self.results.data.Y)) 487 self.target_prior = np.round(target_values_cnt / count_all * 100) 488 489 # set the spin text to gray color when set automatically 490 self.target_prior_sp.setStyleSheet("color: gray;") 491 492 def curve_data(self, target, clf_idx): 493 """Return `ROCData' for the given target and classifier.""" 494 if (target, clf_idx) not in self._curve_data: 495 # pylint: disable=no-member 496 data = ROCData.from_results(self.results, clf_idx, target) 497 self._curve_data[target, clf_idx] = data 498 499 return self._curve_data[target, clf_idx] 500 501 def plot_curves(self, target, clf_idx): 502 """Return a set of functions `plot_curves` generating plot curves.""" 503 def generate_pens(basecolor): 504 pen = QPen(basecolor, 1) 505 pen.setCosmetic(True) 506 507 shadow_pen = QPen(pen.color().lighter(160), 2.5) 508 shadow_pen.setCosmetic(True) 509 return pen, shadow_pen 510 511 data = self.curve_data(target, clf_idx) 512 513 if (target, clf_idx) not in self._plot_curves: 514 pen, shadow_pen = generate_pens(self.colors[clf_idx]) 515 name = self.classifier_names[clf_idx] 516 @once 517 def merged(): 518 return plot_curve( 519 data.merged, pen=pen, shadow_pen=shadow_pen, name=name) 520 @once 521 def folds(): 522 return [plot_curve(fold, pen=pen, shadow_pen=shadow_pen) 523 for fold in data.folds] 524 @once 525 def avg_vert(): 526 return plot_avg_curve(data.avg_vertical, pen=pen, 527 shadow_pen=shadow_pen, name=name) 528 @once 529 def avg_thres(): 530 return plot_avg_curve(data.avg_threshold, pen=pen, 531 shadow_pen=shadow_pen, name=name) 532 533 self._plot_curves[target, clf_idx] = PlotCurves( 534 merge=merged, folds=folds, 535 avg_vertical=avg_vert, avg_threshold=avg_thres 536 ) 537 538 return self._plot_curves[target, clf_idx] 539 540 def _setup_plot(self): 541 def merge_averaging(): 542 for curve in curves: 543 graphics = curve.merge() 544 curve = graphics.curve 545 self.plot.addItem(graphics.curve_item) 546 547 if self.display_convex_curve: 548 self.plot.addItem(graphics.hull_item) 549 550 if self.display_def_threshold and curve.is_valid: 551 points = curve.points 552 ind = np.argmin(np.abs(points.thresholds - 0.5)) 553 item = pg.TextItem( 554 text="{:.3f}".format(points.thresholds[ind]), 555 ) 556 item.setPos(points.fpr[ind], points.tpr[ind]) 557 self.plot.addItem(item) 558 559 hull_curves = [curve.merged.hull for curve in selected] 560 if hull_curves: 561 self._rocch = convex_hull(hull_curves) 562 iso_pen = QPen(QColor(Qt.black), 1) 563 iso_pen.setCosmetic(True) 564 self._perf_line = InfiniteLine(pen=iso_pen, antialias=True) 565 self.plot.addItem(self._perf_line) 566 return hull_curves 567 568 def vertical_averaging(): 569 for curve in curves: 570 graphics = curve.avg_vertical() 571 572 self.plot.addItem(graphics.curve_item) 573 self.plot.addItem(graphics.confint_item) 574 return [curve.avg_vertical.hull for curve in selected] 575 576 def threshold_averaging(): 577 for curve in curves: 578 graphics = curve.avg_threshold() 579 self.plot.addItem(graphics.curve_item) 580 self.plot.addItem(graphics.confint_item) 581 return [curve.avg_threshold.hull for curve in selected] 582 583 def no_averaging(): 584 for curve in curves: 585 graphics = curve.folds() 586 for fold in graphics: 587 self.plot.addItem(fold.curve_item) 588 if self.display_convex_curve: 589 self.plot.addItem(fold.hull_item) 590 return [fold.hull for curve in selected for fold in curve.folds] 591 592 averagings = { 593 OWROCAnalysis.Merge: merge_averaging, 594 OWROCAnalysis.Vertical: vertical_averaging, 595 OWROCAnalysis.Threshold: threshold_averaging, 596 OWROCAnalysis.NoAveraging: no_averaging 597 } 598 599 target = self.target_index 600 selected = self.selected_classifiers 601 602 curves = [self.plot_curves(target, i) for i in selected] 603 selected = [self.curve_data(target, i) for i in selected] 604 hull_curves = averagings[self.roc_averaging]() 605 606 if self.display_convex_hull and hull_curves: 607 hull = convex_hull(hull_curves) 608 hull_pen = QPen(QColor(200, 200, 200, 100), 2) 609 hull_pen.setCosmetic(True) 610 item = self.plot.plot( 611 hull.fpr, hull.tpr, 612 pen=hull_pen, 613 brush=QBrush(QColor(200, 200, 200, 50)), 614 fillLevel=0) 615 item.setZValue(-10000) 616 617 pen = QPen(QColor(100, 100, 100, 100), 1, Qt.DashLine) 618 pen.setCosmetic(True) 619 self.plot.plot([0, 1], [0, 1], pen=pen, antialias=True) 620 621 if self.roc_averaging == OWROCAnalysis.Merge: 622 self._update_perf_line() 623 624 self._update_axes_ticks() 625 626 warning = "" 627 if not all(c.is_valid for c in hull_curves): 628 if any(c.is_valid for c in hull_curves): 629 warning = "Some ROC curves are undefined" 630 else: 631 warning = "All ROC curves are undefined" 632 self.warning(warning) 633 634 def _update_axes_ticks(self): 635 def enumticks(a): 636 a = np.unique(a) 637 if len(a) > 15: 638 return None 639 return [[(x, f"{x:.2f}") for x in a[::-1]]] 640 641 data = self.curve_data(self.target_index, self.selected_classifiers[0]) 642 points = data.merged.points 643 644 axis = self.plot.getAxis("bottom") 645 axis.setTicks(enumticks(points.fpr)) 646 647 axis = self.plot.getAxis("left") 648 axis.setTicks(enumticks(points.tpr)) 649 650 def _on_mouse_moved(self, pos): 651 target = self.target_index 652 selected = self.selected_classifiers 653 curves = [(clf_idx, self.plot_curves(target, clf_idx)) 654 for clf_idx in selected] # type: List[Tuple[int, PlotCurves]] 655 valid_thresh, valid_clf = [], [] 656 pt, ave_mode = None, self.roc_averaging 657 658 for clf_idx, crv in curves: 659 if self.roc_averaging == OWROCAnalysis.Merge: 660 curve = crv.merge() 661 elif self.roc_averaging == OWROCAnalysis.Vertical: 662 curve = crv.avg_vertical() 663 elif self.roc_averaging == OWROCAnalysis.Threshold: 664 curve = crv.avg_threshold() 665 else: 666 # currently not implemented for 'Show Individual Curves' 667 return 668 669 sp = curve.curve_item.childItems()[0] # type: pg.ScatterPlotItem 670 act_pos = sp.mapFromScene(pos) 671 pts = sp.pointsAt(act_pos) 672 673 if pts: 674 mouse_pt = pts[0].pos() 675 if self._tooltip_cache: 676 cache_pt, cache_thresh, cache_clf, cache_ave = self._tooltip_cache 677 curr_thresh, curr_clf = [], [] 678 if np.linalg.norm(mouse_pt - cache_pt) < 10e-6 \ 679 and cache_ave == self.roc_averaging: 680 mask = np.equal(cache_clf, clf_idx) 681 curr_thresh = np.compress(mask, cache_thresh).tolist() 682 curr_clf = np.compress(mask, cache_clf).tolist() 683 else: 684 QToolTip.showText(QCursor.pos(), "") 685 self._tooltip_cache = None 686 687 if curr_thresh: 688 valid_thresh.append(*curr_thresh) 689 valid_clf.append(*curr_clf) 690 pt = cache_pt 691 continue 692 693 curve_pts = curve.curve.points 694 roc_points = np.column_stack((curve_pts.fpr, curve_pts.tpr)) 695 diff = np.subtract(roc_points, mouse_pt) 696 # Find closest point on curve and save the corresponding threshold 697 idx_closest = np.argmin(np.linalg.norm(diff, axis=1)) 698 699 thresh = curve_pts.thresholds[idx_closest] 700 if not np.isnan(thresh): 701 valid_thresh.append(thresh) 702 valid_clf.append(clf_idx) 703 pt = [curve_pts.fpr[idx_closest], curve_pts.tpr[idx_closest]] 704 705 if valid_thresh: 706 clf_names = self.classifier_names 707 msg = "Thresholds:\n" + "\n".join(["({:s}) {:.3f}".format(clf_names[i], thresh) 708 for i, thresh in zip(valid_clf, valid_thresh)]) 709 QToolTip.showText(QCursor.pos(), msg) 710 self._tooltip_cache = (pt, valid_thresh, valid_clf, ave_mode) 711 712 def _on_target_changed(self): 713 self.plot.clear() 714 self._set_target_prior() 715 self._setup_plot() 716 717 def _on_classifiers_changed(self): 718 self.plot.clear() 719 if self.results is not None: 720 self._setup_plot() 721 722 def _on_target_prior_changed(self): 723 self.target_prior_sp.setStyleSheet("color: black;") 724 self._on_display_perf_line_changed() 725 726 def _on_display_perf_line_changed(self): 727 if self.roc_averaging == OWROCAnalysis.Merge: 728 self._update_perf_line() 729 730 if self.perf_line is not None: 731 self.perf_line.setVisible(self.display_perf_line) 732 733 def _on_display_def_threshold_changed(self): 734 self._replot() 735 736 def _replot(self): 737 self.plot.clear() 738 if self.results is not None: 739 self._setup_plot() 740 741 def _update_perf_line(self): 742 if self._perf_line is None: 743 return 744 745 self._perf_line.setVisible(self.display_perf_line) 746 if self.display_perf_line: 747 m = roc_iso_performance_slope( 748 self.fp_cost, self.fn_cost, self.target_prior / 100.0) 749 750 hull = self._rocch 751 if hull.is_valid: 752 ind = roc_iso_performance_line(m, hull) 753 angle = np.arctan2(m, 1) # in radians 754 self._perf_line.setAngle(angle * 180 / np.pi) 755 self._perf_line.setPos((hull.fpr[ind[0]], hull.tpr[ind[0]])) 756 else: 757 self._perf_line.setVisible(False) 758 759 def onDeleteWidget(self): 760 self.clear() 761 762 def send_report(self): 763 if self.results is None: 764 return 765 items = OrderedDict() 766 items["Target class"] = self.target_cb.currentText() 767 if self.display_perf_line: 768 items["Costs"] = \ 769 "FP = {}, FN = {}".format(self.fp_cost, self.fn_cost) 770 items["Target probability"] = "{} %".format(self.target_prior) 771 caption = report.list_legend(self.classifiers_list_box, 772 self.selected_classifiers) 773 self.report_items(items) 774 self.report_plot() 775 self.report_caption(caption) 776 777 778def interp(x, xp, fp, left=None, right=None): 779 """ 780 Like numpy.interp except for handling of running sequences of 781 same values in `xp`. 782 """ 783 x = np.asanyarray(x) 784 xp = np.asanyarray(xp) 785 fp = np.asanyarray(fp) 786 787 if xp.shape != fp.shape: 788 raise ValueError("xp and fp must have the same shape") 789 790 ind = np.searchsorted(xp, x, side="right") 791 fx = np.zeros(len(x)) 792 793 under = ind == 0 794 over = ind == len(xp) 795 between = ~under & ~over 796 797 fx[under] = left if left is not None else fp[0] 798 fx[over] = right if right is not None else fp[-1] 799 800 if right is not None: 801 # Fix points exactly on the right boundary. 802 fx[x == xp[-1]] = fp[-1] 803 804 ind = ind[between] 805 806 df = (fp[ind] - fp[ind - 1]) / (xp[ind] - xp[ind - 1]) 807 808 fx[between] = df * (x[between] - xp[ind]) + fp[ind] 809 810 return fx 811 812 813def roc_curve_for_fold(res, fold, clf_idx, target): 814 fold_actual = res.actual[fold] 815 P = np.sum(fold_actual == target) 816 N = fold_actual.size - P 817 818 if P == 0 or N == 0: 819 # Undefined TP and FP rate 820 return np.array([]), np.array([]), np.array([]) 821 822 fold_probs = res.probabilities[clf_idx][fold][:, target] 823 drop_intermediate = len(fold_actual) > 20 824 fpr, tpr, thresholds = skl_metrics.roc_curve( 825 fold_actual, fold_probs, pos_label=target, 826 drop_intermediate=drop_intermediate 827 ) 828 829 # skl sets the first threshold to the highest threshold in the data + 1 830 # since we deal with probabilities, we (carefully) set it to 1 831 # Unrelated comparisons, thus pylint: disable=chained-comparison 832 if len(thresholds) > 1 and thresholds[1] <= 1: 833 thresholds[0] = 1 834 return fpr, tpr, thresholds 835 836 837def roc_curve_vertical_average(curves, samples=10): 838 if not curves: 839 raise ValueError("No curves") 840 fpr_sample = np.linspace(0.0, 1.0, samples) 841 tpr_samples = [] 842 for fpr, tpr, _ in curves: 843 tpr_samples.append(interp(fpr_sample, fpr, tpr, left=0, right=1)) 844 845 tpr_samples = np.array(tpr_samples) 846 return fpr_sample, tpr_samples.mean(axis=0), tpr_samples.std(axis=0) 847 848 849def roc_curve_threshold_average(curves, thresh_samples): 850 if not curves: 851 raise ValueError("No curves") 852 fpr_samples, tpr_samples = [], [] 853 for fpr, tpr, thresh in curves: 854 ind = np.searchsorted(thresh[::-1], thresh_samples, side="left") 855 ind = ind[::-1] 856 ind = np.clip(ind, 0, len(thresh) - 1) 857 fpr_samples.append(fpr[ind]) 858 tpr_samples.append(tpr[ind]) 859 860 fpr_samples = np.array(fpr_samples) 861 tpr_samples = np.array(tpr_samples) 862 863 return ((fpr_samples.mean(axis=0), fpr_samples.std(axis=0)), 864 (tpr_samples.mean(axis=0), fpr_samples.std(axis=0))) 865 866 867def roc_curve_thresh_avg_interp(curves, thresh_samples): 868 fpr_samples, tpr_samples = [], [] 869 for fpr, tpr, thresh in curves: 870 thresh = thresh[::-1] 871 fpr = interp(thresh_samples, thresh, fpr[::-1], left=1.0, right=0.0) 872 tpr = interp(thresh_samples, thresh, tpr[::-1], left=1.0, right=0.0) 873 fpr_samples.append(fpr) 874 tpr_samples.append(tpr) 875 876 fpr_samples = np.array(fpr_samples) 877 tpr_samples = np.array(tpr_samples) 878 879 return ((fpr_samples.mean(axis=0), fpr_samples.std(axis=0)), 880 (tpr_samples.mean(axis=0), fpr_samples.std(axis=0))) 881 882 883RocPoint = namedtuple("RocPoint", ["fpr", "tpr", "threshold"]) 884 885 886def roc_curve_convex_hull(curve): 887 def slope(p1, p2): 888 x1, y1, _ = p1 889 x2, y2, _ = p2 890 if x1 != x2: 891 return (y2 - y1) / (x2 - x1) 892 else: 893 return np.inf 894 895 fpr, _, _ = curve 896 897 if len(fpr) <= 2: 898 return curve 899 points = map(RocPoint._make, zip(*curve)) 900 901 hull = deque([next(points)]) 902 903 for point in points: 904 while True: 905 if len(hull) < 2: 906 hull.append(point) 907 break 908 else: 909 last = hull[-1] 910 if point.fpr != last.fpr and \ 911 slope(hull[-2], last) > slope(last, point): 912 hull.append(point) 913 break 914 else: 915 hull.pop() 916 917 fpr = np.array([p.fpr for p in hull]) 918 tpr = np.array([p.tpr for p in hull]) 919 thres = np.array([p.threshold for p in hull]) 920 return (fpr, tpr, thres) 921 922 923def convex_hull(curves): 924 def slope(p1, p2): 925 x1, y1, *_ = p1 926 x2, y2, *_ = p2 927 if x1 != x2: 928 return (y2 - y1) / (x2 - x1) 929 else: 930 return np.inf 931 932 curves = [list(map(RocPoint._make, zip(*curve))) for curve in curves] 933 934 merged_points = reduce(operator.iadd, curves, []) 935 merged_points = sorted(merged_points) 936 937 if not merged_points: 938 return ROCPoints(np.array([]), np.array([]), np.array([])) 939 940 if len(merged_points) <= 2: 941 return ROCPoints._make(map(np.array, zip(*merged_points))) 942 943 points = iter(merged_points) 944 945 hull = deque([next(points)]) 946 947 for point in points: 948 while True: 949 if len(hull) < 2: 950 hull.append(point) 951 break 952 else: 953 last = hull[-1] 954 if point[0] != last[0] and \ 955 slope(hull[-2], last) > slope(last, point): 956 hull.append(point) 957 break 958 else: 959 hull.pop() 960 961 return ROCPoints._make(map(np.array, zip(*hull))) 962 963 964def roc_iso_performance_line(slope, hull, tol=1e-5): 965 """ 966 Return the indices where a line with `slope` touches the ROC convex hull. 967 """ 968 fpr, tpr, *_ = hull 969 970 # Compute the distance of each point to a reference iso line 971 # going through point (0, 1). The point(s) with the minimum 972 # distance are our result 973 974 # y = m * x + 1 975 # m * x - 1y + 1 = 0 976 a, b, c = slope, -1, 1 977 dist = distance_to_line(a, b, c, fpr, tpr) 978 mindist = np.min(dist) 979 980 return np.flatnonzero((dist - mindist) <= tol) 981 982 983def distance_to_line(a, b, c, x0, y0): 984 """ 985 Return the distance to a line ax + by + c = 0 986 """ 987 assert a != 0 or b != 0 988 return np.abs(a * x0 + b * y0 + c) / np.sqrt(a ** 2 + b ** 2) 989 990 991def roc_iso_performance_slope(fp_cost, fn_cost, p): 992 assert 0 <= p <= 1 993 if fn_cost * p == 0: 994 return np.inf 995 else: 996 return (fp_cost * (1. - p)) / (fn_cost * p) 997 998 999def _create_results(): # pragma: no cover 1000 probs1 = [0.984, 0.907, 0.881, 0.865, 0.815, 0.741, 0.735, 0.635, 1001 0.582, 0.554, 0.413, 0.317, 0.287, 0.225, 0.216, 0.183] 1002 probs = np.array([[[1 - x, x] for x in probs1]]) 1003 preds = (probs > 0.5).astype(float) 1004 return Results( 1005 data=Orange.data.Table("heart_disease")[:16], 1006 row_indices=np.arange(16), 1007 actual=np.array(list(map(int, "1100111001001000"))), 1008 probabilities=probs, predicted=preds 1009 ) 1010 1011 1012if __name__ == "__main__": # pragma: no cover 1013 # WidgetPreview(OWROCAnalysis).run(_create_results()) 1014 WidgetPreview(OWROCAnalysis).run(results_for_preview()) 1015