1from sklearn.base import is_classifier 2from .base import _get_response 3 4from .. import average_precision_score 5from .. import precision_recall_curve 6from .._base import _check_pos_label_consistency 7from .._classification import check_consistent_length 8 9from ...utils import check_matplotlib_support, deprecated 10 11 12class PrecisionRecallDisplay: 13 """Precision Recall visualization. 14 15 It is recommend to use 16 :func:`~sklearn.metrics.PrecisionRecallDisplay.from_estimator` or 17 :func:`~sklearn.metrics.PrecisionRecallDisplay.from_predictions` to create 18 a :class:`~sklearn.metrics.PredictionRecallDisplay`. All parameters are 19 stored as attributes. 20 21 Read more in the :ref:`User Guide <visualizations>`. 22 23 Parameters 24 ----------- 25 precision : ndarray 26 Precision values. 27 28 recall : ndarray 29 Recall values. 30 31 average_precision : float, default=None 32 Average precision. If None, the average precision is not shown. 33 34 estimator_name : str, default=None 35 Name of estimator. If None, then the estimator name is not shown. 36 37 pos_label : str or int, default=None 38 The class considered as the positive class. If None, the class will not 39 be shown in the legend. 40 41 .. versionadded:: 0.24 42 43 Attributes 44 ---------- 45 line_ : matplotlib Artist 46 Precision recall curve. 47 48 ax_ : matplotlib Axes 49 Axes with precision recall curve. 50 51 figure_ : matplotlib Figure 52 Figure containing the curve. 53 54 See Also 55 -------- 56 precision_recall_curve : Compute precision-recall pairs for different 57 probability thresholds. 58 PrecisionRecallDisplay.from_estimator : Plot Precision Recall Curve given 59 a binary classifier. 60 PrecisionRecallDisplay.from_predictions : Plot Precision Recall Curve 61 using predictions from a binary classifier. 62 63 Examples 64 -------- 65 >>> import matplotlib.pyplot as plt 66 >>> from sklearn.datasets import make_classification 67 >>> from sklearn.metrics import (precision_recall_curve, 68 ... PrecisionRecallDisplay) 69 >>> from sklearn.model_selection import train_test_split 70 >>> from sklearn.svm import SVC 71 >>> X, y = make_classification(random_state=0) 72 >>> X_train, X_test, y_train, y_test = train_test_split(X, y, 73 ... random_state=0) 74 >>> clf = SVC(random_state=0) 75 >>> clf.fit(X_train, y_train) 76 SVC(random_state=0) 77 >>> predictions = clf.predict(X_test) 78 >>> precision, recall, _ = precision_recall_curve(y_test, predictions) 79 >>> disp = PrecisionRecallDisplay(precision=precision, recall=recall) 80 >>> disp.plot() 81 <...> 82 >>> plt.show() 83 """ 84 85 def __init__( 86 self, 87 precision, 88 recall, 89 *, 90 average_precision=None, 91 estimator_name=None, 92 pos_label=None, 93 ): 94 self.estimator_name = estimator_name 95 self.precision = precision 96 self.recall = recall 97 self.average_precision = average_precision 98 self.pos_label = pos_label 99 100 def plot(self, ax=None, *, name=None, **kwargs): 101 """Plot visualization. 102 103 Extra keyword arguments will be passed to matplotlib's `plot`. 104 105 Parameters 106 ---------- 107 ax : Matplotlib Axes, default=None 108 Axes object to plot on. If `None`, a new figure and axes is 109 created. 110 111 name : str, default=None 112 Name of precision recall curve for labeling. If `None`, use 113 `estimator_name` if not `None`, otherwise no labeling is shown. 114 115 **kwargs : dict 116 Keyword arguments to be passed to matplotlib's `plot`. 117 118 Returns 119 ------- 120 display : :class:`~sklearn.metrics.PrecisionRecallDisplay` 121 Object that stores computed values. 122 """ 123 check_matplotlib_support("PrecisionRecallDisplay.plot") 124 125 name = self.estimator_name if name is None else name 126 127 line_kwargs = {"drawstyle": "steps-post"} 128 if self.average_precision is not None and name is not None: 129 line_kwargs["label"] = f"{name} (AP = {self.average_precision:0.2f})" 130 elif self.average_precision is not None: 131 line_kwargs["label"] = f"AP = {self.average_precision:0.2f}" 132 elif name is not None: 133 line_kwargs["label"] = name 134 line_kwargs.update(**kwargs) 135 136 import matplotlib.pyplot as plt 137 138 if ax is None: 139 fig, ax = plt.subplots() 140 141 (self.line_,) = ax.plot(self.recall, self.precision, **line_kwargs) 142 info_pos_label = ( 143 f" (Positive label: {self.pos_label})" if self.pos_label is not None else "" 144 ) 145 146 xlabel = "Recall" + info_pos_label 147 ylabel = "Precision" + info_pos_label 148 ax.set(xlabel=xlabel, ylabel=ylabel) 149 150 if "label" in line_kwargs: 151 ax.legend(loc="lower left") 152 153 self.ax_ = ax 154 self.figure_ = ax.figure 155 return self 156 157 @classmethod 158 def from_estimator( 159 cls, 160 estimator, 161 X, 162 y, 163 *, 164 sample_weight=None, 165 pos_label=None, 166 response_method="auto", 167 name=None, 168 ax=None, 169 **kwargs, 170 ): 171 """Plot precision-recall curve given an estimator and some data. 172 173 Parameters 174 ---------- 175 estimator : estimator instance 176 Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline` 177 in which the last estimator is a classifier. 178 179 X : {array-like, sparse matrix} of shape (n_samples, n_features) 180 Input values. 181 182 y : array-like of shape (n_samples,) 183 Target values. 184 185 sample_weight : array-like of shape (n_samples,), default=None 186 Sample weights. 187 188 pos_label : str or int, default=None 189 The class considered as the positive class when computing the 190 precision and recall metrics. By default, `estimators.classes_[1]` 191 is considered as the positive class. 192 193 response_method : {'predict_proba', 'decision_function', 'auto'}, \ 194 default='auto' 195 Specifies whether to use :term:`predict_proba` or 196 :term:`decision_function` as the target response. If set to 'auto', 197 :term:`predict_proba` is tried first and if it does not exist 198 :term:`decision_function` is tried next. 199 200 name : str, default=None 201 Name for labeling curve. If `None`, no name is used. 202 203 ax : matplotlib axes, default=None 204 Axes object to plot on. If `None`, a new figure and axes is created. 205 206 **kwargs : dict 207 Keyword arguments to be passed to matplotlib's `plot`. 208 209 Returns 210 ------- 211 display : :class:`~sklearn.metrics.PrecisionRecallDisplay` 212 213 See Also 214 -------- 215 PrecisionRecallDisplay.from_predictions : Plot precision-recall curve 216 using estimated probabilities or output of decision function. 217 218 Examples 219 -------- 220 >>> import matplotlib.pyplot as plt 221 >>> from sklearn.datasets import make_classification 222 >>> from sklearn.metrics import PrecisionRecallDisplay 223 >>> from sklearn.model_selection import train_test_split 224 >>> from sklearn.linear_model import LogisticRegression 225 >>> X, y = make_classification(random_state=0) 226 >>> X_train, X_test, y_train, y_test = train_test_split( 227 ... X, y, random_state=0) 228 >>> clf = LogisticRegression() 229 >>> clf.fit(X_train, y_train) 230 LogisticRegression() 231 >>> PrecisionRecallDisplay.from_estimator( 232 ... clf, X_test, y_test) 233 <...> 234 >>> plt.show() 235 """ 236 method_name = f"{cls.__name__}.from_estimator" 237 check_matplotlib_support(method_name) 238 if not is_classifier(estimator): 239 raise ValueError(f"{method_name} only supports classifiers") 240 y_pred, pos_label = _get_response( 241 X, 242 estimator, 243 response_method, 244 pos_label=pos_label, 245 ) 246 247 name = name if name is not None else estimator.__class__.__name__ 248 249 return cls.from_predictions( 250 y, 251 y_pred, 252 sample_weight=sample_weight, 253 name=name, 254 pos_label=pos_label, 255 ax=ax, 256 **kwargs, 257 ) 258 259 @classmethod 260 def from_predictions( 261 cls, 262 y_true, 263 y_pred, 264 *, 265 sample_weight=None, 266 pos_label=None, 267 name=None, 268 ax=None, 269 **kwargs, 270 ): 271 """Plot precision-recall curve given binary class predictions. 272 273 Parameters 274 ---------- 275 y_true : array-like of shape (n_samples,) 276 True binary labels. 277 278 y_pred : array-like of shape (n_samples,) 279 Estimated probabilities or output of decision function. 280 281 sample_weight : array-like of shape (n_samples,), default=None 282 Sample weights. 283 284 pos_label : str or int, default=None 285 The class considered as the positive class when computing the 286 precision and recall metrics. 287 288 name : str, default=None 289 Name for labeling curve. If `None`, name will be set to 290 `"Classifier"`. 291 292 ax : matplotlib axes, default=None 293 Axes object to plot on. If `None`, a new figure and axes is created. 294 295 **kwargs : dict 296 Keyword arguments to be passed to matplotlib's `plot`. 297 298 Returns 299 ------- 300 display : :class:`~sklearn.metrics.PrecisionRecallDisplay` 301 302 See Also 303 -------- 304 PrecisionRecallDisplay.from_estimator : Plot precision-recall curve 305 using an estimator. 306 307 Examples 308 -------- 309 >>> import matplotlib.pyplot as plt 310 >>> from sklearn.datasets import make_classification 311 >>> from sklearn.metrics import PrecisionRecallDisplay 312 >>> from sklearn.model_selection import train_test_split 313 >>> from sklearn.linear_model import LogisticRegression 314 >>> X, y = make_classification(random_state=0) 315 >>> X_train, X_test, y_train, y_test = train_test_split( 316 ... X, y, random_state=0) 317 >>> clf = LogisticRegression() 318 >>> clf.fit(X_train, y_train) 319 LogisticRegression() 320 >>> y_pred = clf.predict_proba(X_test)[:, 1] 321 >>> PrecisionRecallDisplay.from_predictions( 322 ... y_test, y_pred) 323 <...> 324 >>> plt.show() 325 """ 326 check_matplotlib_support(f"{cls.__name__}.from_predictions") 327 328 check_consistent_length(y_true, y_pred, sample_weight) 329 pos_label = _check_pos_label_consistency(pos_label, y_true) 330 331 precision, recall, _ = precision_recall_curve( 332 y_true, y_pred, pos_label=pos_label, sample_weight=sample_weight 333 ) 334 average_precision = average_precision_score( 335 y_true, y_pred, pos_label=pos_label, sample_weight=sample_weight 336 ) 337 338 name = name if name is not None else "Classifier" 339 340 viz = PrecisionRecallDisplay( 341 precision=precision, 342 recall=recall, 343 average_precision=average_precision, 344 estimator_name=name, 345 pos_label=pos_label, 346 ) 347 348 return viz.plot(ax=ax, name=name, **kwargs) 349 350 351@deprecated( 352 "Function `plot_precision_recall_curve` is deprecated in 1.0 and will be " 353 "removed in 1.2. Use one of the class methods: " 354 "PrecisionRecallDisplay.from_predictions or " 355 "PrecisionRecallDisplay.from_estimator." 356) 357def plot_precision_recall_curve( 358 estimator, 359 X, 360 y, 361 *, 362 sample_weight=None, 363 response_method="auto", 364 name=None, 365 ax=None, 366 pos_label=None, 367 **kwargs, 368): 369 """Plot Precision Recall Curve for binary classifiers. 370 371 Extra keyword arguments will be passed to matplotlib's `plot`. 372 373 Read more in the :ref:`User Guide <precision_recall_f_measure_metrics>`. 374 375 .. deprecated:: 1.0 376 `plot_precision_recall_curve` is deprecated in 1.0 and will be removed in 377 1.2. Use one of the following class methods: 378 :func:`~sklearn.metrics.PrecisionRecallDisplay.from_predictions` or 379 :func:`~sklearn.metrics.PrecisionRecallDisplay.from_estimator`. 380 381 Parameters 382 ---------- 383 estimator : estimator instance 384 Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline` 385 in which the last estimator is a classifier. 386 387 X : {array-like, sparse matrix} of shape (n_samples, n_features) 388 Input values. 389 390 y : array-like of shape (n_samples,) 391 Binary target values. 392 393 sample_weight : array-like of shape (n_samples,), default=None 394 Sample weights. 395 396 response_method : {'predict_proba', 'decision_function', 'auto'}, \ 397 default='auto' 398 Specifies whether to use :term:`predict_proba` or 399 :term:`decision_function` as the target response. If set to 'auto', 400 :term:`predict_proba` is tried first and if it does not exist 401 :term:`decision_function` is tried next. 402 403 name : str, default=None 404 Name for labeling curve. If `None`, the name of the 405 estimator is used. 406 407 ax : matplotlib axes, default=None 408 Axes object to plot on. If `None`, a new figure and axes is created. 409 410 pos_label : str or int, default=None 411 The class considered as the positive class when computing the precision 412 and recall metrics. By default, `estimators.classes_[1]` is considered 413 as the positive class. 414 415 .. versionadded:: 0.24 416 417 **kwargs : dict 418 Keyword arguments to be passed to matplotlib's `plot`. 419 420 Returns 421 ------- 422 display : :class:`~sklearn.metrics.PrecisionRecallDisplay` 423 Object that stores computed values. 424 425 See Also 426 -------- 427 precision_recall_curve : Compute precision-recall pairs for different 428 probability thresholds. 429 PrecisionRecallDisplay : Precision Recall visualization. 430 """ 431 check_matplotlib_support("plot_precision_recall_curve") 432 433 y_pred, pos_label = _get_response( 434 X, estimator, response_method, pos_label=pos_label 435 ) 436 437 precision, recall, _ = precision_recall_curve( 438 y, y_pred, pos_label=pos_label, sample_weight=sample_weight 439 ) 440 average_precision = average_precision_score( 441 y, y_pred, pos_label=pos_label, sample_weight=sample_weight 442 ) 443 444 name = name if name is not None else estimator.__class__.__name__ 445 446 viz = PrecisionRecallDisplay( 447 precision=precision, 448 recall=recall, 449 average_precision=average_precision, 450 estimator_name=name, 451 pos_label=pos_label, 452 ) 453 454 return viz.plot(ax=ax, name=name, **kwargs) 455