1import re
2import pytest
3import numpy as np
4import warnings
5from scipy.sparse import csr_matrix
6
7from sklearn import datasets
8from sklearn import svm
9
10from sklearn.utils.extmath import softmax
11from sklearn.datasets import make_multilabel_classification
12from sklearn.random_projection import _sparse_random_matrix
13from sklearn.utils.validation import check_array, check_consistent_length
14from sklearn.utils.validation import check_random_state
15
16from sklearn.utils._testing import assert_allclose
17from sklearn.utils._testing import assert_almost_equal
18from sklearn.utils._testing import assert_array_equal
19from sklearn.utils._testing import assert_array_almost_equal
20
21from sklearn.metrics import accuracy_score
22from sklearn.metrics import auc
23from sklearn.metrics import average_precision_score
24from sklearn.metrics import coverage_error
25from sklearn.metrics import det_curve
26from sklearn.metrics import label_ranking_average_precision_score
27from sklearn.metrics import precision_recall_curve
28from sklearn.metrics import label_ranking_loss
29from sklearn.metrics import roc_auc_score
30from sklearn.metrics import roc_curve
31from sklearn.metrics._ranking import _ndcg_sample_scores, _dcg_sample_scores
32from sklearn.metrics import ndcg_score, dcg_score
33from sklearn.metrics import top_k_accuracy_score
34
35from sklearn.exceptions import UndefinedMetricWarning
36from sklearn.model_selection import train_test_split
37from sklearn.linear_model import LogisticRegression
38
39
40###############################################################################
41# Utilities for testing
42
43CURVE_FUNCS = [
44    det_curve,
45    precision_recall_curve,
46    roc_curve,
47]
48
49
50def make_prediction(dataset=None, binary=False):
51    """Make some classification predictions on a toy dataset using a SVC
52
53    If binary is True restrict to a binary classification problem instead of a
54    multiclass classification problem
55    """
56
57    if dataset is None:
58        # import some data to play with
59        dataset = datasets.load_iris()
60
61    X = dataset.data
62    y = dataset.target
63
64    if binary:
65        # restrict to a binary classification task
66        X, y = X[y < 2], y[y < 2]
67
68    n_samples, n_features = X.shape
69    p = np.arange(n_samples)
70
71    rng = check_random_state(37)
72    rng.shuffle(p)
73    X, y = X[p], y[p]
74    half = int(n_samples / 2)
75
76    # add noisy features to make the problem harder and avoid perfect results
77    rng = np.random.RandomState(0)
78    X = np.c_[X, rng.randn(n_samples, 200 * n_features)]
79
80    # run classifier, get class probabilities and label predictions
81    clf = svm.SVC(kernel="linear", probability=True, random_state=0)
82    y_score = clf.fit(X[:half], y[:half]).predict_proba(X[half:])
83
84    if binary:
85        # only interested in probabilities of the positive case
86        # XXX: do we really want a special API for the binary case?
87        y_score = y_score[:, 1]
88
89    y_pred = clf.predict(X[half:])
90    y_true = y[half:]
91    return y_true, y_pred, y_score
92
93
94###############################################################################
95# Tests
96
97
98def _auc(y_true, y_score):
99    """Alternative implementation to check for correctness of
100    `roc_auc_score`."""
101    pos_label = np.unique(y_true)[1]
102
103    # Count the number of times positive samples are correctly ranked above
104    # negative samples.
105    pos = y_score[y_true == pos_label]
106    neg = y_score[y_true != pos_label]
107    diff_matrix = pos.reshape(1, -1) - neg.reshape(-1, 1)
108    n_correct = np.sum(diff_matrix > 0)
109
110    return n_correct / float(len(pos) * len(neg))
111
112
113def _average_precision(y_true, y_score):
114    """Alternative implementation to check for correctness of
115    `average_precision_score`.
116
117    Note that this implementation fails on some edge cases.
118    For example, for constant predictions e.g. [0.5, 0.5, 0.5],
119    y_true = [1, 0, 0] returns an average precision of 0.33...
120    but y_true = [0, 0, 1] returns 1.0.
121    """
122    pos_label = np.unique(y_true)[1]
123    n_pos = np.sum(y_true == pos_label)
124    order = np.argsort(y_score)[::-1]
125    y_score = y_score[order]
126    y_true = y_true[order]
127
128    score = 0
129    for i in range(len(y_score)):
130        if y_true[i] == pos_label:
131            # Compute precision up to document i
132            # i.e, percentage of relevant documents up to document i.
133            prec = 0
134            for j in range(0, i + 1):
135                if y_true[j] == pos_label:
136                    prec += 1.0
137            prec /= i + 1.0
138            score += prec
139
140    return score / n_pos
141
142
143def _average_precision_slow(y_true, y_score):
144    """A second alternative implementation of average precision that closely
145    follows the Wikipedia article's definition (see References). This should
146    give identical results as `average_precision_score` for all inputs.
147
148    References
149    ----------
150    .. [1] `Wikipedia entry for the Average precision
151       <https://en.wikipedia.org/wiki/Average_precision>`_
152    """
153    precision, recall, threshold = precision_recall_curve(y_true, y_score)
154    precision = list(reversed(precision))
155    recall = list(reversed(recall))
156    average_precision = 0
157    for i in range(1, len(precision)):
158        average_precision += precision[i] * (recall[i] - recall[i - 1])
159    return average_precision
160
161
162def _partial_roc_auc_score(y_true, y_predict, max_fpr):
163    """Alternative implementation to check for correctness of `roc_auc_score`
164    with `max_fpr` set.
165    """
166
167    def _partial_roc(y_true, y_predict, max_fpr):
168        fpr, tpr, _ = roc_curve(y_true, y_predict)
169        new_fpr = fpr[fpr <= max_fpr]
170        new_fpr = np.append(new_fpr, max_fpr)
171        new_tpr = tpr[fpr <= max_fpr]
172        idx_out = np.argmax(fpr > max_fpr)
173        idx_in = idx_out - 1
174        x_interp = [fpr[idx_in], fpr[idx_out]]
175        y_interp = [tpr[idx_in], tpr[idx_out]]
176        new_tpr = np.append(new_tpr, np.interp(max_fpr, x_interp, y_interp))
177        return (new_fpr, new_tpr)
178
179    new_fpr, new_tpr = _partial_roc(y_true, y_predict, max_fpr)
180    partial_auc = auc(new_fpr, new_tpr)
181
182    # Formula (5) from McClish 1989
183    fpr1 = 0
184    fpr2 = max_fpr
185    min_area = 0.5 * (fpr2 - fpr1) * (fpr2 + fpr1)
186    max_area = fpr2 - fpr1
187    return 0.5 * (1 + (partial_auc - min_area) / (max_area - min_area))
188
189
190@pytest.mark.parametrize("drop", [True, False])
191def test_roc_curve(drop):
192    # Test Area under Receiver Operating Characteristic (ROC) curve
193    y_true, _, y_score = make_prediction(binary=True)
194    expected_auc = _auc(y_true, y_score)
195
196    fpr, tpr, thresholds = roc_curve(y_true, y_score, drop_intermediate=drop)
197    roc_auc = auc(fpr, tpr)
198    assert_array_almost_equal(roc_auc, expected_auc, decimal=2)
199    assert_almost_equal(roc_auc, roc_auc_score(y_true, y_score))
200    assert fpr.shape == tpr.shape
201    assert fpr.shape == thresholds.shape
202
203
204def test_roc_curve_end_points():
205    # Make sure that roc_curve returns a curve start at 0 and ending and
206    # 1 even in corner cases
207    rng = np.random.RandomState(0)
208    y_true = np.array([0] * 50 + [1] * 50)
209    y_pred = rng.randint(3, size=100)
210    fpr, tpr, thr = roc_curve(y_true, y_pred, drop_intermediate=True)
211    assert fpr[0] == 0
212    assert fpr[-1] == 1
213    assert fpr.shape == tpr.shape
214    assert fpr.shape == thr.shape
215
216
217def test_roc_returns_consistency():
218    # Test whether the returned threshold matches up with tpr
219    # make small toy dataset
220    y_true, _, y_score = make_prediction(binary=True)
221    fpr, tpr, thresholds = roc_curve(y_true, y_score)
222
223    # use the given thresholds to determine the tpr
224    tpr_correct = []
225    for t in thresholds:
226        tp = np.sum((y_score >= t) & y_true)
227        p = np.sum(y_true)
228        tpr_correct.append(1.0 * tp / p)
229
230    # compare tpr and tpr_correct to see if the thresholds' order was correct
231    assert_array_almost_equal(tpr, tpr_correct, decimal=2)
232    assert fpr.shape == tpr.shape
233    assert fpr.shape == thresholds.shape
234
235
236def test_roc_curve_multi():
237    # roc_curve not applicable for multi-class problems
238    y_true, _, y_score = make_prediction(binary=False)
239
240    with pytest.raises(ValueError):
241        roc_curve(y_true, y_score)
242
243
244def test_roc_curve_confidence():
245    # roc_curve for confidence scores
246    y_true, _, y_score = make_prediction(binary=True)
247
248    fpr, tpr, thresholds = roc_curve(y_true, y_score - 0.5)
249    roc_auc = auc(fpr, tpr)
250    assert_array_almost_equal(roc_auc, 0.90, decimal=2)
251    assert fpr.shape == tpr.shape
252    assert fpr.shape == thresholds.shape
253
254
255def test_roc_curve_hard():
256    # roc_curve for hard decisions
257    y_true, pred, y_score = make_prediction(binary=True)
258
259    # always predict one
260    trivial_pred = np.ones(y_true.shape)
261    fpr, tpr, thresholds = roc_curve(y_true, trivial_pred)
262    roc_auc = auc(fpr, tpr)
263    assert_array_almost_equal(roc_auc, 0.50, decimal=2)
264    assert fpr.shape == tpr.shape
265    assert fpr.shape == thresholds.shape
266
267    # always predict zero
268    trivial_pred = np.zeros(y_true.shape)
269    fpr, tpr, thresholds = roc_curve(y_true, trivial_pred)
270    roc_auc = auc(fpr, tpr)
271    assert_array_almost_equal(roc_auc, 0.50, decimal=2)
272    assert fpr.shape == tpr.shape
273    assert fpr.shape == thresholds.shape
274
275    # hard decisions
276    fpr, tpr, thresholds = roc_curve(y_true, pred)
277    roc_auc = auc(fpr, tpr)
278    assert_array_almost_equal(roc_auc, 0.78, decimal=2)
279    assert fpr.shape == tpr.shape
280    assert fpr.shape == thresholds.shape
281
282
283def test_roc_curve_one_label():
284    y_true = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
285    y_pred = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
286    # assert there are warnings
287    expected_message = (
288        "No negative samples in y_true, false positive value should be meaningless"
289    )
290    with pytest.warns(UndefinedMetricWarning, match=expected_message):
291        fpr, tpr, thresholds = roc_curve(y_true, y_pred)
292
293    # all true labels, all fpr should be nan
294    assert_array_equal(fpr, np.full(len(thresholds), np.nan))
295    assert fpr.shape == tpr.shape
296    assert fpr.shape == thresholds.shape
297
298    # assert there are warnings
299    expected_message = (
300        "No positive samples in y_true, true positive value should be meaningless"
301    )
302    with pytest.warns(UndefinedMetricWarning, match=expected_message):
303        fpr, tpr, thresholds = roc_curve([1 - x for x in y_true], y_pred)
304    # all negative labels, all tpr should be nan
305    assert_array_equal(tpr, np.full(len(thresholds), np.nan))
306    assert fpr.shape == tpr.shape
307    assert fpr.shape == thresholds.shape
308
309
310def test_roc_curve_toydata():
311    # Binary classification
312    y_true = [0, 1]
313    y_score = [0, 1]
314    tpr, fpr, _ = roc_curve(y_true, y_score)
315    roc_auc = roc_auc_score(y_true, y_score)
316    assert_array_almost_equal(tpr, [0, 0, 1])
317    assert_array_almost_equal(fpr, [0, 1, 1])
318    assert_almost_equal(roc_auc, 1.0)
319
320    y_true = [0, 1]
321    y_score = [1, 0]
322    tpr, fpr, _ = roc_curve(y_true, y_score)
323    roc_auc = roc_auc_score(y_true, y_score)
324    assert_array_almost_equal(tpr, [0, 1, 1])
325    assert_array_almost_equal(fpr, [0, 0, 1])
326    assert_almost_equal(roc_auc, 0.0)
327
328    y_true = [1, 0]
329    y_score = [1, 1]
330    tpr, fpr, _ = roc_curve(y_true, y_score)
331    roc_auc = roc_auc_score(y_true, y_score)
332    assert_array_almost_equal(tpr, [0, 1])
333    assert_array_almost_equal(fpr, [0, 1])
334    assert_almost_equal(roc_auc, 0.5)
335
336    y_true = [1, 0]
337    y_score = [1, 0]
338    tpr, fpr, _ = roc_curve(y_true, y_score)
339    roc_auc = roc_auc_score(y_true, y_score)
340    assert_array_almost_equal(tpr, [0, 0, 1])
341    assert_array_almost_equal(fpr, [0, 1, 1])
342    assert_almost_equal(roc_auc, 1.0)
343
344    y_true = [1, 0]
345    y_score = [0.5, 0.5]
346    tpr, fpr, _ = roc_curve(y_true, y_score)
347    roc_auc = roc_auc_score(y_true, y_score)
348    assert_array_almost_equal(tpr, [0, 1])
349    assert_array_almost_equal(fpr, [0, 1])
350    assert_almost_equal(roc_auc, 0.5)
351
352    y_true = [0, 0]
353    y_score = [0.25, 0.75]
354    # assert UndefinedMetricWarning because of no positive sample in y_true
355    expected_message = (
356        "No positive samples in y_true, true positive value should be meaningless"
357    )
358    with pytest.warns(UndefinedMetricWarning, match=expected_message):
359        tpr, fpr, _ = roc_curve(y_true, y_score)
360
361    with pytest.raises(ValueError):
362        roc_auc_score(y_true, y_score)
363    assert_array_almost_equal(tpr, [0.0, 0.5, 1.0])
364    assert_array_almost_equal(fpr, [np.nan, np.nan, np.nan])
365
366    y_true = [1, 1]
367    y_score = [0.25, 0.75]
368    # assert UndefinedMetricWarning because of no negative sample in y_true
369    expected_message = (
370        "No negative samples in y_true, false positive value should be meaningless"
371    )
372    with pytest.warns(UndefinedMetricWarning, match=expected_message):
373        tpr, fpr, _ = roc_curve(y_true, y_score)
374
375    with pytest.raises(ValueError):
376        roc_auc_score(y_true, y_score)
377    assert_array_almost_equal(tpr, [np.nan, np.nan, np.nan])
378    assert_array_almost_equal(fpr, [0.0, 0.5, 1.0])
379
380    # Multi-label classification task
381    y_true = np.array([[0, 1], [0, 1]])
382    y_score = np.array([[0, 1], [0, 1]])
383    with pytest.raises(ValueError):
384        roc_auc_score(y_true, y_score, average="macro")
385    with pytest.raises(ValueError):
386        roc_auc_score(y_true, y_score, average="weighted")
387    assert_almost_equal(roc_auc_score(y_true, y_score, average="samples"), 1.0)
388    assert_almost_equal(roc_auc_score(y_true, y_score, average="micro"), 1.0)
389
390    y_true = np.array([[0, 1], [0, 1]])
391    y_score = np.array([[0, 1], [1, 0]])
392    with pytest.raises(ValueError):
393        roc_auc_score(y_true, y_score, average="macro")
394    with pytest.raises(ValueError):
395        roc_auc_score(y_true, y_score, average="weighted")
396    assert_almost_equal(roc_auc_score(y_true, y_score, average="samples"), 0.5)
397    assert_almost_equal(roc_auc_score(y_true, y_score, average="micro"), 0.5)
398
399    y_true = np.array([[1, 0], [0, 1]])
400    y_score = np.array([[0, 1], [1, 0]])
401    assert_almost_equal(roc_auc_score(y_true, y_score, average="macro"), 0)
402    assert_almost_equal(roc_auc_score(y_true, y_score, average="weighted"), 0)
403    assert_almost_equal(roc_auc_score(y_true, y_score, average="samples"), 0)
404    assert_almost_equal(roc_auc_score(y_true, y_score, average="micro"), 0)
405
406    y_true = np.array([[1, 0], [0, 1]])
407    y_score = np.array([[0.5, 0.5], [0.5, 0.5]])
408    assert_almost_equal(roc_auc_score(y_true, y_score, average="macro"), 0.5)
409    assert_almost_equal(roc_auc_score(y_true, y_score, average="weighted"), 0.5)
410    assert_almost_equal(roc_auc_score(y_true, y_score, average="samples"), 0.5)
411    assert_almost_equal(roc_auc_score(y_true, y_score, average="micro"), 0.5)
412
413
414def test_roc_curve_drop_intermediate():
415    # Test that drop_intermediate drops the correct thresholds
416    y_true = [0, 0, 0, 0, 1, 1]
417    y_score = [0.0, 0.2, 0.5, 0.6, 0.7, 1.0]
418    tpr, fpr, thresholds = roc_curve(y_true, y_score, drop_intermediate=True)
419    assert_array_almost_equal(thresholds, [2.0, 1.0, 0.7, 0.0])
420
421    # Test dropping thresholds with repeating scores
422    y_true = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]
423    y_score = [0.0, 0.1, 0.6, 0.6, 0.7, 0.8, 0.9, 0.6, 0.7, 0.8, 0.9, 0.9, 1.0]
424    tpr, fpr, thresholds = roc_curve(y_true, y_score, drop_intermediate=True)
425    assert_array_almost_equal(thresholds, [2.0, 1.0, 0.9, 0.7, 0.6, 0.0])
426
427
428def test_roc_curve_fpr_tpr_increasing():
429    # Ensure that fpr and tpr returned by roc_curve are increasing.
430    # Construct an edge case with float y_score and sample_weight
431    # when some adjacent values of fpr and tpr are actually the same.
432    y_true = [0, 0, 1, 1, 1]
433    y_score = [0.1, 0.7, 0.3, 0.4, 0.5]
434    sample_weight = np.repeat(0.2, 5)
435    fpr, tpr, _ = roc_curve(y_true, y_score, sample_weight=sample_weight)
436    assert (np.diff(fpr) < 0).sum() == 0
437    assert (np.diff(tpr) < 0).sum() == 0
438
439
440def test_auc():
441    # Test Area Under Curve (AUC) computation
442    x = [0, 1]
443    y = [0, 1]
444    assert_array_almost_equal(auc(x, y), 0.5)
445    x = [1, 0]
446    y = [0, 1]
447    assert_array_almost_equal(auc(x, y), 0.5)
448    x = [1, 0, 0]
449    y = [0, 1, 1]
450    assert_array_almost_equal(auc(x, y), 0.5)
451    x = [0, 1]
452    y = [1, 1]
453    assert_array_almost_equal(auc(x, y), 1)
454    x = [0, 0.5, 1]
455    y = [0, 0.5, 1]
456    assert_array_almost_equal(auc(x, y), 0.5)
457
458
459def test_auc_errors():
460    # Incompatible shapes
461    with pytest.raises(ValueError):
462        auc([0.0, 0.5, 1.0], [0.1, 0.2])
463
464    # Too few x values
465    with pytest.raises(ValueError):
466        auc([0.0], [0.1])
467
468    # x is not in order
469    x = [2, 1, 3, 4]
470    y = [5, 6, 7, 8]
471    error_message = "x is neither increasing nor decreasing : {}".format(np.array(x))
472    with pytest.raises(ValueError, match=re.escape(error_message)):
473        auc(x, y)
474
475
476@pytest.mark.parametrize(
477    "y_true, labels",
478    [
479        (np.array([0, 1, 0, 2]), [0, 1, 2]),
480        (np.array([0, 1, 0, 2]), None),
481        (["a", "b", "a", "c"], ["a", "b", "c"]),
482        (["a", "b", "a", "c"], None),
483    ],
484)
485def test_multiclass_ovo_roc_auc_toydata(y_true, labels):
486    # Tests the one-vs-one multiclass ROC AUC algorithm
487    # on a small example, representative of an expected use case.
488    y_scores = np.array(
489        [[0.1, 0.8, 0.1], [0.3, 0.4, 0.3], [0.35, 0.5, 0.15], [0, 0.2, 0.8]]
490    )
491
492    # Used to compute the expected output.
493    # Consider labels 0 and 1:
494    # positive label is 0, negative label is 1
495    score_01 = roc_auc_score([1, 0, 1], [0.1, 0.3, 0.35])
496    # positive label is 1, negative label is 0
497    score_10 = roc_auc_score([0, 1, 0], [0.8, 0.4, 0.5])
498    average_score_01 = (score_01 + score_10) / 2
499
500    # Consider labels 0 and 2:
501    score_02 = roc_auc_score([1, 1, 0], [0.1, 0.35, 0])
502    score_20 = roc_auc_score([0, 0, 1], [0.1, 0.15, 0.8])
503    average_score_02 = (score_02 + score_20) / 2
504
505    # Consider labels 1 and 2:
506    score_12 = roc_auc_score([1, 0], [0.4, 0.2])
507    score_21 = roc_auc_score([0, 1], [0.3, 0.8])
508    average_score_12 = (score_12 + score_21) / 2
509
510    # Unweighted, one-vs-one multiclass ROC AUC algorithm
511    ovo_unweighted_score = (average_score_01 + average_score_02 + average_score_12) / 3
512    assert_almost_equal(
513        roc_auc_score(y_true, y_scores, labels=labels, multi_class="ovo"),
514        ovo_unweighted_score,
515    )
516
517    # Weighted, one-vs-one multiclass ROC AUC algorithm
518    # Each term is weighted by the prevalence for the positive label.
519    pair_scores = [average_score_01, average_score_02, average_score_12]
520    prevalence = [0.75, 0.75, 0.50]
521    ovo_weighted_score = np.average(pair_scores, weights=prevalence)
522    assert_almost_equal(
523        roc_auc_score(
524            y_true, y_scores, labels=labels, multi_class="ovo", average="weighted"
525        ),
526        ovo_weighted_score,
527    )
528
529
530@pytest.mark.parametrize(
531    "y_true, labels",
532    [
533        (np.array([0, 2, 0, 2]), [0, 1, 2]),
534        (np.array(["a", "d", "a", "d"]), ["a", "b", "d"]),
535    ],
536)
537def test_multiclass_ovo_roc_auc_toydata_binary(y_true, labels):
538    # Tests the one-vs-one multiclass ROC AUC algorithm for binary y_true
539    #
540    # on a small example, representative of an expected use case.
541    y_scores = np.array(
542        [[0.2, 0.0, 0.8], [0.6, 0.0, 0.4], [0.55, 0.0, 0.45], [0.4, 0.0, 0.6]]
543    )
544
545    # Used to compute the expected output.
546    # Consider labels 0 and 1:
547    # positive label is 0, negative label is 1
548    score_01 = roc_auc_score([1, 0, 1, 0], [0.2, 0.6, 0.55, 0.4])
549    # positive label is 1, negative label is 0
550    score_10 = roc_auc_score([0, 1, 0, 1], [0.8, 0.4, 0.45, 0.6])
551    ovo_score = (score_01 + score_10) / 2
552
553    assert_almost_equal(
554        roc_auc_score(y_true, y_scores, labels=labels, multi_class="ovo"), ovo_score
555    )
556
557    # Weighted, one-vs-one multiclass ROC AUC algorithm
558    assert_almost_equal(
559        roc_auc_score(
560            y_true, y_scores, labels=labels, multi_class="ovo", average="weighted"
561        ),
562        ovo_score,
563    )
564
565
566@pytest.mark.parametrize(
567    "y_true, labels",
568    [
569        (np.array([0, 1, 2, 2]), None),
570        (["a", "b", "c", "c"], None),
571        ([0, 1, 2, 2], [0, 1, 2]),
572        (["a", "b", "c", "c"], ["a", "b", "c"]),
573    ],
574)
575def test_multiclass_ovr_roc_auc_toydata(y_true, labels):
576    # Tests the unweighted, one-vs-rest multiclass ROC AUC algorithm
577    # on a small example, representative of an expected use case.
578    y_scores = np.array(
579        [[1.0, 0.0, 0.0], [0.1, 0.5, 0.4], [0.1, 0.1, 0.8], [0.3, 0.3, 0.4]]
580    )
581    # Compute the expected result by individually computing the 'one-vs-rest'
582    # ROC AUC scores for classes 0, 1, and 2.
583    out_0 = roc_auc_score([1, 0, 0, 0], y_scores[:, 0])
584    out_1 = roc_auc_score([0, 1, 0, 0], y_scores[:, 1])
585    out_2 = roc_auc_score([0, 0, 1, 1], y_scores[:, 2])
586    result_unweighted = (out_0 + out_1 + out_2) / 3.0
587
588    assert_almost_equal(
589        roc_auc_score(y_true, y_scores, multi_class="ovr", labels=labels),
590        result_unweighted,
591    )
592
593    # Tests the weighted, one-vs-rest multiclass ROC AUC algorithm
594    # on the same input (Provost & Domingos, 2000)
595    result_weighted = out_0 * 0.25 + out_1 * 0.25 + out_2 * 0.5
596    assert_almost_equal(
597        roc_auc_score(
598            y_true, y_scores, multi_class="ovr", labels=labels, average="weighted"
599        ),
600        result_weighted,
601    )
602
603
604@pytest.mark.parametrize(
605    "msg, y_true, labels",
606    [
607        ("Parameter 'labels' must be unique", np.array([0, 1, 2, 2]), [0, 2, 0]),
608        (
609            "Parameter 'labels' must be unique",
610            np.array(["a", "b", "c", "c"]),
611            ["a", "a", "b"],
612        ),
613        (
614            "Number of classes in y_true not equal to the number of columns "
615            "in 'y_score'",
616            np.array([0, 2, 0, 2]),
617            None,
618        ),
619        (
620            "Parameter 'labels' must be ordered",
621            np.array(["a", "b", "c", "c"]),
622            ["a", "c", "b"],
623        ),
624        (
625            "Number of given labels, 2, not equal to the number of columns in "
626            "'y_score', 3",
627            np.array([0, 1, 2, 2]),
628            [0, 1],
629        ),
630        (
631            "Number of given labels, 2, not equal to the number of columns in "
632            "'y_score', 3",
633            np.array(["a", "b", "c", "c"]),
634            ["a", "b"],
635        ),
636        (
637            "Number of given labels, 4, not equal to the number of columns in "
638            "'y_score', 3",
639            np.array([0, 1, 2, 2]),
640            [0, 1, 2, 3],
641        ),
642        (
643            "Number of given labels, 4, not equal to the number of columns in "
644            "'y_score', 3",
645            np.array(["a", "b", "c", "c"]),
646            ["a", "b", "c", "d"],
647        ),
648        (
649            "'y_true' contains labels not in parameter 'labels'",
650            np.array(["a", "b", "c", "e"]),
651            ["a", "b", "c"],
652        ),
653        (
654            "'y_true' contains labels not in parameter 'labels'",
655            np.array(["a", "b", "c", "d"]),
656            ["a", "b", "c"],
657        ),
658        (
659            "'y_true' contains labels not in parameter 'labels'",
660            np.array([0, 1, 2, 3]),
661            [0, 1, 2],
662        ),
663    ],
664)
665@pytest.mark.parametrize("multi_class", ["ovo", "ovr"])
666def test_roc_auc_score_multiclass_labels_error(msg, y_true, labels, multi_class):
667    y_scores = np.array(
668        [[0.1, 0.8, 0.1], [0.3, 0.4, 0.3], [0.35, 0.5, 0.15], [0, 0.2, 0.8]]
669    )
670
671    with pytest.raises(ValueError, match=msg):
672        roc_auc_score(y_true, y_scores, labels=labels, multi_class=multi_class)
673
674
675@pytest.mark.parametrize(
676    "msg, kwargs",
677    [
678        (
679            (
680                r"average must be one of \('macro', 'weighted'\) for "
681                r"multiclass problems"
682            ),
683            {"average": "samples", "multi_class": "ovo"},
684        ),
685        (
686            (
687                r"average must be one of \('macro', 'weighted'\) for "
688                r"multiclass problems"
689            ),
690            {"average": "micro", "multi_class": "ovr"},
691        ),
692        (
693            (
694                r"sample_weight is not supported for multiclass one-vs-one "
695                r"ROC AUC, 'sample_weight' must be None in this case"
696            ),
697            {"multi_class": "ovo", "sample_weight": []},
698        ),
699        (
700            (
701                r"Partial AUC computation not available in multiclass setting, "
702                r"'max_fpr' must be set to `None`, received `max_fpr=0.5` "
703                r"instead"
704            ),
705            {"multi_class": "ovo", "max_fpr": 0.5},
706        ),
707        (
708            (
709                r"multi_class='ovp' is not supported for multiclass ROC AUC, "
710                r"multi_class must be in \('ovo', 'ovr'\)"
711            ),
712            {"multi_class": "ovp"},
713        ),
714        (r"multi_class must be in \('ovo', 'ovr'\)", {}),
715    ],
716)
717def test_roc_auc_score_multiclass_error(msg, kwargs):
718    # Test that roc_auc_score function returns an error when trying
719    # to compute multiclass AUC for parameters where an output
720    # is not defined.
721    rng = check_random_state(404)
722    y_score = rng.rand(20, 3)
723    y_prob = softmax(y_score)
724    y_true = rng.randint(0, 3, size=20)
725    with pytest.raises(ValueError, match=msg):
726        roc_auc_score(y_true, y_prob, **kwargs)
727
728
729def test_auc_score_non_binary_class():
730    # Test that roc_auc_score function returns an error when trying
731    # to compute AUC for non-binary class values.
732    rng = check_random_state(404)
733    y_pred = rng.rand(10)
734    # y_true contains only one class value
735    y_true = np.zeros(10, dtype="int")
736    err_msg = "ROC AUC score is not defined"
737    with pytest.raises(ValueError, match=err_msg):
738        roc_auc_score(y_true, y_pred)
739    y_true = np.ones(10, dtype="int")
740    with pytest.raises(ValueError, match=err_msg):
741        roc_auc_score(y_true, y_pred)
742    y_true = np.full(10, -1, dtype="int")
743    with pytest.raises(ValueError, match=err_msg):
744        roc_auc_score(y_true, y_pred)
745
746    with warnings.catch_warnings(record=True):
747        rng = check_random_state(404)
748        y_pred = rng.rand(10)
749        # y_true contains only one class value
750        y_true = np.zeros(10, dtype="int")
751        with pytest.raises(ValueError, match=err_msg):
752            roc_auc_score(y_true, y_pred)
753        y_true = np.ones(10, dtype="int")
754        with pytest.raises(ValueError, match=err_msg):
755            roc_auc_score(y_true, y_pred)
756        y_true = np.full(10, -1, dtype="int")
757        with pytest.raises(ValueError, match=err_msg):
758            roc_auc_score(y_true, y_pred)
759
760
761@pytest.mark.parametrize("curve_func", CURVE_FUNCS)
762def test_binary_clf_curve_multiclass_error(curve_func):
763    rng = check_random_state(404)
764    y_true = rng.randint(0, 3, size=10)
765    y_pred = rng.rand(10)
766    msg = "multiclass format is not supported"
767    with pytest.raises(ValueError, match=msg):
768        curve_func(y_true, y_pred)
769
770
771@pytest.mark.parametrize("curve_func", CURVE_FUNCS)
772def test_binary_clf_curve_implicit_pos_label(curve_func):
773    # Check that using string class labels raises an informative
774    # error for any supported string dtype:
775    msg = (
776        "y_true takes value in {'a', 'b'} and pos_label is "
777        "not specified: either make y_true take "
778        "value in {0, 1} or {-1, 1} or pass pos_label "
779        "explicitly."
780    )
781    with pytest.raises(ValueError, match=msg):
782        curve_func(np.array(["a", "b"], dtype="<U1"), [0.0, 1.0])
783
784    with pytest.raises(ValueError, match=msg):
785        curve_func(np.array(["a", "b"], dtype=object), [0.0, 1.0])
786
787    # The error message is slightly different for bytes-encoded
788    # class labels, but otherwise the behavior is the same:
789    msg = (
790        "y_true takes value in {b'a', b'b'} and pos_label is "
791        "not specified: either make y_true take "
792        "value in {0, 1} or {-1, 1} or pass pos_label "
793        "explicitly."
794    )
795    with pytest.raises(ValueError, match=msg):
796        curve_func(np.array([b"a", b"b"], dtype="<S1"), [0.0, 1.0])
797
798    # Check that it is possible to use floating point class labels
799    # that are interpreted similarly to integer class labels:
800    y_pred = [0.0, 1.0, 0.2, 0.42]
801    int_curve = curve_func([0, 1, 1, 0], y_pred)
802    float_curve = curve_func([0.0, 1.0, 1.0, 0.0], y_pred)
803    for int_curve_part, float_curve_part in zip(int_curve, float_curve):
804        np.testing.assert_allclose(int_curve_part, float_curve_part)
805
806
807@pytest.mark.parametrize("curve_func", CURVE_FUNCS)
808def test_binary_clf_curve_zero_sample_weight(curve_func):
809    y_true = [0, 0, 1, 1, 1]
810    y_score = [0.1, 0.2, 0.3, 0.4, 0.5]
811    sample_weight = [1, 1, 1, 0.5, 0]
812
813    result_1 = curve_func(y_true, y_score, sample_weight=sample_weight)
814    result_2 = curve_func(y_true[:-1], y_score[:-1], sample_weight=sample_weight[:-1])
815
816    for arr_1, arr_2 in zip(result_1, result_2):
817        assert_allclose(arr_1, arr_2)
818
819
820def test_precision_recall_curve():
821    y_true, _, y_score = make_prediction(binary=True)
822    _test_precision_recall_curve(y_true, y_score)
823
824    # Use {-1, 1} for labels; make sure original labels aren't modified
825    y_true[np.where(y_true == 0)] = -1
826    y_true_copy = y_true.copy()
827    _test_precision_recall_curve(y_true, y_score)
828    assert_array_equal(y_true_copy, y_true)
829
830    labels = [1, 0, 0, 1]
831    predict_probas = [1, 2, 3, 4]
832    p, r, t = precision_recall_curve(labels, predict_probas)
833    assert_array_almost_equal(p, np.array([0.5, 0.33333333, 0.5, 1.0, 1.0]))
834    assert_array_almost_equal(r, np.array([1.0, 0.5, 0.5, 0.5, 0.0]))
835    assert_array_almost_equal(t, np.array([1, 2, 3, 4]))
836    assert p.size == r.size
837    assert p.size == t.size + 1
838
839
840def _test_precision_recall_curve(y_true, y_score):
841    # Test Precision-Recall and aread under PR curve
842    p, r, thresholds = precision_recall_curve(y_true, y_score)
843    precision_recall_auc = _average_precision_slow(y_true, y_score)
844    assert_array_almost_equal(precision_recall_auc, 0.859, 3)
845    assert_array_almost_equal(
846        precision_recall_auc, average_precision_score(y_true, y_score)
847    )
848    # `_average_precision` is not very precise in case of 0.5 ties: be tolerant
849    assert_almost_equal(
850        _average_precision(y_true, y_score), precision_recall_auc, decimal=2
851    )
852    assert p.size == r.size
853    assert p.size == thresholds.size + 1
854    # Smoke test in the case of proba having only one value
855    p, r, thresholds = precision_recall_curve(y_true, np.zeros_like(y_score))
856    assert p.size == r.size
857    assert p.size == thresholds.size + 1
858
859
860def test_precision_recall_curve_toydata():
861    with np.errstate(all="raise"):
862        # Binary classification
863        y_true = [0, 1]
864        y_score = [0, 1]
865        p, r, _ = precision_recall_curve(y_true, y_score)
866        auc_prc = average_precision_score(y_true, y_score)
867        assert_array_almost_equal(p, [1, 1])
868        assert_array_almost_equal(r, [1, 0])
869        assert_almost_equal(auc_prc, 1.0)
870
871        y_true = [0, 1]
872        y_score = [1, 0]
873        p, r, _ = precision_recall_curve(y_true, y_score)
874        auc_prc = average_precision_score(y_true, y_score)
875        assert_array_almost_equal(p, [0.5, 0.0, 1.0])
876        assert_array_almost_equal(r, [1.0, 0.0, 0.0])
877        # Here we are doing a terrible prediction: we are always getting
878        # it wrong, hence the average_precision_score is the accuracy at
879        # chance: 50%
880        assert_almost_equal(auc_prc, 0.5)
881
882        y_true = [1, 0]
883        y_score = [1, 1]
884        p, r, _ = precision_recall_curve(y_true, y_score)
885        auc_prc = average_precision_score(y_true, y_score)
886        assert_array_almost_equal(p, [0.5, 1])
887        assert_array_almost_equal(r, [1.0, 0])
888        assert_almost_equal(auc_prc, 0.5)
889
890        y_true = [1, 0]
891        y_score = [1, 0]
892        p, r, _ = precision_recall_curve(y_true, y_score)
893        auc_prc = average_precision_score(y_true, y_score)
894        assert_array_almost_equal(p, [1, 1])
895        assert_array_almost_equal(r, [1, 0])
896        assert_almost_equal(auc_prc, 1.0)
897
898        y_true = [1, 0]
899        y_score = [0.5, 0.5]
900        p, r, _ = precision_recall_curve(y_true, y_score)
901        auc_prc = average_precision_score(y_true, y_score)
902        assert_array_almost_equal(p, [0.5, 1])
903        assert_array_almost_equal(r, [1, 0.0])
904        assert_almost_equal(auc_prc, 0.5)
905
906        y_true = [0, 0]
907        y_score = [0.25, 0.75]
908        with pytest.raises(Exception):
909            precision_recall_curve(y_true, y_score)
910        with pytest.raises(Exception):
911            average_precision_score(y_true, y_score)
912
913        y_true = [1, 1]
914        y_score = [0.25, 0.75]
915        p, r, _ = precision_recall_curve(y_true, y_score)
916        assert_almost_equal(average_precision_score(y_true, y_score), 1.0)
917        assert_array_almost_equal(p, [1.0, 1.0, 1.0])
918        assert_array_almost_equal(r, [1, 0.5, 0.0])
919
920        # Multi-label classification task
921        y_true = np.array([[0, 1], [0, 1]])
922        y_score = np.array([[0, 1], [0, 1]])
923        with pytest.raises(Exception):
924            average_precision_score(y_true, y_score, average="macro")
925        with pytest.raises(Exception):
926            average_precision_score(y_true, y_score, average="weighted")
927        assert_almost_equal(
928            average_precision_score(y_true, y_score, average="samples"), 1.0
929        )
930        assert_almost_equal(
931            average_precision_score(y_true, y_score, average="micro"), 1.0
932        )
933
934        y_true = np.array([[0, 1], [0, 1]])
935        y_score = np.array([[0, 1], [1, 0]])
936        with pytest.raises(Exception):
937            average_precision_score(y_true, y_score, average="macro")
938        with pytest.raises(Exception):
939            average_precision_score(y_true, y_score, average="weighted")
940        assert_almost_equal(
941            average_precision_score(y_true, y_score, average="samples"), 0.75
942        )
943        assert_almost_equal(
944            average_precision_score(y_true, y_score, average="micro"), 0.5
945        )
946
947        y_true = np.array([[1, 0], [0, 1]])
948        y_score = np.array([[0, 1], [1, 0]])
949        assert_almost_equal(
950            average_precision_score(y_true, y_score, average="macro"), 0.5
951        )
952        assert_almost_equal(
953            average_precision_score(y_true, y_score, average="weighted"), 0.5
954        )
955        assert_almost_equal(
956            average_precision_score(y_true, y_score, average="samples"), 0.5
957        )
958        assert_almost_equal(
959            average_precision_score(y_true, y_score, average="micro"), 0.5
960        )
961
962        y_true = np.array([[1, 0], [0, 1]])
963        y_score = np.array([[0.5, 0.5], [0.5, 0.5]])
964        assert_almost_equal(
965            average_precision_score(y_true, y_score, average="macro"), 0.5
966        )
967        assert_almost_equal(
968            average_precision_score(y_true, y_score, average="weighted"), 0.5
969        )
970        assert_almost_equal(
971            average_precision_score(y_true, y_score, average="samples"), 0.5
972        )
973        assert_almost_equal(
974            average_precision_score(y_true, y_score, average="micro"), 0.5
975        )
976
977    with np.errstate(all="ignore"):
978        # if one class is never present weighted should not be NaN
979        y_true = np.array([[0, 0], [0, 1]])
980        y_score = np.array([[0, 0], [0, 1]])
981        assert_almost_equal(
982            average_precision_score(y_true, y_score, average="weighted"), 1
983        )
984
985
986def test_average_precision_constant_values():
987    # Check the average_precision_score of a constant predictor is
988    # the TPR
989
990    # Generate a dataset with 25% of positives
991    y_true = np.zeros(100, dtype=int)
992    y_true[::4] = 1
993    # And a constant score
994    y_score = np.ones(100)
995    # The precision is then the fraction of positive whatever the recall
996    # is, as there is only one threshold:
997    assert average_precision_score(y_true, y_score) == 0.25
998
999
1000def test_average_precision_score_pos_label_errors():
1001    # Raise an error when pos_label is not in binary y_true
1002    y_true = np.array([0, 1])
1003    y_pred = np.array([0, 1])
1004    err_msg = r"pos_label=2 is not a valid label. It should be one of \[0, 1\]"
1005    with pytest.raises(ValueError, match=err_msg):
1006        average_precision_score(y_true, y_pred, pos_label=2)
1007    # Raise an error for multilabel-indicator y_true with
1008    # pos_label other than 1
1009    y_true = np.array([[1, 0], [0, 1], [0, 1], [1, 0]])
1010    y_pred = np.array([[0.9, 0.1], [0.1, 0.9], [0.8, 0.2], [0.2, 0.8]])
1011    err_msg = (
1012        "Parameter pos_label is fixed to 1 for multilabel-indicator y_true. "
1013        "Do not set pos_label or set pos_label to 1."
1014    )
1015    with pytest.raises(ValueError, match=err_msg):
1016        average_precision_score(y_true, y_pred, pos_label=0)
1017
1018
1019def test_score_scale_invariance():
1020    # Test that average_precision_score and roc_auc_score are invariant by
1021    # the scaling or shifting of probabilities
1022    # This test was expanded (added scaled_down) in response to github
1023    # issue #3864 (and others), where overly aggressive rounding was causing
1024    # problems for users with very small y_score values
1025    y_true, _, y_score = make_prediction(binary=True)
1026
1027    roc_auc = roc_auc_score(y_true, y_score)
1028    roc_auc_scaled_up = roc_auc_score(y_true, 100 * y_score)
1029    roc_auc_scaled_down = roc_auc_score(y_true, 1e-6 * y_score)
1030    roc_auc_shifted = roc_auc_score(y_true, y_score - 10)
1031    assert roc_auc == roc_auc_scaled_up
1032    assert roc_auc == roc_auc_scaled_down
1033    assert roc_auc == roc_auc_shifted
1034
1035    pr_auc = average_precision_score(y_true, y_score)
1036    pr_auc_scaled_up = average_precision_score(y_true, 100 * y_score)
1037    pr_auc_scaled_down = average_precision_score(y_true, 1e-6 * y_score)
1038    pr_auc_shifted = average_precision_score(y_true, y_score - 10)
1039    assert pr_auc == pr_auc_scaled_up
1040    assert pr_auc == pr_auc_scaled_down
1041    assert pr_auc == pr_auc_shifted
1042
1043
1044@pytest.mark.parametrize(
1045    "y_true,y_score,expected_fpr,expected_fnr",
1046    [
1047        ([0, 0, 1], [0, 0.5, 1], [0], [0]),
1048        ([0, 0, 1], [0, 0.25, 0.5], [0], [0]),
1049        ([0, 0, 1], [0.5, 0.75, 1], [0], [0]),
1050        ([0, 0, 1], [0.25, 0.5, 0.75], [0], [0]),
1051        ([0, 1, 0], [0, 0.5, 1], [0.5], [0]),
1052        ([0, 1, 0], [0, 0.25, 0.5], [0.5], [0]),
1053        ([0, 1, 0], [0.5, 0.75, 1], [0.5], [0]),
1054        ([0, 1, 0], [0.25, 0.5, 0.75], [0.5], [0]),
1055        ([0, 1, 1], [0, 0.5, 1], [0.0], [0]),
1056        ([0, 1, 1], [0, 0.25, 0.5], [0], [0]),
1057        ([0, 1, 1], [0.5, 0.75, 1], [0], [0]),
1058        ([0, 1, 1], [0.25, 0.5, 0.75], [0], [0]),
1059        ([1, 0, 0], [0, 0.5, 1], [1, 1, 0.5], [0, 1, 1]),
1060        ([1, 0, 0], [0, 0.25, 0.5], [1, 1, 0.5], [0, 1, 1]),
1061        ([1, 0, 0], [0.5, 0.75, 1], [1, 1, 0.5], [0, 1, 1]),
1062        ([1, 0, 0], [0.25, 0.5, 0.75], [1, 1, 0.5], [0, 1, 1]),
1063        ([1, 0, 1], [0, 0.5, 1], [1, 1, 0], [0, 0.5, 0.5]),
1064        ([1, 0, 1], [0, 0.25, 0.5], [1, 1, 0], [0, 0.5, 0.5]),
1065        ([1, 0, 1], [0.5, 0.75, 1], [1, 1, 0], [0, 0.5, 0.5]),
1066        ([1, 0, 1], [0.25, 0.5, 0.75], [1, 1, 0], [0, 0.5, 0.5]),
1067    ],
1068)
1069def test_det_curve_toydata(y_true, y_score, expected_fpr, expected_fnr):
1070    # Check on a batch of small examples.
1071    fpr, fnr, _ = det_curve(y_true, y_score)
1072
1073    assert_allclose(fpr, expected_fpr)
1074    assert_allclose(fnr, expected_fnr)
1075
1076
1077@pytest.mark.parametrize(
1078    "y_true,y_score,expected_fpr,expected_fnr",
1079    [
1080        ([1, 0], [0.5, 0.5], [1], [0]),
1081        ([0, 1], [0.5, 0.5], [1], [0]),
1082        ([0, 0, 1], [0.25, 0.5, 0.5], [0.5], [0]),
1083        ([0, 1, 0], [0.25, 0.5, 0.5], [0.5], [0]),
1084        ([0, 1, 1], [0.25, 0.5, 0.5], [0], [0]),
1085        ([1, 0, 0], [0.25, 0.5, 0.5], [1], [0]),
1086        ([1, 0, 1], [0.25, 0.5, 0.5], [1], [0]),
1087        ([1, 1, 0], [0.25, 0.5, 0.5], [1], [0]),
1088    ],
1089)
1090def test_det_curve_tie_handling(y_true, y_score, expected_fpr, expected_fnr):
1091    fpr, fnr, _ = det_curve(y_true, y_score)
1092
1093    assert_allclose(fpr, expected_fpr)
1094    assert_allclose(fnr, expected_fnr)
1095
1096
1097def test_det_curve_sanity_check():
1098    # Exactly duplicated inputs yield the same result.
1099    assert_allclose(
1100        det_curve([0, 0, 1], [0, 0.5, 1]),
1101        det_curve([0, 0, 0, 0, 1, 1], [0, 0, 0.5, 0.5, 1, 1]),
1102    )
1103
1104
1105@pytest.mark.parametrize("y_score", [(0), (0.25), (0.5), (0.75), (1)])
1106def test_det_curve_constant_scores(y_score):
1107    fpr, fnr, threshold = det_curve(
1108        y_true=[0, 1, 0, 1, 0, 1], y_score=np.full(6, y_score)
1109    )
1110
1111    assert_allclose(fpr, [1])
1112    assert_allclose(fnr, [0])
1113    assert_allclose(threshold, [y_score])
1114
1115
1116@pytest.mark.parametrize(
1117    "y_true",
1118    [
1119        ([0, 0, 0, 0, 0, 1]),
1120        ([0, 0, 0, 0, 1, 1]),
1121        ([0, 0, 0, 1, 1, 1]),
1122        ([0, 0, 1, 1, 1, 1]),
1123        ([0, 1, 1, 1, 1, 1]),
1124    ],
1125)
1126def test_det_curve_perfect_scores(y_true):
1127    fpr, fnr, _ = det_curve(y_true=y_true, y_score=y_true)
1128
1129    assert_allclose(fpr, [0])
1130    assert_allclose(fnr, [0])
1131
1132
1133@pytest.mark.parametrize(
1134    "y_true, y_pred, err_msg",
1135    [
1136        ([0, 1], [0, 0.5, 1], "inconsistent numbers of samples"),
1137        ([0, 1, 1], [0, 0.5], "inconsistent numbers of samples"),
1138        ([0, 0, 0], [0, 0.5, 1], "Only one class present in y_true"),
1139        ([1, 1, 1], [0, 0.5, 1], "Only one class present in y_true"),
1140        (
1141            ["cancer", "cancer", "not cancer"],
1142            [0.2, 0.3, 0.8],
1143            "pos_label is not specified",
1144        ),
1145    ],
1146)
1147def test_det_curve_bad_input(y_true, y_pred, err_msg):
1148    # input variables with inconsistent numbers of samples
1149    with pytest.raises(ValueError, match=err_msg):
1150        det_curve(y_true, y_pred)
1151
1152
1153def test_det_curve_pos_label():
1154    y_true = ["cancer"] * 3 + ["not cancer"] * 7
1155    y_pred_pos_not_cancer = np.array([0.1, 0.4, 0.6, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.9])
1156    y_pred_pos_cancer = 1 - y_pred_pos_not_cancer
1157
1158    fpr_pos_cancer, fnr_pos_cancer, th_pos_cancer = det_curve(
1159        y_true,
1160        y_pred_pos_cancer,
1161        pos_label="cancer",
1162    )
1163    fpr_pos_not_cancer, fnr_pos_not_cancer, th_pos_not_cancer = det_curve(
1164        y_true,
1165        y_pred_pos_not_cancer,
1166        pos_label="not cancer",
1167    )
1168
1169    # check that the first threshold will change depending which label we
1170    # consider positive
1171    assert th_pos_cancer[0] == pytest.approx(0.4)
1172    assert th_pos_not_cancer[0] == pytest.approx(0.2)
1173
1174    # check for the symmetry of the fpr and fnr
1175    assert_allclose(fpr_pos_cancer, fnr_pos_not_cancer[::-1])
1176    assert_allclose(fnr_pos_cancer, fpr_pos_not_cancer[::-1])
1177
1178
1179def check_lrap_toy(lrap_score):
1180    # Check on several small example that it works
1181    assert_almost_equal(lrap_score([[0, 1]], [[0.25, 0.75]]), 1)
1182    assert_almost_equal(lrap_score([[0, 1]], [[0.75, 0.25]]), 1 / 2)
1183    assert_almost_equal(lrap_score([[1, 1]], [[0.75, 0.25]]), 1)
1184
1185    assert_almost_equal(lrap_score([[0, 0, 1]], [[0.25, 0.5, 0.75]]), 1)
1186    assert_almost_equal(lrap_score([[0, 1, 0]], [[0.25, 0.5, 0.75]]), 1 / 2)
1187    assert_almost_equal(lrap_score([[0, 1, 1]], [[0.25, 0.5, 0.75]]), 1)
1188    assert_almost_equal(lrap_score([[1, 0, 0]], [[0.25, 0.5, 0.75]]), 1 / 3)
1189    assert_almost_equal(
1190        lrap_score([[1, 0, 1]], [[0.25, 0.5, 0.75]]), (2 / 3 + 1 / 1) / 2
1191    )
1192    assert_almost_equal(
1193        lrap_score([[1, 1, 0]], [[0.25, 0.5, 0.75]]), (2 / 3 + 1 / 2) / 2
1194    )
1195
1196    assert_almost_equal(lrap_score([[0, 0, 1]], [[0.75, 0.5, 0.25]]), 1 / 3)
1197    assert_almost_equal(lrap_score([[0, 1, 0]], [[0.75, 0.5, 0.25]]), 1 / 2)
1198    assert_almost_equal(
1199        lrap_score([[0, 1, 1]], [[0.75, 0.5, 0.25]]), (1 / 2 + 2 / 3) / 2
1200    )
1201    assert_almost_equal(lrap_score([[1, 0, 0]], [[0.75, 0.5, 0.25]]), 1)
1202    assert_almost_equal(lrap_score([[1, 0, 1]], [[0.75, 0.5, 0.25]]), (1 + 2 / 3) / 2)
1203    assert_almost_equal(lrap_score([[1, 1, 0]], [[0.75, 0.5, 0.25]]), 1)
1204    assert_almost_equal(lrap_score([[1, 1, 1]], [[0.75, 0.5, 0.25]]), 1)
1205
1206    assert_almost_equal(lrap_score([[0, 0, 1]], [[0.5, 0.75, 0.25]]), 1 / 3)
1207    assert_almost_equal(lrap_score([[0, 1, 0]], [[0.5, 0.75, 0.25]]), 1)
1208    assert_almost_equal(lrap_score([[0, 1, 1]], [[0.5, 0.75, 0.25]]), (1 + 2 / 3) / 2)
1209    assert_almost_equal(lrap_score([[1, 0, 0]], [[0.5, 0.75, 0.25]]), 1 / 2)
1210    assert_almost_equal(
1211        lrap_score([[1, 0, 1]], [[0.5, 0.75, 0.25]]), (1 / 2 + 2 / 3) / 2
1212    )
1213    assert_almost_equal(lrap_score([[1, 1, 0]], [[0.5, 0.75, 0.25]]), 1)
1214    assert_almost_equal(lrap_score([[1, 1, 1]], [[0.5, 0.75, 0.25]]), 1)
1215
1216    # Tie handling
1217    assert_almost_equal(lrap_score([[1, 0]], [[0.5, 0.5]]), 0.5)
1218    assert_almost_equal(lrap_score([[0, 1]], [[0.5, 0.5]]), 0.5)
1219    assert_almost_equal(lrap_score([[1, 1]], [[0.5, 0.5]]), 1)
1220
1221    assert_almost_equal(lrap_score([[0, 0, 1]], [[0.25, 0.5, 0.5]]), 0.5)
1222    assert_almost_equal(lrap_score([[0, 1, 0]], [[0.25, 0.5, 0.5]]), 0.5)
1223    assert_almost_equal(lrap_score([[0, 1, 1]], [[0.25, 0.5, 0.5]]), 1)
1224    assert_almost_equal(lrap_score([[1, 0, 0]], [[0.25, 0.5, 0.5]]), 1 / 3)
1225    assert_almost_equal(
1226        lrap_score([[1, 0, 1]], [[0.25, 0.5, 0.5]]), (2 / 3 + 1 / 2) / 2
1227    )
1228    assert_almost_equal(
1229        lrap_score([[1, 1, 0]], [[0.25, 0.5, 0.5]]), (2 / 3 + 1 / 2) / 2
1230    )
1231    assert_almost_equal(lrap_score([[1, 1, 1]], [[0.25, 0.5, 0.5]]), 1)
1232
1233    assert_almost_equal(lrap_score([[1, 1, 0]], [[0.5, 0.5, 0.5]]), 2 / 3)
1234
1235    assert_almost_equal(lrap_score([[1, 1, 1, 0]], [[0.5, 0.5, 0.5, 0.5]]), 3 / 4)
1236
1237
1238def check_zero_or_all_relevant_labels(lrap_score):
1239    random_state = check_random_state(0)
1240
1241    for n_labels in range(2, 5):
1242        y_score = random_state.uniform(size=(1, n_labels))
1243        y_score_ties = np.zeros_like(y_score)
1244
1245        # No relevant labels
1246        y_true = np.zeros((1, n_labels))
1247        assert lrap_score(y_true, y_score) == 1.0
1248        assert lrap_score(y_true, y_score_ties) == 1.0
1249
1250        # Only relevant labels
1251        y_true = np.ones((1, n_labels))
1252        assert lrap_score(y_true, y_score) == 1.0
1253        assert lrap_score(y_true, y_score_ties) == 1.0
1254
1255    # Degenerate case: only one label
1256    assert_almost_equal(
1257        lrap_score([[1], [0], [1], [0]], [[0.5], [0.5], [0.5], [0.5]]), 1.0
1258    )
1259
1260
1261def check_lrap_error_raised(lrap_score):
1262    # Raise value error if not appropriate format
1263    with pytest.raises(ValueError):
1264        lrap_score([0, 1, 0], [0.25, 0.3, 0.2])
1265    with pytest.raises(ValueError):
1266        lrap_score([0, 1, 2], [[0.25, 0.75, 0.0], [0.7, 0.3, 0.0], [0.8, 0.2, 0.0]])
1267    with pytest.raises(ValueError):
1268        lrap_score(
1269            [(0), (1), (2)], [[0.25, 0.75, 0.0], [0.7, 0.3, 0.0], [0.8, 0.2, 0.0]]
1270        )
1271
1272    # Check that y_true.shape != y_score.shape raise the proper exception
1273    with pytest.raises(ValueError):
1274        lrap_score([[0, 1], [0, 1]], [0, 1])
1275    with pytest.raises(ValueError):
1276        lrap_score([[0, 1], [0, 1]], [[0, 1]])
1277    with pytest.raises(ValueError):
1278        lrap_score([[0, 1], [0, 1]], [[0], [1]])
1279    with pytest.raises(ValueError):
1280        lrap_score([[0, 1]], [[0, 1], [0, 1]])
1281    with pytest.raises(ValueError):
1282        lrap_score([[0], [1]], [[0, 1], [0, 1]])
1283    with pytest.raises(ValueError):
1284        lrap_score([[0, 1], [0, 1]], [[0], [1]])
1285
1286
1287def check_lrap_only_ties(lrap_score):
1288    # Check tie handling in score
1289    # Basic check with only ties and increasing label space
1290    for n_labels in range(2, 10):
1291        y_score = np.ones((1, n_labels))
1292
1293        # Check for growing number of consecutive relevant
1294        for n_relevant in range(1, n_labels):
1295            # Check for a bunch of positions
1296            for pos in range(n_labels - n_relevant):
1297                y_true = np.zeros((1, n_labels))
1298                y_true[0, pos : pos + n_relevant] = 1
1299                assert_almost_equal(lrap_score(y_true, y_score), n_relevant / n_labels)
1300
1301
1302def check_lrap_without_tie_and_increasing_score(lrap_score):
1303    # Check that Label ranking average precision works for various
1304    # Basic check with increasing label space size and decreasing score
1305    for n_labels in range(2, 10):
1306        y_score = n_labels - (np.arange(n_labels).reshape((1, n_labels)) + 1)
1307
1308        # First and last
1309        y_true = np.zeros((1, n_labels))
1310        y_true[0, 0] = 1
1311        y_true[0, -1] = 1
1312        assert_almost_equal(lrap_score(y_true, y_score), (2 / n_labels + 1) / 2)
1313
1314        # Check for growing number of consecutive relevant label
1315        for n_relevant in range(1, n_labels):
1316            # Check for a bunch of position
1317            for pos in range(n_labels - n_relevant):
1318                y_true = np.zeros((1, n_labels))
1319                y_true[0, pos : pos + n_relevant] = 1
1320                assert_almost_equal(
1321                    lrap_score(y_true, y_score),
1322                    sum(
1323                        (r + 1) / ((pos + r + 1) * n_relevant)
1324                        for r in range(n_relevant)
1325                    ),
1326                )
1327
1328
1329def _my_lrap(y_true, y_score):
1330    """Simple implementation of label ranking average precision"""
1331    check_consistent_length(y_true, y_score)
1332    y_true = check_array(y_true)
1333    y_score = check_array(y_score)
1334    n_samples, n_labels = y_true.shape
1335    score = np.empty((n_samples,))
1336    for i in range(n_samples):
1337        # The best rank correspond to 1. Rank higher than 1 are worse.
1338        # The best inverse ranking correspond to n_labels.
1339        unique_rank, inv_rank = np.unique(y_score[i], return_inverse=True)
1340        n_ranks = unique_rank.size
1341        rank = n_ranks - inv_rank
1342
1343        # Rank need to be corrected to take into account ties
1344        # ex: rank 1 ex aequo means that both label are rank 2.
1345        corr_rank = np.bincount(rank, minlength=n_ranks + 1).cumsum()
1346        rank = corr_rank[rank]
1347
1348        relevant = y_true[i].nonzero()[0]
1349        if relevant.size == 0 or relevant.size == n_labels:
1350            score[i] = 1
1351            continue
1352
1353        score[i] = 0.0
1354        for label in relevant:
1355            # Let's count the number of relevant label with better rank
1356            # (smaller rank).
1357            n_ranked_above = sum(rank[r] <= rank[label] for r in relevant)
1358
1359            # Weight by the rank of the actual label
1360            score[i] += n_ranked_above / rank[label]
1361
1362        score[i] /= relevant.size
1363
1364    return score.mean()
1365
1366
1367def check_alternative_lrap_implementation(
1368    lrap_score, n_classes=5, n_samples=20, random_state=0
1369):
1370    _, y_true = make_multilabel_classification(
1371        n_features=1,
1372        allow_unlabeled=False,
1373        random_state=random_state,
1374        n_classes=n_classes,
1375        n_samples=n_samples,
1376    )
1377
1378    # Score with ties
1379    y_score = _sparse_random_matrix(
1380        n_components=y_true.shape[0],
1381        n_features=y_true.shape[1],
1382        random_state=random_state,
1383    )
1384
1385    if hasattr(y_score, "toarray"):
1386        y_score = y_score.toarray()
1387    score_lrap = label_ranking_average_precision_score(y_true, y_score)
1388    score_my_lrap = _my_lrap(y_true, y_score)
1389    assert_almost_equal(score_lrap, score_my_lrap)
1390
1391    # Uniform score
1392    random_state = check_random_state(random_state)
1393    y_score = random_state.uniform(size=(n_samples, n_classes))
1394    score_lrap = label_ranking_average_precision_score(y_true, y_score)
1395    score_my_lrap = _my_lrap(y_true, y_score)
1396    assert_almost_equal(score_lrap, score_my_lrap)
1397
1398
1399@pytest.mark.parametrize(
1400    "check",
1401    (
1402        check_lrap_toy,
1403        check_lrap_without_tie_and_increasing_score,
1404        check_lrap_only_ties,
1405        check_zero_or_all_relevant_labels,
1406    ),
1407)
1408@pytest.mark.parametrize("func", (label_ranking_average_precision_score, _my_lrap))
1409def test_label_ranking_avp(check, func):
1410    check(func)
1411
1412
1413def test_lrap_error_raised():
1414    check_lrap_error_raised(label_ranking_average_precision_score)
1415
1416
1417@pytest.mark.parametrize("n_samples", (1, 2, 8, 20))
1418@pytest.mark.parametrize("n_classes", (2, 5, 10))
1419@pytest.mark.parametrize("random_state", range(1))
1420def test_alternative_lrap_implementation(n_samples, n_classes, random_state):
1421
1422    check_alternative_lrap_implementation(
1423        label_ranking_average_precision_score, n_classes, n_samples, random_state
1424    )
1425
1426
1427def test_lrap_sample_weighting_zero_labels():
1428    # Degenerate sample labeling (e.g., zero labels for a sample) is a valid
1429    # special case for lrap (the sample is considered to achieve perfect
1430    # precision), but this case is not tested in test_common.
1431    # For these test samples, the APs are 0.5, 0.75, and 1.0 (default for zero
1432    # labels).
1433    y_true = np.array([[1, 0, 0, 0], [1, 0, 0, 1], [0, 0, 0, 0]], dtype=bool)
1434    y_score = np.array(
1435        [[0.3, 0.4, 0.2, 0.1], [0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1]]
1436    )
1437    samplewise_lraps = np.array([0.5, 0.75, 1.0])
1438    sample_weight = np.array([1.0, 1.0, 0.0])
1439
1440    assert_almost_equal(
1441        label_ranking_average_precision_score(
1442            y_true, y_score, sample_weight=sample_weight
1443        ),
1444        np.sum(sample_weight * samplewise_lraps) / np.sum(sample_weight),
1445    )
1446
1447
1448def test_coverage_error():
1449    # Toy case
1450    assert_almost_equal(coverage_error([[0, 1]], [[0.25, 0.75]]), 1)
1451    assert_almost_equal(coverage_error([[0, 1]], [[0.75, 0.25]]), 2)
1452    assert_almost_equal(coverage_error([[1, 1]], [[0.75, 0.25]]), 2)
1453    assert_almost_equal(coverage_error([[0, 0]], [[0.75, 0.25]]), 0)
1454
1455    assert_almost_equal(coverage_error([[0, 0, 0]], [[0.25, 0.5, 0.75]]), 0)
1456    assert_almost_equal(coverage_error([[0, 0, 1]], [[0.25, 0.5, 0.75]]), 1)
1457    assert_almost_equal(coverage_error([[0, 1, 0]], [[0.25, 0.5, 0.75]]), 2)
1458    assert_almost_equal(coverage_error([[0, 1, 1]], [[0.25, 0.5, 0.75]]), 2)
1459    assert_almost_equal(coverage_error([[1, 0, 0]], [[0.25, 0.5, 0.75]]), 3)
1460    assert_almost_equal(coverage_error([[1, 0, 1]], [[0.25, 0.5, 0.75]]), 3)
1461    assert_almost_equal(coverage_error([[1, 1, 0]], [[0.25, 0.5, 0.75]]), 3)
1462    assert_almost_equal(coverage_error([[1, 1, 1]], [[0.25, 0.5, 0.75]]), 3)
1463
1464    assert_almost_equal(coverage_error([[0, 0, 0]], [[0.75, 0.5, 0.25]]), 0)
1465    assert_almost_equal(coverage_error([[0, 0, 1]], [[0.75, 0.5, 0.25]]), 3)
1466    assert_almost_equal(coverage_error([[0, 1, 0]], [[0.75, 0.5, 0.25]]), 2)
1467    assert_almost_equal(coverage_error([[0, 1, 1]], [[0.75, 0.5, 0.25]]), 3)
1468    assert_almost_equal(coverage_error([[1, 0, 0]], [[0.75, 0.5, 0.25]]), 1)
1469    assert_almost_equal(coverage_error([[1, 0, 1]], [[0.75, 0.5, 0.25]]), 3)
1470    assert_almost_equal(coverage_error([[1, 1, 0]], [[0.75, 0.5, 0.25]]), 2)
1471    assert_almost_equal(coverage_error([[1, 1, 1]], [[0.75, 0.5, 0.25]]), 3)
1472
1473    assert_almost_equal(coverage_error([[0, 0, 0]], [[0.5, 0.75, 0.25]]), 0)
1474    assert_almost_equal(coverage_error([[0, 0, 1]], [[0.5, 0.75, 0.25]]), 3)
1475    assert_almost_equal(coverage_error([[0, 1, 0]], [[0.5, 0.75, 0.25]]), 1)
1476    assert_almost_equal(coverage_error([[0, 1, 1]], [[0.5, 0.75, 0.25]]), 3)
1477    assert_almost_equal(coverage_error([[1, 0, 0]], [[0.5, 0.75, 0.25]]), 2)
1478    assert_almost_equal(coverage_error([[1, 0, 1]], [[0.5, 0.75, 0.25]]), 3)
1479    assert_almost_equal(coverage_error([[1, 1, 0]], [[0.5, 0.75, 0.25]]), 2)
1480    assert_almost_equal(coverage_error([[1, 1, 1]], [[0.5, 0.75, 0.25]]), 3)
1481
1482    # Non trivial case
1483    assert_almost_equal(
1484        coverage_error([[0, 1, 0], [1, 1, 0]], [[0.1, 10.0, -3], [0, 1, 3]]),
1485        (1 + 3) / 2.0,
1486    )
1487
1488    assert_almost_equal(
1489        coverage_error(
1490            [[0, 1, 0], [1, 1, 0], [0, 1, 1]], [[0.1, 10, -3], [0, 1, 3], [0, 2, 0]]
1491        ),
1492        (1 + 3 + 3) / 3.0,
1493    )
1494
1495    assert_almost_equal(
1496        coverage_error(
1497            [[0, 1, 0], [1, 1, 0], [0, 1, 1]], [[0.1, 10, -3], [3, 1, 3], [0, 2, 0]]
1498        ),
1499        (1 + 3 + 3) / 3.0,
1500    )
1501
1502
1503def test_coverage_tie_handling():
1504    assert_almost_equal(coverage_error([[0, 0]], [[0.5, 0.5]]), 0)
1505    assert_almost_equal(coverage_error([[1, 0]], [[0.5, 0.5]]), 2)
1506    assert_almost_equal(coverage_error([[0, 1]], [[0.5, 0.5]]), 2)
1507    assert_almost_equal(coverage_error([[1, 1]], [[0.5, 0.5]]), 2)
1508
1509    assert_almost_equal(coverage_error([[0, 0, 0]], [[0.25, 0.5, 0.5]]), 0)
1510    assert_almost_equal(coverage_error([[0, 0, 1]], [[0.25, 0.5, 0.5]]), 2)
1511    assert_almost_equal(coverage_error([[0, 1, 0]], [[0.25, 0.5, 0.5]]), 2)
1512    assert_almost_equal(coverage_error([[0, 1, 1]], [[0.25, 0.5, 0.5]]), 2)
1513    assert_almost_equal(coverage_error([[1, 0, 0]], [[0.25, 0.5, 0.5]]), 3)
1514    assert_almost_equal(coverage_error([[1, 0, 1]], [[0.25, 0.5, 0.5]]), 3)
1515    assert_almost_equal(coverage_error([[1, 1, 0]], [[0.25, 0.5, 0.5]]), 3)
1516    assert_almost_equal(coverage_error([[1, 1, 1]], [[0.25, 0.5, 0.5]]), 3)
1517
1518
1519def test_label_ranking_loss():
1520    assert_almost_equal(label_ranking_loss([[0, 1]], [[0.25, 0.75]]), 0)
1521    assert_almost_equal(label_ranking_loss([[0, 1]], [[0.75, 0.25]]), 1)
1522
1523    assert_almost_equal(label_ranking_loss([[0, 0, 1]], [[0.25, 0.5, 0.75]]), 0)
1524    assert_almost_equal(label_ranking_loss([[0, 1, 0]], [[0.25, 0.5, 0.75]]), 1 / 2)
1525    assert_almost_equal(label_ranking_loss([[0, 1, 1]], [[0.25, 0.5, 0.75]]), 0)
1526    assert_almost_equal(label_ranking_loss([[1, 0, 0]], [[0.25, 0.5, 0.75]]), 2 / 2)
1527    assert_almost_equal(label_ranking_loss([[1, 0, 1]], [[0.25, 0.5, 0.75]]), 1 / 2)
1528    assert_almost_equal(label_ranking_loss([[1, 1, 0]], [[0.25, 0.5, 0.75]]), 2 / 2)
1529
1530    # Undefined metrics -  the ranking doesn't matter
1531    assert_almost_equal(label_ranking_loss([[0, 0]], [[0.75, 0.25]]), 0)
1532    assert_almost_equal(label_ranking_loss([[1, 1]], [[0.75, 0.25]]), 0)
1533    assert_almost_equal(label_ranking_loss([[0, 0]], [[0.5, 0.5]]), 0)
1534    assert_almost_equal(label_ranking_loss([[1, 1]], [[0.5, 0.5]]), 0)
1535
1536    assert_almost_equal(label_ranking_loss([[0, 0, 0]], [[0.5, 0.75, 0.25]]), 0)
1537    assert_almost_equal(label_ranking_loss([[1, 1, 1]], [[0.5, 0.75, 0.25]]), 0)
1538    assert_almost_equal(label_ranking_loss([[0, 0, 0]], [[0.25, 0.5, 0.5]]), 0)
1539    assert_almost_equal(label_ranking_loss([[1, 1, 1]], [[0.25, 0.5, 0.5]]), 0)
1540
1541    # Non trivial case
1542    assert_almost_equal(
1543        label_ranking_loss([[0, 1, 0], [1, 1, 0]], [[0.1, 10.0, -3], [0, 1, 3]]),
1544        (0 + 2 / 2) / 2.0,
1545    )
1546
1547    assert_almost_equal(
1548        label_ranking_loss(
1549            [[0, 1, 0], [1, 1, 0], [0, 1, 1]], [[0.1, 10, -3], [0, 1, 3], [0, 2, 0]]
1550        ),
1551        (0 + 2 / 2 + 1 / 2) / 3.0,
1552    )
1553
1554    assert_almost_equal(
1555        label_ranking_loss(
1556            [[0, 1, 0], [1, 1, 0], [0, 1, 1]], [[0.1, 10, -3], [3, 1, 3], [0, 2, 0]]
1557        ),
1558        (0 + 2 / 2 + 1 / 2) / 3.0,
1559    )
1560
1561    # Sparse csr matrices
1562    assert_almost_equal(
1563        label_ranking_loss(
1564            csr_matrix(np.array([[0, 1, 0], [1, 1, 0]])), [[0.1, 10, -3], [3, 1, 3]]
1565        ),
1566        (0 + 2 / 2) / 2.0,
1567    )
1568
1569
1570def test_ranking_appropriate_input_shape():
1571    # Check that y_true.shape != y_score.shape raise the proper exception
1572    with pytest.raises(ValueError):
1573        label_ranking_loss([[0, 1], [0, 1]], [0, 1])
1574    with pytest.raises(ValueError):
1575        label_ranking_loss([[0, 1], [0, 1]], [[0, 1]])
1576    with pytest.raises(ValueError):
1577        label_ranking_loss([[0, 1], [0, 1]], [[0], [1]])
1578    with pytest.raises(ValueError):
1579        label_ranking_loss([[0, 1]], [[0, 1], [0, 1]])
1580    with pytest.raises(ValueError):
1581        label_ranking_loss([[0], [1]], [[0, 1], [0, 1]])
1582    with pytest.raises(ValueError):
1583        label_ranking_loss([[0, 1], [0, 1]], [[0], [1]])
1584
1585
1586def test_ranking_loss_ties_handling():
1587    # Tie handling
1588    assert_almost_equal(label_ranking_loss([[1, 0]], [[0.5, 0.5]]), 1)
1589    assert_almost_equal(label_ranking_loss([[0, 1]], [[0.5, 0.5]]), 1)
1590    assert_almost_equal(label_ranking_loss([[0, 0, 1]], [[0.25, 0.5, 0.5]]), 1 / 2)
1591    assert_almost_equal(label_ranking_loss([[0, 1, 0]], [[0.25, 0.5, 0.5]]), 1 / 2)
1592    assert_almost_equal(label_ranking_loss([[0, 1, 1]], [[0.25, 0.5, 0.5]]), 0)
1593    assert_almost_equal(label_ranking_loss([[1, 0, 0]], [[0.25, 0.5, 0.5]]), 1)
1594    assert_almost_equal(label_ranking_loss([[1, 0, 1]], [[0.25, 0.5, 0.5]]), 1)
1595    assert_almost_equal(label_ranking_loss([[1, 1, 0]], [[0.25, 0.5, 0.5]]), 1)
1596
1597
1598def test_dcg_score():
1599    _, y_true = make_multilabel_classification(random_state=0, n_classes=10)
1600    y_score = -y_true + 1
1601    _test_dcg_score_for(y_true, y_score)
1602    y_true, y_score = np.random.RandomState(0).random_sample((2, 100, 10))
1603    _test_dcg_score_for(y_true, y_score)
1604
1605
1606def _test_dcg_score_for(y_true, y_score):
1607    discount = np.log2(np.arange(y_true.shape[1]) + 2)
1608    ideal = _dcg_sample_scores(y_true, y_true)
1609    score = _dcg_sample_scores(y_true, y_score)
1610    assert (score <= ideal).all()
1611    assert (_dcg_sample_scores(y_true, y_true, k=5) <= ideal).all()
1612    assert ideal.shape == (y_true.shape[0],)
1613    assert score.shape == (y_true.shape[0],)
1614    assert ideal == pytest.approx((np.sort(y_true)[:, ::-1] / discount).sum(axis=1))
1615
1616
1617def test_dcg_ties():
1618    y_true = np.asarray([np.arange(5)])
1619    y_score = np.zeros(y_true.shape)
1620    dcg = _dcg_sample_scores(y_true, y_score)
1621    dcg_ignore_ties = _dcg_sample_scores(y_true, y_score, ignore_ties=True)
1622    discounts = 1 / np.log2(np.arange(2, 7))
1623    assert dcg == pytest.approx([discounts.sum() * y_true.mean()])
1624    assert dcg_ignore_ties == pytest.approx([(discounts * y_true[:, ::-1]).sum()])
1625    y_score[0, 3:] = 1
1626    dcg = _dcg_sample_scores(y_true, y_score)
1627    dcg_ignore_ties = _dcg_sample_scores(y_true, y_score, ignore_ties=True)
1628    assert dcg_ignore_ties == pytest.approx([(discounts * y_true[:, ::-1]).sum()])
1629    assert dcg == pytest.approx(
1630        [
1631            discounts[:2].sum() * y_true[0, 3:].mean()
1632            + discounts[2:].sum() * y_true[0, :3].mean()
1633        ]
1634    )
1635
1636
1637def test_ndcg_ignore_ties_with_k():
1638    a = np.arange(12).reshape((2, 6))
1639    assert ndcg_score(a, a, k=3, ignore_ties=True) == pytest.approx(
1640        ndcg_score(a, a, k=3, ignore_ties=True)
1641    )
1642
1643
1644def test_ndcg_invariant():
1645    y_true = np.arange(70).reshape(7, 10)
1646    y_score = y_true + np.random.RandomState(0).uniform(-0.2, 0.2, size=y_true.shape)
1647    ndcg = ndcg_score(y_true, y_score)
1648    ndcg_no_ties = ndcg_score(y_true, y_score, ignore_ties=True)
1649    assert ndcg == pytest.approx(ndcg_no_ties)
1650    assert ndcg == pytest.approx(1.0)
1651    y_score += 1000
1652    assert ndcg_score(y_true, y_score) == pytest.approx(1.0)
1653
1654
1655@pytest.mark.parametrize("ignore_ties", [True, False])
1656def test_ndcg_toy_examples(ignore_ties):
1657    y_true = 3 * np.eye(7)[:5]
1658    y_score = np.tile(np.arange(6, -1, -1), (5, 1))
1659    y_score_noisy = y_score + np.random.RandomState(0).uniform(
1660        -0.2, 0.2, size=y_score.shape
1661    )
1662    assert _dcg_sample_scores(
1663        y_true, y_score, ignore_ties=ignore_ties
1664    ) == pytest.approx(3 / np.log2(np.arange(2, 7)))
1665    assert _dcg_sample_scores(
1666        y_true, y_score_noisy, ignore_ties=ignore_ties
1667    ) == pytest.approx(3 / np.log2(np.arange(2, 7)))
1668    assert _ndcg_sample_scores(
1669        y_true, y_score, ignore_ties=ignore_ties
1670    ) == pytest.approx(1 / np.log2(np.arange(2, 7)))
1671    assert _dcg_sample_scores(
1672        y_true, y_score, log_base=10, ignore_ties=ignore_ties
1673    ) == pytest.approx(3 / np.log10(np.arange(2, 7)))
1674    assert ndcg_score(y_true, y_score, ignore_ties=ignore_ties) == pytest.approx(
1675        (1 / np.log2(np.arange(2, 7))).mean()
1676    )
1677    assert dcg_score(y_true, y_score, ignore_ties=ignore_ties) == pytest.approx(
1678        (3 / np.log2(np.arange(2, 7))).mean()
1679    )
1680    y_true = 3 * np.ones((5, 7))
1681    expected_dcg_score = (3 / np.log2(np.arange(2, 9))).sum()
1682    assert _dcg_sample_scores(
1683        y_true, y_score, ignore_ties=ignore_ties
1684    ) == pytest.approx(expected_dcg_score * np.ones(5))
1685    assert _ndcg_sample_scores(
1686        y_true, y_score, ignore_ties=ignore_ties
1687    ) == pytest.approx(np.ones(5))
1688    assert dcg_score(y_true, y_score, ignore_ties=ignore_ties) == pytest.approx(
1689        expected_dcg_score
1690    )
1691    assert ndcg_score(y_true, y_score, ignore_ties=ignore_ties) == pytest.approx(1.0)
1692
1693
1694def test_ndcg_score():
1695    _, y_true = make_multilabel_classification(random_state=0, n_classes=10)
1696    y_score = -y_true + 1
1697    _test_ndcg_score_for(y_true, y_score)
1698    y_true, y_score = np.random.RandomState(0).random_sample((2, 100, 10))
1699    _test_ndcg_score_for(y_true, y_score)
1700
1701
1702def _test_ndcg_score_for(y_true, y_score):
1703    ideal = _ndcg_sample_scores(y_true, y_true)
1704    score = _ndcg_sample_scores(y_true, y_score)
1705    assert (score <= ideal).all()
1706    all_zero = (y_true == 0).all(axis=1)
1707    assert ideal[~all_zero] == pytest.approx(np.ones((~all_zero).sum()))
1708    assert ideal[all_zero] == pytest.approx(np.zeros(all_zero.sum()))
1709    assert score[~all_zero] == pytest.approx(
1710        _dcg_sample_scores(y_true, y_score)[~all_zero]
1711        / _dcg_sample_scores(y_true, y_true)[~all_zero]
1712    )
1713    assert score[all_zero] == pytest.approx(np.zeros(all_zero.sum()))
1714    assert ideal.shape == (y_true.shape[0],)
1715    assert score.shape == (y_true.shape[0],)
1716
1717
1718def test_partial_roc_auc_score():
1719    # Check `roc_auc_score` for max_fpr != `None`
1720    y_true = np.array([0, 0, 1, 1])
1721    assert roc_auc_score(y_true, y_true, max_fpr=1) == 1
1722    assert roc_auc_score(y_true, y_true, max_fpr=0.001) == 1
1723    with pytest.raises(ValueError):
1724        assert roc_auc_score(y_true, y_true, max_fpr=-0.1)
1725    with pytest.raises(ValueError):
1726        assert roc_auc_score(y_true, y_true, max_fpr=1.1)
1727    with pytest.raises(ValueError):
1728        assert roc_auc_score(y_true, y_true, max_fpr=0)
1729
1730    y_scores = np.array([0.1, 0, 0.1, 0.01])
1731    roc_auc_with_max_fpr_one = roc_auc_score(y_true, y_scores, max_fpr=1)
1732    unconstrained_roc_auc = roc_auc_score(y_true, y_scores)
1733    assert roc_auc_with_max_fpr_one == unconstrained_roc_auc
1734    assert roc_auc_score(y_true, y_scores, max_fpr=0.3) == 0.5
1735
1736    y_true, y_pred, _ = make_prediction(binary=True)
1737    for max_fpr in np.linspace(1e-4, 1, 5):
1738        assert_almost_equal(
1739            roc_auc_score(y_true, y_pred, max_fpr=max_fpr),
1740            _partial_roc_auc_score(y_true, y_pred, max_fpr),
1741        )
1742
1743
1744@pytest.mark.parametrize(
1745    "y_true, k, true_score",
1746    [
1747        ([0, 1, 2, 3], 1, 0.25),
1748        ([0, 1, 2, 3], 2, 0.5),
1749        ([0, 1, 2, 3], 3, 0.75),
1750    ],
1751)
1752def test_top_k_accuracy_score(y_true, k, true_score):
1753    y_score = np.array(
1754        [
1755            [0.4, 0.3, 0.2, 0.1],
1756            [0.1, 0.3, 0.4, 0.2],
1757            [0.4, 0.1, 0.2, 0.3],
1758            [0.3, 0.2, 0.4, 0.1],
1759        ]
1760    )
1761    score = top_k_accuracy_score(y_true, y_score, k=k)
1762    assert score == pytest.approx(true_score)
1763
1764
1765@pytest.mark.parametrize(
1766    "y_score, k, true_score",
1767    [
1768        (np.array([-1, -1, 1, 1]), 1, 1),
1769        (np.array([-1, 1, -1, 1]), 1, 0.5),
1770        (np.array([-1, 1, -1, 1]), 2, 1),
1771        (np.array([0.2, 0.2, 0.7, 0.7]), 1, 1),
1772        (np.array([0.2, 0.7, 0.2, 0.7]), 1, 0.5),
1773        (np.array([0.2, 0.7, 0.2, 0.7]), 2, 1),
1774    ],
1775)
1776def test_top_k_accuracy_score_binary(y_score, k, true_score):
1777    y_true = [0, 0, 1, 1]
1778
1779    threshold = 0.5 if y_score.min() >= 0 and y_score.max() <= 1 else 0
1780    y_pred = (y_score > threshold).astype(np.int64) if k == 1 else y_true
1781
1782    score = top_k_accuracy_score(y_true, y_score, k=k)
1783    score_acc = accuracy_score(y_true, y_pred)
1784
1785    assert score == score_acc == pytest.approx(true_score)
1786
1787
1788@pytest.mark.parametrize(
1789    "y_true, true_score, labels",
1790    [
1791        (np.array([0, 1, 1, 2]), 0.75, [0, 1, 2, 3]),
1792        (np.array([0, 1, 1, 1]), 0.5, [0, 1, 2, 3]),
1793        (np.array([1, 1, 1, 1]), 0.5, [0, 1, 2, 3]),
1794        (np.array(["a", "e", "e", "a"]), 0.75, ["a", "b", "d", "e"]),
1795    ],
1796)
1797@pytest.mark.parametrize("labels_as_ndarray", [True, False])
1798def test_top_k_accuracy_score_multiclass_with_labels(
1799    y_true, true_score, labels, labels_as_ndarray
1800):
1801    """Test when labels and y_score are multiclass."""
1802    if labels_as_ndarray:
1803        labels = np.asarray(labels)
1804    y_score = np.array(
1805        [
1806            [0.4, 0.3, 0.2, 0.1],
1807            [0.1, 0.3, 0.4, 0.2],
1808            [0.4, 0.1, 0.2, 0.3],
1809            [0.3, 0.2, 0.4, 0.1],
1810        ]
1811    )
1812
1813    score = top_k_accuracy_score(y_true, y_score, k=2, labels=labels)
1814    assert score == pytest.approx(true_score)
1815
1816
1817def test_top_k_accuracy_score_increasing():
1818    # Make sure increasing k leads to a higher score
1819    X, y = datasets.make_classification(
1820        n_classes=10, n_samples=1000, n_informative=10, random_state=0
1821    )
1822
1823    X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
1824
1825    clf = LogisticRegression(random_state=0)
1826    clf.fit(X_train, y_train)
1827
1828    for X, y in zip((X_train, X_test), (y_train, y_test)):
1829        scores = [
1830            top_k_accuracy_score(y, clf.predict_proba(X), k=k) for k in range(2, 10)
1831        ]
1832
1833        assert np.all(np.diff(scores) > 0)
1834
1835
1836@pytest.mark.parametrize(
1837    "y_true, k, true_score",
1838    [
1839        ([0, 1, 2, 3], 1, 0.25),
1840        ([0, 1, 2, 3], 2, 0.5),
1841        ([0, 1, 2, 3], 3, 1),
1842    ],
1843)
1844def test_top_k_accuracy_score_ties(y_true, k, true_score):
1845    # Make sure highest indices labels are chosen first in case of ties
1846    y_score = np.array(
1847        [
1848            [5, 5, 7, 0],
1849            [1, 5, 5, 5],
1850            [0, 0, 3, 3],
1851            [1, 1, 1, 1],
1852        ]
1853    )
1854    assert top_k_accuracy_score(y_true, y_score, k=k) == pytest.approx(true_score)
1855
1856
1857@pytest.mark.parametrize(
1858    "y_true, k",
1859    [
1860        ([0, 1, 2, 3], 4),
1861        ([0, 1, 2, 3], 5),
1862    ],
1863)
1864def test_top_k_accuracy_score_warning(y_true, k):
1865    y_score = np.array(
1866        [
1867            [0.4, 0.3, 0.2, 0.1],
1868            [0.1, 0.4, 0.3, 0.2],
1869            [0.2, 0.1, 0.4, 0.3],
1870            [0.3, 0.2, 0.1, 0.4],
1871        ]
1872    )
1873    expected_message = (
1874        r"'k' \(\d+\) greater than or equal to 'n_classes' \(\d+\) will result in a "
1875        "perfect score and is therefore meaningless."
1876    )
1877    with pytest.warns(UndefinedMetricWarning, match=expected_message):
1878        score = top_k_accuracy_score(y_true, y_score, k=k)
1879    assert score == 1
1880
1881
1882@pytest.mark.parametrize(
1883    "y_true, labels, msg",
1884    [
1885        (
1886            [0, 0.57, 1, 2],
1887            None,
1888            "y type must be 'binary' or 'multiclass', got 'continuous'",
1889        ),
1890        (
1891            [0, 1, 2, 3],
1892            None,
1893            r"Number of classes in 'y_true' \(4\) not equal to the number of "
1894            r"classes in 'y_score' \(3\).",
1895        ),
1896        (
1897            ["c", "c", "a", "b"],
1898            ["a", "b", "c", "c"],
1899            "Parameter 'labels' must be unique.",
1900        ),
1901        (["c", "c", "a", "b"], ["a", "c", "b"], "Parameter 'labels' must be ordered."),
1902        (
1903            [0, 0, 1, 2],
1904            [0, 1, 2, 3],
1905            r"Number of given labels \(4\) not equal to the number of classes in "
1906            r"'y_score' \(3\).",
1907        ),
1908        (
1909            [0, 0, 1, 2],
1910            [0, 1, 3],
1911            "'y_true' contains labels not in parameter 'labels'.",
1912        ),
1913    ],
1914)
1915def test_top_k_accuracy_score_error(y_true, labels, msg):
1916    y_score = np.array(
1917        [
1918            [0.2, 0.1, 0.7],
1919            [0.4, 0.3, 0.3],
1920            [0.3, 0.4, 0.3],
1921            [0.4, 0.5, 0.1],
1922        ]
1923    )
1924    with pytest.raises(ValueError, match=msg):
1925        top_k_accuracy_score(y_true, y_score, k=2, labels=labels)
1926