1from .base import _get_response
2
3from .. import auc
4from .. import roc_curve
5from .._base import _check_pos_label_consistency
6
7from ...utils import check_matplotlib_support, deprecated
8
9
10class RocCurveDisplay:
11    """ROC Curve visualization.
12
13    It is recommend to use
14    :func:`~sklearn.metrics.RocCurveDisplay.from_estimator` or
15    :func:`~sklearn.metrics.RocCurveDisplay.from_predictions` to create
16    a :class:`~sklearn.metrics.RocCurveDisplay`. All parameters are
17    stored as attributes.
18
19    Read more in the :ref:`User Guide <visualizations>`.
20
21    Parameters
22    ----------
23    fpr : ndarray
24        False positive rate.
25
26    tpr : ndarray
27        True positive rate.
28
29    roc_auc : float, default=None
30        Area under ROC curve. If None, the roc_auc score is not shown.
31
32    estimator_name : str, default=None
33        Name of estimator. If None, the estimator name is not shown.
34
35    pos_label : str or int, default=None
36        The class considered as the positive class when computing the roc auc
37        metrics. By default, `estimators.classes_[1]` is considered
38        as the positive class.
39
40        .. versionadded:: 0.24
41
42    Attributes
43    ----------
44    line_ : matplotlib Artist
45        ROC Curve.
46
47    ax_ : matplotlib Axes
48        Axes with ROC Curve.
49
50    figure_ : matplotlib Figure
51        Figure containing the curve.
52
53    See Also
54    --------
55    roc_curve : Compute Receiver operating characteristic (ROC) curve.
56    RocCurveDisplay.from_estimator : Plot Receiver Operating Characteristic
57        (ROC) curve given an estimator and some data.
58    RocCurveDisplay.from_predictions : Plot Receiver Operating Characteristic
59        (ROC) curve given the true and predicted values.
60    roc_auc_score : Compute the area under the ROC curve.
61
62    Examples
63    --------
64    >>> import matplotlib.pyplot as plt
65    >>> import numpy as np
66    >>> from sklearn import metrics
67    >>> y = np.array([0, 0, 1, 1])
68    >>> pred = np.array([0.1, 0.4, 0.35, 0.8])
69    >>> fpr, tpr, thresholds = metrics.roc_curve(y, pred)
70    >>> roc_auc = metrics.auc(fpr, tpr)
71    >>> display = metrics.RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc,
72    ...                                   estimator_name='example estimator')
73    >>> display.plot()
74    <...>
75    >>> plt.show()
76    """
77
78    def __init__(self, *, fpr, tpr, roc_auc=None, estimator_name=None, pos_label=None):
79        self.estimator_name = estimator_name
80        self.fpr = fpr
81        self.tpr = tpr
82        self.roc_auc = roc_auc
83        self.pos_label = pos_label
84
85    def plot(self, ax=None, *, name=None, **kwargs):
86        """Plot visualization
87
88        Extra keyword arguments will be passed to matplotlib's ``plot``.
89
90        Parameters
91        ----------
92        ax : matplotlib axes, default=None
93            Axes object to plot on. If `None`, a new figure and axes is
94            created.
95
96        name : str, default=None
97            Name of ROC Curve for labeling. If `None`, use `estimator_name` if
98            not `None`, otherwise no labeling is shown.
99
100        Returns
101        -------
102        display : :class:`~sklearn.metrics.plot.RocCurveDisplay`
103            Object that stores computed values.
104        """
105        check_matplotlib_support("RocCurveDisplay.plot")
106
107        name = self.estimator_name if name is None else name
108
109        line_kwargs = {}
110        if self.roc_auc is not None and name is not None:
111            line_kwargs["label"] = f"{name} (AUC = {self.roc_auc:0.2f})"
112        elif self.roc_auc is not None:
113            line_kwargs["label"] = f"AUC = {self.roc_auc:0.2f}"
114        elif name is not None:
115            line_kwargs["label"] = name
116
117        line_kwargs.update(**kwargs)
118
119        import matplotlib.pyplot as plt
120
121        if ax is None:
122            fig, ax = plt.subplots()
123
124        (self.line_,) = ax.plot(self.fpr, self.tpr, **line_kwargs)
125        info_pos_label = (
126            f" (Positive label: {self.pos_label})" if self.pos_label is not None else ""
127        )
128
129        xlabel = "False Positive Rate" + info_pos_label
130        ylabel = "True Positive Rate" + info_pos_label
131        ax.set(xlabel=xlabel, ylabel=ylabel)
132
133        if "label" in line_kwargs:
134            ax.legend(loc="lower right")
135
136        self.ax_ = ax
137        self.figure_ = ax.figure
138        return self
139
140    @classmethod
141    def from_estimator(
142        cls,
143        estimator,
144        X,
145        y,
146        *,
147        sample_weight=None,
148        drop_intermediate=True,
149        response_method="auto",
150        pos_label=None,
151        name=None,
152        ax=None,
153        **kwargs,
154    ):
155        """Create a ROC Curve display from an estimator.
156
157        Parameters
158        ----------
159        estimator : estimator instance
160            Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline`
161            in which the last estimator is a classifier.
162
163        X : {array-like, sparse matrix} of shape (n_samples, n_features)
164            Input values.
165
166        y : array-like of shape (n_samples,)
167            Target values.
168
169        sample_weight : array-like of shape (n_samples,), default=None
170            Sample weights.
171
172        drop_intermediate : bool, default=True
173            Whether to drop some suboptimal thresholds which would not appear
174            on a plotted ROC curve. This is useful in order to create lighter
175            ROC curves.
176
177        response_method : {'predict_proba', 'decision_function', 'auto'} \
178                default='auto'
179            Specifies whether to use :term:`predict_proba` or
180            :term:`decision_function` as the target response. If set to 'auto',
181            :term:`predict_proba` is tried first and if it does not exist
182            :term:`decision_function` is tried next.
183
184        pos_label : str or int, default=None
185            The class considered as the positive class when computing the roc auc
186            metrics. By default, `estimators.classes_[1]` is considered
187            as the positive class.
188
189        name : str, default=None
190            Name of ROC Curve for labeling. If `None`, use the name of the
191            estimator.
192
193        ax : matplotlib axes, default=None
194            Axes object to plot on. If `None`, a new figure and axes is created.
195
196        **kwargs : dict
197            Keyword arguments to be passed to matplotlib's `plot`.
198
199        Returns
200        -------
201        display : :class:`~sklearn.metrics.plot.RocCurveDisplay`
202            The ROC Curve display.
203
204        See Also
205        --------
206        roc_curve : Compute Receiver operating characteristic (ROC) curve.
207        RocCurveDisplay.from_predictions : ROC Curve visualization given the
208            probabilities of scores of a classifier.
209        roc_auc_score : Compute the area under the ROC curve.
210
211        Examples
212        --------
213        >>> import matplotlib.pyplot as plt
214        >>> from sklearn.datasets import make_classification
215        >>> from sklearn.metrics import RocCurveDisplay
216        >>> from sklearn.model_selection import train_test_split
217        >>> from sklearn.svm import SVC
218        >>> X, y = make_classification(random_state=0)
219        >>> X_train, X_test, y_train, y_test = train_test_split(
220        ...     X, y, random_state=0)
221        >>> clf = SVC(random_state=0).fit(X_train, y_train)
222        >>> RocCurveDisplay.from_estimator(
223        ...    clf, X_test, y_test)
224        <...>
225        >>> plt.show()
226        """
227        check_matplotlib_support(f"{cls.__name__}.from_estimator")
228
229        name = estimator.__class__.__name__ if name is None else name
230
231        y_pred, pos_label = _get_response(
232            X,
233            estimator,
234            response_method=response_method,
235            pos_label=pos_label,
236        )
237
238        return cls.from_predictions(
239            y_true=y,
240            y_pred=y_pred,
241            sample_weight=sample_weight,
242            drop_intermediate=drop_intermediate,
243            name=name,
244            ax=ax,
245            pos_label=pos_label,
246            **kwargs,
247        )
248
249    @classmethod
250    def from_predictions(
251        cls,
252        y_true,
253        y_pred,
254        *,
255        sample_weight=None,
256        drop_intermediate=True,
257        pos_label=None,
258        name=None,
259        ax=None,
260        **kwargs,
261    ):
262        """Plot ROC curve given the true and predicted values.
263
264        Read more in the :ref:`User Guide <visualizations>`.
265
266        .. versionadded:: 1.0
267
268        Parameters
269        ----------
270        y_true : array-like of shape (n_samples,)
271            True labels.
272
273        y_pred : array-like of shape (n_samples,)
274            Target scores, can either be probability estimates of the positive
275            class, confidence values, or non-thresholded measure of decisions
276            (as returned by “decision_function” on some classifiers).
277
278        sample_weight : array-like of shape (n_samples,), default=None
279            Sample weights.
280
281        drop_intermediate : bool, default=True
282            Whether to drop some suboptimal thresholds which would not appear
283            on a plotted ROC curve. This is useful in order to create lighter
284            ROC curves.
285
286        pos_label : str or int, default=None
287            The label of the positive class. When `pos_label=None`, if `y_true`
288            is in {-1, 1} or {0, 1}, `pos_label` is set to 1, otherwise an
289            error will be raised.
290
291        name : str, default=None
292            Name of ROC curve for labeling. If `None`, name will be set to
293            `"Classifier"`.
294
295        ax : matplotlib axes, default=None
296            Axes object to plot on. If `None`, a new figure and axes is
297            created.
298
299        **kwargs : dict
300            Additional keywords arguments passed to matplotlib `plot` function.
301
302        Returns
303        -------
304        display : :class:`~sklearn.metrics.DetCurveDisplay`
305            Object that stores computed values.
306
307        See Also
308        --------
309        roc_curve : Compute Receiver operating characteristic (ROC) curve.
310        RocCurveDisplay.from_estimator : ROC Curve visualization given an
311            estimator and some data.
312        roc_auc_score : Compute the area under the ROC curve.
313
314        Examples
315        --------
316        >>> import matplotlib.pyplot as plt
317        >>> from sklearn.datasets import make_classification
318        >>> from sklearn.metrics import RocCurveDisplay
319        >>> from sklearn.model_selection import train_test_split
320        >>> from sklearn.svm import SVC
321        >>> X, y = make_classification(random_state=0)
322        >>> X_train, X_test, y_train, y_test = train_test_split(
323        ...     X, y, random_state=0)
324        >>> clf = SVC(random_state=0).fit(X_train, y_train)
325        >>> y_pred = clf.decision_function(X_test)
326        >>> RocCurveDisplay.from_predictions(
327        ...    y_test, y_pred)
328        <...>
329        >>> plt.show()
330        """
331        check_matplotlib_support(f"{cls.__name__}.from_predictions")
332
333        fpr, tpr, _ = roc_curve(
334            y_true,
335            y_pred,
336            pos_label=pos_label,
337            sample_weight=sample_weight,
338            drop_intermediate=drop_intermediate,
339        )
340        roc_auc = auc(fpr, tpr)
341
342        name = "Classifier" if name is None else name
343        pos_label = _check_pos_label_consistency(pos_label, y_true)
344
345        viz = RocCurveDisplay(
346            fpr=fpr, tpr=tpr, roc_auc=roc_auc, estimator_name=name, pos_label=pos_label
347        )
348
349        return viz.plot(ax=ax, name=name, **kwargs)
350
351
352@deprecated(
353    "Function :func:`plot_roc_curve` is deprecated in 1.0 and will be "
354    "removed in 1.2. Use one of the class methods: "
355    ":meth:`sklearn.metric.RocCurveDisplay.from_predictions` or "
356    ":meth:`sklearn.metric.RocCurveDisplay.from_estimator`."
357)
358def plot_roc_curve(
359    estimator,
360    X,
361    y,
362    *,
363    sample_weight=None,
364    drop_intermediate=True,
365    response_method="auto",
366    name=None,
367    ax=None,
368    pos_label=None,
369    **kwargs,
370):
371    """Plot Receiver operating characteristic (ROC) curve.
372
373    Extra keyword arguments will be passed to matplotlib's `plot`.
374
375    Read more in the :ref:`User Guide <visualizations>`.
376
377    Parameters
378    ----------
379    estimator : estimator instance
380        Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline`
381        in which the last estimator is a classifier.
382
383    X : {array-like, sparse matrix} of shape (n_samples, n_features)
384        Input values.
385
386    y : array-like of shape (n_samples,)
387        Target values.
388
389    sample_weight : array-like of shape (n_samples,), default=None
390        Sample weights.
391
392    drop_intermediate : bool, default=True
393        Whether to drop some suboptimal thresholds which would not appear
394        on a plotted ROC curve. This is useful in order to create lighter
395        ROC curves.
396
397    response_method : {'predict_proba', 'decision_function', 'auto'} \
398            default='auto'
399        Specifies whether to use :term:`predict_proba` or
400        :term:`decision_function` as the target response. If set to 'auto',
401        :term:`predict_proba` is tried first and if it does not exist
402        :term:`decision_function` is tried next.
403
404    name : str, default=None
405        Name of ROC Curve for labeling. If `None`, use the name of the
406        estimator.
407
408    ax : matplotlib axes, default=None
409        Axes object to plot on. If `None`, a new figure and axes is created.
410
411    pos_label : str or int, default=None
412        The class considered as the positive class when computing the roc auc
413        metrics. By default, `estimators.classes_[1]` is considered
414        as the positive class.
415
416    **kwargs : dict
417        Additional keywords arguments passed to matplotlib `plot` function.
418
419        .. versionadded:: 0.24
420
421    Returns
422    -------
423    display : :class:`~sklearn.metrics.RocCurveDisplay`
424        Object that stores computed values.
425
426    See Also
427    --------
428    roc_curve : Compute Receiver operating characteristic (ROC) curve.
429    RocCurveDisplay.from_estimator : ROC Curve visualization given an estimator
430        and some data.
431    RocCurveDisplay.from_predictions : ROC Curve visualisation given the
432        true and predicted values.
433    roc_auc_score : Compute the area under the ROC curve.
434
435    Examples
436    --------
437    >>> import matplotlib.pyplot as plt
438    >>> from sklearn import datasets, metrics, model_selection, svm
439    >>> X, y = datasets.make_classification(random_state=0)
440    >>> X_train, X_test, y_train, y_test = model_selection.train_test_split(
441    ...     X, y, random_state=0)
442    >>> clf = svm.SVC(random_state=0)
443    >>> clf.fit(X_train, y_train)
444    SVC(random_state=0)
445    >>> metrics.plot_roc_curve(clf, X_test, y_test) # doctest: +SKIP
446    <...>
447    >>> plt.show()
448    """
449    check_matplotlib_support("plot_roc_curve")
450
451    y_pred, pos_label = _get_response(
452        X, estimator, response_method, pos_label=pos_label
453    )
454
455    fpr, tpr, _ = roc_curve(
456        y,
457        y_pred,
458        pos_label=pos_label,
459        sample_weight=sample_weight,
460        drop_intermediate=drop_intermediate,
461    )
462    roc_auc = auc(fpr, tpr)
463
464    name = estimator.__class__.__name__ if name is None else name
465
466    viz = RocCurveDisplay(
467        fpr=fpr, tpr=tpr, roc_auc=roc_auc, estimator_name=name, pos_label=pos_label
468    )
469
470    return viz.plot(ax=ax, name=name, **kwargs)
471