1from sklearn.base import is_classifier
2from .base import _get_response
3
4from .. import average_precision_score
5from .. import precision_recall_curve
6from .._base import _check_pos_label_consistency
7from .._classification import check_consistent_length
8
9from ...utils import check_matplotlib_support, deprecated
10
11
12class PrecisionRecallDisplay:
13    """Precision Recall visualization.
14
15    It is recommend to use
16    :func:`~sklearn.metrics.PrecisionRecallDisplay.from_estimator` or
17    :func:`~sklearn.metrics.PrecisionRecallDisplay.from_predictions` to create
18    a :class:`~sklearn.metrics.PredictionRecallDisplay`. All parameters are
19    stored as attributes.
20
21    Read more in the :ref:`User Guide <visualizations>`.
22
23    Parameters
24    -----------
25    precision : ndarray
26        Precision values.
27
28    recall : ndarray
29        Recall values.
30
31    average_precision : float, default=None
32        Average precision. If None, the average precision is not shown.
33
34    estimator_name : str, default=None
35        Name of estimator. If None, then the estimator name is not shown.
36
37    pos_label : str or int, default=None
38        The class considered as the positive class. If None, the class will not
39        be shown in the legend.
40
41        .. versionadded:: 0.24
42
43    Attributes
44    ----------
45    line_ : matplotlib Artist
46        Precision recall curve.
47
48    ax_ : matplotlib Axes
49        Axes with precision recall curve.
50
51    figure_ : matplotlib Figure
52        Figure containing the curve.
53
54    See Also
55    --------
56    precision_recall_curve : Compute precision-recall pairs for different
57        probability thresholds.
58    PrecisionRecallDisplay.from_estimator : Plot Precision Recall Curve given
59        a binary classifier.
60    PrecisionRecallDisplay.from_predictions : Plot Precision Recall Curve
61        using predictions from a binary classifier.
62
63    Examples
64    --------
65    >>> import matplotlib.pyplot as plt
66    >>> from sklearn.datasets import make_classification
67    >>> from sklearn.metrics import (precision_recall_curve,
68    ...                              PrecisionRecallDisplay)
69    >>> from sklearn.model_selection import train_test_split
70    >>> from sklearn.svm import SVC
71    >>> X, y = make_classification(random_state=0)
72    >>> X_train, X_test, y_train, y_test = train_test_split(X, y,
73    ...                                                     random_state=0)
74    >>> clf = SVC(random_state=0)
75    >>> clf.fit(X_train, y_train)
76    SVC(random_state=0)
77    >>> predictions = clf.predict(X_test)
78    >>> precision, recall, _ = precision_recall_curve(y_test, predictions)
79    >>> disp = PrecisionRecallDisplay(precision=precision, recall=recall)
80    >>> disp.plot()
81    <...>
82    >>> plt.show()
83    """
84
85    def __init__(
86        self,
87        precision,
88        recall,
89        *,
90        average_precision=None,
91        estimator_name=None,
92        pos_label=None,
93    ):
94        self.estimator_name = estimator_name
95        self.precision = precision
96        self.recall = recall
97        self.average_precision = average_precision
98        self.pos_label = pos_label
99
100    def plot(self, ax=None, *, name=None, **kwargs):
101        """Plot visualization.
102
103        Extra keyword arguments will be passed to matplotlib's `plot`.
104
105        Parameters
106        ----------
107        ax : Matplotlib Axes, default=None
108            Axes object to plot on. If `None`, a new figure and axes is
109            created.
110
111        name : str, default=None
112            Name of precision recall curve for labeling. If `None`, use
113            `estimator_name` if not `None`, otherwise no labeling is shown.
114
115        **kwargs : dict
116            Keyword arguments to be passed to matplotlib's `plot`.
117
118        Returns
119        -------
120        display : :class:`~sklearn.metrics.PrecisionRecallDisplay`
121            Object that stores computed values.
122        """
123        check_matplotlib_support("PrecisionRecallDisplay.plot")
124
125        name = self.estimator_name if name is None else name
126
127        line_kwargs = {"drawstyle": "steps-post"}
128        if self.average_precision is not None and name is not None:
129            line_kwargs["label"] = f"{name} (AP = {self.average_precision:0.2f})"
130        elif self.average_precision is not None:
131            line_kwargs["label"] = f"AP = {self.average_precision:0.2f}"
132        elif name is not None:
133            line_kwargs["label"] = name
134        line_kwargs.update(**kwargs)
135
136        import matplotlib.pyplot as plt
137
138        if ax is None:
139            fig, ax = plt.subplots()
140
141        (self.line_,) = ax.plot(self.recall, self.precision, **line_kwargs)
142        info_pos_label = (
143            f" (Positive label: {self.pos_label})" if self.pos_label is not None else ""
144        )
145
146        xlabel = "Recall" + info_pos_label
147        ylabel = "Precision" + info_pos_label
148        ax.set(xlabel=xlabel, ylabel=ylabel)
149
150        if "label" in line_kwargs:
151            ax.legend(loc="lower left")
152
153        self.ax_ = ax
154        self.figure_ = ax.figure
155        return self
156
157    @classmethod
158    def from_estimator(
159        cls,
160        estimator,
161        X,
162        y,
163        *,
164        sample_weight=None,
165        pos_label=None,
166        response_method="auto",
167        name=None,
168        ax=None,
169        **kwargs,
170    ):
171        """Plot precision-recall curve given an estimator and some data.
172
173        Parameters
174        ----------
175        estimator : estimator instance
176            Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline`
177            in which the last estimator is a classifier.
178
179        X : {array-like, sparse matrix} of shape (n_samples, n_features)
180            Input values.
181
182        y : array-like of shape (n_samples,)
183            Target values.
184
185        sample_weight : array-like of shape (n_samples,), default=None
186            Sample weights.
187
188        pos_label : str or int, default=None
189            The class considered as the positive class when computing the
190            precision and recall metrics. By default, `estimators.classes_[1]`
191            is considered as the positive class.
192
193        response_method : {'predict_proba', 'decision_function', 'auto'}, \
194            default='auto'
195            Specifies whether to use :term:`predict_proba` or
196            :term:`decision_function` as the target response. If set to 'auto',
197            :term:`predict_proba` is tried first and if it does not exist
198            :term:`decision_function` is tried next.
199
200        name : str, default=None
201            Name for labeling curve. If `None`, no name is used.
202
203        ax : matplotlib axes, default=None
204            Axes object to plot on. If `None`, a new figure and axes is created.
205
206        **kwargs : dict
207            Keyword arguments to be passed to matplotlib's `plot`.
208
209        Returns
210        -------
211        display : :class:`~sklearn.metrics.PrecisionRecallDisplay`
212
213        See Also
214        --------
215        PrecisionRecallDisplay.from_predictions : Plot precision-recall curve
216            using estimated probabilities or output of decision function.
217
218        Examples
219        --------
220        >>> import matplotlib.pyplot as plt
221        >>> from sklearn.datasets import make_classification
222        >>> from sklearn.metrics import PrecisionRecallDisplay
223        >>> from sklearn.model_selection import train_test_split
224        >>> from sklearn.linear_model import LogisticRegression
225        >>> X, y = make_classification(random_state=0)
226        >>> X_train, X_test, y_train, y_test = train_test_split(
227        ...         X, y, random_state=0)
228        >>> clf = LogisticRegression()
229        >>> clf.fit(X_train, y_train)
230        LogisticRegression()
231        >>> PrecisionRecallDisplay.from_estimator(
232        ...    clf, X_test, y_test)
233        <...>
234        >>> plt.show()
235        """
236        method_name = f"{cls.__name__}.from_estimator"
237        check_matplotlib_support(method_name)
238        if not is_classifier(estimator):
239            raise ValueError(f"{method_name} only supports classifiers")
240        y_pred, pos_label = _get_response(
241            X,
242            estimator,
243            response_method,
244            pos_label=pos_label,
245        )
246
247        name = name if name is not None else estimator.__class__.__name__
248
249        return cls.from_predictions(
250            y,
251            y_pred,
252            sample_weight=sample_weight,
253            name=name,
254            pos_label=pos_label,
255            ax=ax,
256            **kwargs,
257        )
258
259    @classmethod
260    def from_predictions(
261        cls,
262        y_true,
263        y_pred,
264        *,
265        sample_weight=None,
266        pos_label=None,
267        name=None,
268        ax=None,
269        **kwargs,
270    ):
271        """Plot precision-recall curve given binary class predictions.
272
273        Parameters
274        ----------
275        y_true : array-like of shape (n_samples,)
276            True binary labels.
277
278        y_pred : array-like of shape (n_samples,)
279            Estimated probabilities or output of decision function.
280
281        sample_weight : array-like of shape (n_samples,), default=None
282            Sample weights.
283
284        pos_label : str or int, default=None
285            The class considered as the positive class when computing the
286            precision and recall metrics.
287
288        name : str, default=None
289            Name for labeling curve. If `None`, name will be set to
290            `"Classifier"`.
291
292        ax : matplotlib axes, default=None
293            Axes object to plot on. If `None`, a new figure and axes is created.
294
295        **kwargs : dict
296            Keyword arguments to be passed to matplotlib's `plot`.
297
298        Returns
299        -------
300        display : :class:`~sklearn.metrics.PrecisionRecallDisplay`
301
302        See Also
303        --------
304        PrecisionRecallDisplay.from_estimator : Plot precision-recall curve
305            using an estimator.
306
307        Examples
308        --------
309        >>> import matplotlib.pyplot as plt
310        >>> from sklearn.datasets import make_classification
311        >>> from sklearn.metrics import PrecisionRecallDisplay
312        >>> from sklearn.model_selection import train_test_split
313        >>> from sklearn.linear_model import LogisticRegression
314        >>> X, y = make_classification(random_state=0)
315        >>> X_train, X_test, y_train, y_test = train_test_split(
316        ...         X, y, random_state=0)
317        >>> clf = LogisticRegression()
318        >>> clf.fit(X_train, y_train)
319        LogisticRegression()
320        >>> y_pred = clf.predict_proba(X_test)[:, 1]
321        >>> PrecisionRecallDisplay.from_predictions(
322        ...    y_test, y_pred)
323        <...>
324        >>> plt.show()
325        """
326        check_matplotlib_support(f"{cls.__name__}.from_predictions")
327
328        check_consistent_length(y_true, y_pred, sample_weight)
329        pos_label = _check_pos_label_consistency(pos_label, y_true)
330
331        precision, recall, _ = precision_recall_curve(
332            y_true, y_pred, pos_label=pos_label, sample_weight=sample_weight
333        )
334        average_precision = average_precision_score(
335            y_true, y_pred, pos_label=pos_label, sample_weight=sample_weight
336        )
337
338        name = name if name is not None else "Classifier"
339
340        viz = PrecisionRecallDisplay(
341            precision=precision,
342            recall=recall,
343            average_precision=average_precision,
344            estimator_name=name,
345            pos_label=pos_label,
346        )
347
348        return viz.plot(ax=ax, name=name, **kwargs)
349
350
351@deprecated(
352    "Function `plot_precision_recall_curve` is deprecated in 1.0 and will be "
353    "removed in 1.2. Use one of the class methods: "
354    "PrecisionRecallDisplay.from_predictions or "
355    "PrecisionRecallDisplay.from_estimator."
356)
357def plot_precision_recall_curve(
358    estimator,
359    X,
360    y,
361    *,
362    sample_weight=None,
363    response_method="auto",
364    name=None,
365    ax=None,
366    pos_label=None,
367    **kwargs,
368):
369    """Plot Precision Recall Curve for binary classifiers.
370
371    Extra keyword arguments will be passed to matplotlib's `plot`.
372
373    Read more in the :ref:`User Guide <precision_recall_f_measure_metrics>`.
374
375    .. deprecated:: 1.0
376       `plot_precision_recall_curve` is deprecated in 1.0 and will be removed in
377       1.2. Use one of the following class methods:
378       :func:`~sklearn.metrics.PrecisionRecallDisplay.from_predictions` or
379       :func:`~sklearn.metrics.PrecisionRecallDisplay.from_estimator`.
380
381    Parameters
382    ----------
383    estimator : estimator instance
384        Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline`
385        in which the last estimator is a classifier.
386
387    X : {array-like, sparse matrix} of shape (n_samples, n_features)
388        Input values.
389
390    y : array-like of shape (n_samples,)
391        Binary target values.
392
393    sample_weight : array-like of shape (n_samples,), default=None
394        Sample weights.
395
396    response_method : {'predict_proba', 'decision_function', 'auto'}, \
397                      default='auto'
398        Specifies whether to use :term:`predict_proba` or
399        :term:`decision_function` as the target response. If set to 'auto',
400        :term:`predict_proba` is tried first and if it does not exist
401        :term:`decision_function` is tried next.
402
403    name : str, default=None
404        Name for labeling curve. If `None`, the name of the
405        estimator is used.
406
407    ax : matplotlib axes, default=None
408        Axes object to plot on. If `None`, a new figure and axes is created.
409
410    pos_label : str or int, default=None
411        The class considered as the positive class when computing the precision
412        and recall metrics. By default, `estimators.classes_[1]` is considered
413        as the positive class.
414
415        .. versionadded:: 0.24
416
417    **kwargs : dict
418        Keyword arguments to be passed to matplotlib's `plot`.
419
420    Returns
421    -------
422    display : :class:`~sklearn.metrics.PrecisionRecallDisplay`
423        Object that stores computed values.
424
425    See Also
426    --------
427    precision_recall_curve : Compute precision-recall pairs for different
428        probability thresholds.
429    PrecisionRecallDisplay : Precision Recall visualization.
430    """
431    check_matplotlib_support("plot_precision_recall_curve")
432
433    y_pred, pos_label = _get_response(
434        X, estimator, response_method, pos_label=pos_label
435    )
436
437    precision, recall, _ = precision_recall_curve(
438        y, y_pred, pos_label=pos_label, sample_weight=sample_weight
439    )
440    average_precision = average_precision_score(
441        y, y_pred, pos_label=pos_label, sample_weight=sample_weight
442    )
443
444    name = name if name is not None else estimator.__class__.__name__
445
446    viz = PrecisionRecallDisplay(
447        precision=precision,
448        recall=recall,
449        average_precision=average_precision,
450        estimator_name=name,
451        pos_label=pos_label,
452    )
453
454    return viz.plot(ax=ax, name=name, **kwargs)
455