1import re 2import pytest 3import numpy as np 4import warnings 5from scipy.sparse import csr_matrix 6 7from sklearn import datasets 8from sklearn import svm 9 10from sklearn.utils.extmath import softmax 11from sklearn.datasets import make_multilabel_classification 12from sklearn.random_projection import _sparse_random_matrix 13from sklearn.utils.validation import check_array, check_consistent_length 14from sklearn.utils.validation import check_random_state 15 16from sklearn.utils._testing import assert_allclose 17from sklearn.utils._testing import assert_almost_equal 18from sklearn.utils._testing import assert_array_equal 19from sklearn.utils._testing import assert_array_almost_equal 20 21from sklearn.metrics import accuracy_score 22from sklearn.metrics import auc 23from sklearn.metrics import average_precision_score 24from sklearn.metrics import coverage_error 25from sklearn.metrics import det_curve 26from sklearn.metrics import label_ranking_average_precision_score 27from sklearn.metrics import precision_recall_curve 28from sklearn.metrics import label_ranking_loss 29from sklearn.metrics import roc_auc_score 30from sklearn.metrics import roc_curve 31from sklearn.metrics._ranking import _ndcg_sample_scores, _dcg_sample_scores 32from sklearn.metrics import ndcg_score, dcg_score 33from sklearn.metrics import top_k_accuracy_score 34 35from sklearn.exceptions import UndefinedMetricWarning 36from sklearn.model_selection import train_test_split 37from sklearn.linear_model import LogisticRegression 38 39 40############################################################################### 41# Utilities for testing 42 43CURVE_FUNCS = [ 44 det_curve, 45 precision_recall_curve, 46 roc_curve, 47] 48 49 50def make_prediction(dataset=None, binary=False): 51 """Make some classification predictions on a toy dataset using a SVC 52 53 If binary is True restrict to a binary classification problem instead of a 54 multiclass classification problem 55 """ 56 57 if dataset is None: 58 # import some data to play with 59 dataset = datasets.load_iris() 60 61 X = dataset.data 62 y = dataset.target 63 64 if binary: 65 # restrict to a binary classification task 66 X, y = X[y < 2], y[y < 2] 67 68 n_samples, n_features = X.shape 69 p = np.arange(n_samples) 70 71 rng = check_random_state(37) 72 rng.shuffle(p) 73 X, y = X[p], y[p] 74 half = int(n_samples / 2) 75 76 # add noisy features to make the problem harder and avoid perfect results 77 rng = np.random.RandomState(0) 78 X = np.c_[X, rng.randn(n_samples, 200 * n_features)] 79 80 # run classifier, get class probabilities and label predictions 81 clf = svm.SVC(kernel="linear", probability=True, random_state=0) 82 y_score = clf.fit(X[:half], y[:half]).predict_proba(X[half:]) 83 84 if binary: 85 # only interested in probabilities of the positive case 86 # XXX: do we really want a special API for the binary case? 87 y_score = y_score[:, 1] 88 89 y_pred = clf.predict(X[half:]) 90 y_true = y[half:] 91 return y_true, y_pred, y_score 92 93 94############################################################################### 95# Tests 96 97 98def _auc(y_true, y_score): 99 """Alternative implementation to check for correctness of 100 `roc_auc_score`.""" 101 pos_label = np.unique(y_true)[1] 102 103 # Count the number of times positive samples are correctly ranked above 104 # negative samples. 105 pos = y_score[y_true == pos_label] 106 neg = y_score[y_true != pos_label] 107 diff_matrix = pos.reshape(1, -1) - neg.reshape(-1, 1) 108 n_correct = np.sum(diff_matrix > 0) 109 110 return n_correct / float(len(pos) * len(neg)) 111 112 113def _average_precision(y_true, y_score): 114 """Alternative implementation to check for correctness of 115 `average_precision_score`. 116 117 Note that this implementation fails on some edge cases. 118 For example, for constant predictions e.g. [0.5, 0.5, 0.5], 119 y_true = [1, 0, 0] returns an average precision of 0.33... 120 but y_true = [0, 0, 1] returns 1.0. 121 """ 122 pos_label = np.unique(y_true)[1] 123 n_pos = np.sum(y_true == pos_label) 124 order = np.argsort(y_score)[::-1] 125 y_score = y_score[order] 126 y_true = y_true[order] 127 128 score = 0 129 for i in range(len(y_score)): 130 if y_true[i] == pos_label: 131 # Compute precision up to document i 132 # i.e, percentage of relevant documents up to document i. 133 prec = 0 134 for j in range(0, i + 1): 135 if y_true[j] == pos_label: 136 prec += 1.0 137 prec /= i + 1.0 138 score += prec 139 140 return score / n_pos 141 142 143def _average_precision_slow(y_true, y_score): 144 """A second alternative implementation of average precision that closely 145 follows the Wikipedia article's definition (see References). This should 146 give identical results as `average_precision_score` for all inputs. 147 148 References 149 ---------- 150 .. [1] `Wikipedia entry for the Average precision 151 <https://en.wikipedia.org/wiki/Average_precision>`_ 152 """ 153 precision, recall, threshold = precision_recall_curve(y_true, y_score) 154 precision = list(reversed(precision)) 155 recall = list(reversed(recall)) 156 average_precision = 0 157 for i in range(1, len(precision)): 158 average_precision += precision[i] * (recall[i] - recall[i - 1]) 159 return average_precision 160 161 162def _partial_roc_auc_score(y_true, y_predict, max_fpr): 163 """Alternative implementation to check for correctness of `roc_auc_score` 164 with `max_fpr` set. 165 """ 166 167 def _partial_roc(y_true, y_predict, max_fpr): 168 fpr, tpr, _ = roc_curve(y_true, y_predict) 169 new_fpr = fpr[fpr <= max_fpr] 170 new_fpr = np.append(new_fpr, max_fpr) 171 new_tpr = tpr[fpr <= max_fpr] 172 idx_out = np.argmax(fpr > max_fpr) 173 idx_in = idx_out - 1 174 x_interp = [fpr[idx_in], fpr[idx_out]] 175 y_interp = [tpr[idx_in], tpr[idx_out]] 176 new_tpr = np.append(new_tpr, np.interp(max_fpr, x_interp, y_interp)) 177 return (new_fpr, new_tpr) 178 179 new_fpr, new_tpr = _partial_roc(y_true, y_predict, max_fpr) 180 partial_auc = auc(new_fpr, new_tpr) 181 182 # Formula (5) from McClish 1989 183 fpr1 = 0 184 fpr2 = max_fpr 185 min_area = 0.5 * (fpr2 - fpr1) * (fpr2 + fpr1) 186 max_area = fpr2 - fpr1 187 return 0.5 * (1 + (partial_auc - min_area) / (max_area - min_area)) 188 189 190@pytest.mark.parametrize("drop", [True, False]) 191def test_roc_curve(drop): 192 # Test Area under Receiver Operating Characteristic (ROC) curve 193 y_true, _, y_score = make_prediction(binary=True) 194 expected_auc = _auc(y_true, y_score) 195 196 fpr, tpr, thresholds = roc_curve(y_true, y_score, drop_intermediate=drop) 197 roc_auc = auc(fpr, tpr) 198 assert_array_almost_equal(roc_auc, expected_auc, decimal=2) 199 assert_almost_equal(roc_auc, roc_auc_score(y_true, y_score)) 200 assert fpr.shape == tpr.shape 201 assert fpr.shape == thresholds.shape 202 203 204def test_roc_curve_end_points(): 205 # Make sure that roc_curve returns a curve start at 0 and ending and 206 # 1 even in corner cases 207 rng = np.random.RandomState(0) 208 y_true = np.array([0] * 50 + [1] * 50) 209 y_pred = rng.randint(3, size=100) 210 fpr, tpr, thr = roc_curve(y_true, y_pred, drop_intermediate=True) 211 assert fpr[0] == 0 212 assert fpr[-1] == 1 213 assert fpr.shape == tpr.shape 214 assert fpr.shape == thr.shape 215 216 217def test_roc_returns_consistency(): 218 # Test whether the returned threshold matches up with tpr 219 # make small toy dataset 220 y_true, _, y_score = make_prediction(binary=True) 221 fpr, tpr, thresholds = roc_curve(y_true, y_score) 222 223 # use the given thresholds to determine the tpr 224 tpr_correct = [] 225 for t in thresholds: 226 tp = np.sum((y_score >= t) & y_true) 227 p = np.sum(y_true) 228 tpr_correct.append(1.0 * tp / p) 229 230 # compare tpr and tpr_correct to see if the thresholds' order was correct 231 assert_array_almost_equal(tpr, tpr_correct, decimal=2) 232 assert fpr.shape == tpr.shape 233 assert fpr.shape == thresholds.shape 234 235 236def test_roc_curve_multi(): 237 # roc_curve not applicable for multi-class problems 238 y_true, _, y_score = make_prediction(binary=False) 239 240 with pytest.raises(ValueError): 241 roc_curve(y_true, y_score) 242 243 244def test_roc_curve_confidence(): 245 # roc_curve for confidence scores 246 y_true, _, y_score = make_prediction(binary=True) 247 248 fpr, tpr, thresholds = roc_curve(y_true, y_score - 0.5) 249 roc_auc = auc(fpr, tpr) 250 assert_array_almost_equal(roc_auc, 0.90, decimal=2) 251 assert fpr.shape == tpr.shape 252 assert fpr.shape == thresholds.shape 253 254 255def test_roc_curve_hard(): 256 # roc_curve for hard decisions 257 y_true, pred, y_score = make_prediction(binary=True) 258 259 # always predict one 260 trivial_pred = np.ones(y_true.shape) 261 fpr, tpr, thresholds = roc_curve(y_true, trivial_pred) 262 roc_auc = auc(fpr, tpr) 263 assert_array_almost_equal(roc_auc, 0.50, decimal=2) 264 assert fpr.shape == tpr.shape 265 assert fpr.shape == thresholds.shape 266 267 # always predict zero 268 trivial_pred = np.zeros(y_true.shape) 269 fpr, tpr, thresholds = roc_curve(y_true, trivial_pred) 270 roc_auc = auc(fpr, tpr) 271 assert_array_almost_equal(roc_auc, 0.50, decimal=2) 272 assert fpr.shape == tpr.shape 273 assert fpr.shape == thresholds.shape 274 275 # hard decisions 276 fpr, tpr, thresholds = roc_curve(y_true, pred) 277 roc_auc = auc(fpr, tpr) 278 assert_array_almost_equal(roc_auc, 0.78, decimal=2) 279 assert fpr.shape == tpr.shape 280 assert fpr.shape == thresholds.shape 281 282 283def test_roc_curve_one_label(): 284 y_true = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] 285 y_pred = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1] 286 # assert there are warnings 287 expected_message = ( 288 "No negative samples in y_true, false positive value should be meaningless" 289 ) 290 with pytest.warns(UndefinedMetricWarning, match=expected_message): 291 fpr, tpr, thresholds = roc_curve(y_true, y_pred) 292 293 # all true labels, all fpr should be nan 294 assert_array_equal(fpr, np.full(len(thresholds), np.nan)) 295 assert fpr.shape == tpr.shape 296 assert fpr.shape == thresholds.shape 297 298 # assert there are warnings 299 expected_message = ( 300 "No positive samples in y_true, true positive value should be meaningless" 301 ) 302 with pytest.warns(UndefinedMetricWarning, match=expected_message): 303 fpr, tpr, thresholds = roc_curve([1 - x for x in y_true], y_pred) 304 # all negative labels, all tpr should be nan 305 assert_array_equal(tpr, np.full(len(thresholds), np.nan)) 306 assert fpr.shape == tpr.shape 307 assert fpr.shape == thresholds.shape 308 309 310def test_roc_curve_toydata(): 311 # Binary classification 312 y_true = [0, 1] 313 y_score = [0, 1] 314 tpr, fpr, _ = roc_curve(y_true, y_score) 315 roc_auc = roc_auc_score(y_true, y_score) 316 assert_array_almost_equal(tpr, [0, 0, 1]) 317 assert_array_almost_equal(fpr, [0, 1, 1]) 318 assert_almost_equal(roc_auc, 1.0) 319 320 y_true = [0, 1] 321 y_score = [1, 0] 322 tpr, fpr, _ = roc_curve(y_true, y_score) 323 roc_auc = roc_auc_score(y_true, y_score) 324 assert_array_almost_equal(tpr, [0, 1, 1]) 325 assert_array_almost_equal(fpr, [0, 0, 1]) 326 assert_almost_equal(roc_auc, 0.0) 327 328 y_true = [1, 0] 329 y_score = [1, 1] 330 tpr, fpr, _ = roc_curve(y_true, y_score) 331 roc_auc = roc_auc_score(y_true, y_score) 332 assert_array_almost_equal(tpr, [0, 1]) 333 assert_array_almost_equal(fpr, [0, 1]) 334 assert_almost_equal(roc_auc, 0.5) 335 336 y_true = [1, 0] 337 y_score = [1, 0] 338 tpr, fpr, _ = roc_curve(y_true, y_score) 339 roc_auc = roc_auc_score(y_true, y_score) 340 assert_array_almost_equal(tpr, [0, 0, 1]) 341 assert_array_almost_equal(fpr, [0, 1, 1]) 342 assert_almost_equal(roc_auc, 1.0) 343 344 y_true = [1, 0] 345 y_score = [0.5, 0.5] 346 tpr, fpr, _ = roc_curve(y_true, y_score) 347 roc_auc = roc_auc_score(y_true, y_score) 348 assert_array_almost_equal(tpr, [0, 1]) 349 assert_array_almost_equal(fpr, [0, 1]) 350 assert_almost_equal(roc_auc, 0.5) 351 352 y_true = [0, 0] 353 y_score = [0.25, 0.75] 354 # assert UndefinedMetricWarning because of no positive sample in y_true 355 expected_message = ( 356 "No positive samples in y_true, true positive value should be meaningless" 357 ) 358 with pytest.warns(UndefinedMetricWarning, match=expected_message): 359 tpr, fpr, _ = roc_curve(y_true, y_score) 360 361 with pytest.raises(ValueError): 362 roc_auc_score(y_true, y_score) 363 assert_array_almost_equal(tpr, [0.0, 0.5, 1.0]) 364 assert_array_almost_equal(fpr, [np.nan, np.nan, np.nan]) 365 366 y_true = [1, 1] 367 y_score = [0.25, 0.75] 368 # assert UndefinedMetricWarning because of no negative sample in y_true 369 expected_message = ( 370 "No negative samples in y_true, false positive value should be meaningless" 371 ) 372 with pytest.warns(UndefinedMetricWarning, match=expected_message): 373 tpr, fpr, _ = roc_curve(y_true, y_score) 374 375 with pytest.raises(ValueError): 376 roc_auc_score(y_true, y_score) 377 assert_array_almost_equal(tpr, [np.nan, np.nan, np.nan]) 378 assert_array_almost_equal(fpr, [0.0, 0.5, 1.0]) 379 380 # Multi-label classification task 381 y_true = np.array([[0, 1], [0, 1]]) 382 y_score = np.array([[0, 1], [0, 1]]) 383 with pytest.raises(ValueError): 384 roc_auc_score(y_true, y_score, average="macro") 385 with pytest.raises(ValueError): 386 roc_auc_score(y_true, y_score, average="weighted") 387 assert_almost_equal(roc_auc_score(y_true, y_score, average="samples"), 1.0) 388 assert_almost_equal(roc_auc_score(y_true, y_score, average="micro"), 1.0) 389 390 y_true = np.array([[0, 1], [0, 1]]) 391 y_score = np.array([[0, 1], [1, 0]]) 392 with pytest.raises(ValueError): 393 roc_auc_score(y_true, y_score, average="macro") 394 with pytest.raises(ValueError): 395 roc_auc_score(y_true, y_score, average="weighted") 396 assert_almost_equal(roc_auc_score(y_true, y_score, average="samples"), 0.5) 397 assert_almost_equal(roc_auc_score(y_true, y_score, average="micro"), 0.5) 398 399 y_true = np.array([[1, 0], [0, 1]]) 400 y_score = np.array([[0, 1], [1, 0]]) 401 assert_almost_equal(roc_auc_score(y_true, y_score, average="macro"), 0) 402 assert_almost_equal(roc_auc_score(y_true, y_score, average="weighted"), 0) 403 assert_almost_equal(roc_auc_score(y_true, y_score, average="samples"), 0) 404 assert_almost_equal(roc_auc_score(y_true, y_score, average="micro"), 0) 405 406 y_true = np.array([[1, 0], [0, 1]]) 407 y_score = np.array([[0.5, 0.5], [0.5, 0.5]]) 408 assert_almost_equal(roc_auc_score(y_true, y_score, average="macro"), 0.5) 409 assert_almost_equal(roc_auc_score(y_true, y_score, average="weighted"), 0.5) 410 assert_almost_equal(roc_auc_score(y_true, y_score, average="samples"), 0.5) 411 assert_almost_equal(roc_auc_score(y_true, y_score, average="micro"), 0.5) 412 413 414def test_roc_curve_drop_intermediate(): 415 # Test that drop_intermediate drops the correct thresholds 416 y_true = [0, 0, 0, 0, 1, 1] 417 y_score = [0.0, 0.2, 0.5, 0.6, 0.7, 1.0] 418 tpr, fpr, thresholds = roc_curve(y_true, y_score, drop_intermediate=True) 419 assert_array_almost_equal(thresholds, [2.0, 1.0, 0.7, 0.0]) 420 421 # Test dropping thresholds with repeating scores 422 y_true = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1] 423 y_score = [0.0, 0.1, 0.6, 0.6, 0.7, 0.8, 0.9, 0.6, 0.7, 0.8, 0.9, 0.9, 1.0] 424 tpr, fpr, thresholds = roc_curve(y_true, y_score, drop_intermediate=True) 425 assert_array_almost_equal(thresholds, [2.0, 1.0, 0.9, 0.7, 0.6, 0.0]) 426 427 428def test_roc_curve_fpr_tpr_increasing(): 429 # Ensure that fpr and tpr returned by roc_curve are increasing. 430 # Construct an edge case with float y_score and sample_weight 431 # when some adjacent values of fpr and tpr are actually the same. 432 y_true = [0, 0, 1, 1, 1] 433 y_score = [0.1, 0.7, 0.3, 0.4, 0.5] 434 sample_weight = np.repeat(0.2, 5) 435 fpr, tpr, _ = roc_curve(y_true, y_score, sample_weight=sample_weight) 436 assert (np.diff(fpr) < 0).sum() == 0 437 assert (np.diff(tpr) < 0).sum() == 0 438 439 440def test_auc(): 441 # Test Area Under Curve (AUC) computation 442 x = [0, 1] 443 y = [0, 1] 444 assert_array_almost_equal(auc(x, y), 0.5) 445 x = [1, 0] 446 y = [0, 1] 447 assert_array_almost_equal(auc(x, y), 0.5) 448 x = [1, 0, 0] 449 y = [0, 1, 1] 450 assert_array_almost_equal(auc(x, y), 0.5) 451 x = [0, 1] 452 y = [1, 1] 453 assert_array_almost_equal(auc(x, y), 1) 454 x = [0, 0.5, 1] 455 y = [0, 0.5, 1] 456 assert_array_almost_equal(auc(x, y), 0.5) 457 458 459def test_auc_errors(): 460 # Incompatible shapes 461 with pytest.raises(ValueError): 462 auc([0.0, 0.5, 1.0], [0.1, 0.2]) 463 464 # Too few x values 465 with pytest.raises(ValueError): 466 auc([0.0], [0.1]) 467 468 # x is not in order 469 x = [2, 1, 3, 4] 470 y = [5, 6, 7, 8] 471 error_message = "x is neither increasing nor decreasing : {}".format(np.array(x)) 472 with pytest.raises(ValueError, match=re.escape(error_message)): 473 auc(x, y) 474 475 476@pytest.mark.parametrize( 477 "y_true, labels", 478 [ 479 (np.array([0, 1, 0, 2]), [0, 1, 2]), 480 (np.array([0, 1, 0, 2]), None), 481 (["a", "b", "a", "c"], ["a", "b", "c"]), 482 (["a", "b", "a", "c"], None), 483 ], 484) 485def test_multiclass_ovo_roc_auc_toydata(y_true, labels): 486 # Tests the one-vs-one multiclass ROC AUC algorithm 487 # on a small example, representative of an expected use case. 488 y_scores = np.array( 489 [[0.1, 0.8, 0.1], [0.3, 0.4, 0.3], [0.35, 0.5, 0.15], [0, 0.2, 0.8]] 490 ) 491 492 # Used to compute the expected output. 493 # Consider labels 0 and 1: 494 # positive label is 0, negative label is 1 495 score_01 = roc_auc_score([1, 0, 1], [0.1, 0.3, 0.35]) 496 # positive label is 1, negative label is 0 497 score_10 = roc_auc_score([0, 1, 0], [0.8, 0.4, 0.5]) 498 average_score_01 = (score_01 + score_10) / 2 499 500 # Consider labels 0 and 2: 501 score_02 = roc_auc_score([1, 1, 0], [0.1, 0.35, 0]) 502 score_20 = roc_auc_score([0, 0, 1], [0.1, 0.15, 0.8]) 503 average_score_02 = (score_02 + score_20) / 2 504 505 # Consider labels 1 and 2: 506 score_12 = roc_auc_score([1, 0], [0.4, 0.2]) 507 score_21 = roc_auc_score([0, 1], [0.3, 0.8]) 508 average_score_12 = (score_12 + score_21) / 2 509 510 # Unweighted, one-vs-one multiclass ROC AUC algorithm 511 ovo_unweighted_score = (average_score_01 + average_score_02 + average_score_12) / 3 512 assert_almost_equal( 513 roc_auc_score(y_true, y_scores, labels=labels, multi_class="ovo"), 514 ovo_unweighted_score, 515 ) 516 517 # Weighted, one-vs-one multiclass ROC AUC algorithm 518 # Each term is weighted by the prevalence for the positive label. 519 pair_scores = [average_score_01, average_score_02, average_score_12] 520 prevalence = [0.75, 0.75, 0.50] 521 ovo_weighted_score = np.average(pair_scores, weights=prevalence) 522 assert_almost_equal( 523 roc_auc_score( 524 y_true, y_scores, labels=labels, multi_class="ovo", average="weighted" 525 ), 526 ovo_weighted_score, 527 ) 528 529 530@pytest.mark.parametrize( 531 "y_true, labels", 532 [ 533 (np.array([0, 2, 0, 2]), [0, 1, 2]), 534 (np.array(["a", "d", "a", "d"]), ["a", "b", "d"]), 535 ], 536) 537def test_multiclass_ovo_roc_auc_toydata_binary(y_true, labels): 538 # Tests the one-vs-one multiclass ROC AUC algorithm for binary y_true 539 # 540 # on a small example, representative of an expected use case. 541 y_scores = np.array( 542 [[0.2, 0.0, 0.8], [0.6, 0.0, 0.4], [0.55, 0.0, 0.45], [0.4, 0.0, 0.6]] 543 ) 544 545 # Used to compute the expected output. 546 # Consider labels 0 and 1: 547 # positive label is 0, negative label is 1 548 score_01 = roc_auc_score([1, 0, 1, 0], [0.2, 0.6, 0.55, 0.4]) 549 # positive label is 1, negative label is 0 550 score_10 = roc_auc_score([0, 1, 0, 1], [0.8, 0.4, 0.45, 0.6]) 551 ovo_score = (score_01 + score_10) / 2 552 553 assert_almost_equal( 554 roc_auc_score(y_true, y_scores, labels=labels, multi_class="ovo"), ovo_score 555 ) 556 557 # Weighted, one-vs-one multiclass ROC AUC algorithm 558 assert_almost_equal( 559 roc_auc_score( 560 y_true, y_scores, labels=labels, multi_class="ovo", average="weighted" 561 ), 562 ovo_score, 563 ) 564 565 566@pytest.mark.parametrize( 567 "y_true, labels", 568 [ 569 (np.array([0, 1, 2, 2]), None), 570 (["a", "b", "c", "c"], None), 571 ([0, 1, 2, 2], [0, 1, 2]), 572 (["a", "b", "c", "c"], ["a", "b", "c"]), 573 ], 574) 575def test_multiclass_ovr_roc_auc_toydata(y_true, labels): 576 # Tests the unweighted, one-vs-rest multiclass ROC AUC algorithm 577 # on a small example, representative of an expected use case. 578 y_scores = np.array( 579 [[1.0, 0.0, 0.0], [0.1, 0.5, 0.4], [0.1, 0.1, 0.8], [0.3, 0.3, 0.4]] 580 ) 581 # Compute the expected result by individually computing the 'one-vs-rest' 582 # ROC AUC scores for classes 0, 1, and 2. 583 out_0 = roc_auc_score([1, 0, 0, 0], y_scores[:, 0]) 584 out_1 = roc_auc_score([0, 1, 0, 0], y_scores[:, 1]) 585 out_2 = roc_auc_score([0, 0, 1, 1], y_scores[:, 2]) 586 result_unweighted = (out_0 + out_1 + out_2) / 3.0 587 588 assert_almost_equal( 589 roc_auc_score(y_true, y_scores, multi_class="ovr", labels=labels), 590 result_unweighted, 591 ) 592 593 # Tests the weighted, one-vs-rest multiclass ROC AUC algorithm 594 # on the same input (Provost & Domingos, 2000) 595 result_weighted = out_0 * 0.25 + out_1 * 0.25 + out_2 * 0.5 596 assert_almost_equal( 597 roc_auc_score( 598 y_true, y_scores, multi_class="ovr", labels=labels, average="weighted" 599 ), 600 result_weighted, 601 ) 602 603 604@pytest.mark.parametrize( 605 "msg, y_true, labels", 606 [ 607 ("Parameter 'labels' must be unique", np.array([0, 1, 2, 2]), [0, 2, 0]), 608 ( 609 "Parameter 'labels' must be unique", 610 np.array(["a", "b", "c", "c"]), 611 ["a", "a", "b"], 612 ), 613 ( 614 "Number of classes in y_true not equal to the number of columns " 615 "in 'y_score'", 616 np.array([0, 2, 0, 2]), 617 None, 618 ), 619 ( 620 "Parameter 'labels' must be ordered", 621 np.array(["a", "b", "c", "c"]), 622 ["a", "c", "b"], 623 ), 624 ( 625 "Number of given labels, 2, not equal to the number of columns in " 626 "'y_score', 3", 627 np.array([0, 1, 2, 2]), 628 [0, 1], 629 ), 630 ( 631 "Number of given labels, 2, not equal to the number of columns in " 632 "'y_score', 3", 633 np.array(["a", "b", "c", "c"]), 634 ["a", "b"], 635 ), 636 ( 637 "Number of given labels, 4, not equal to the number of columns in " 638 "'y_score', 3", 639 np.array([0, 1, 2, 2]), 640 [0, 1, 2, 3], 641 ), 642 ( 643 "Number of given labels, 4, not equal to the number of columns in " 644 "'y_score', 3", 645 np.array(["a", "b", "c", "c"]), 646 ["a", "b", "c", "d"], 647 ), 648 ( 649 "'y_true' contains labels not in parameter 'labels'", 650 np.array(["a", "b", "c", "e"]), 651 ["a", "b", "c"], 652 ), 653 ( 654 "'y_true' contains labels not in parameter 'labels'", 655 np.array(["a", "b", "c", "d"]), 656 ["a", "b", "c"], 657 ), 658 ( 659 "'y_true' contains labels not in parameter 'labels'", 660 np.array([0, 1, 2, 3]), 661 [0, 1, 2], 662 ), 663 ], 664) 665@pytest.mark.parametrize("multi_class", ["ovo", "ovr"]) 666def test_roc_auc_score_multiclass_labels_error(msg, y_true, labels, multi_class): 667 y_scores = np.array( 668 [[0.1, 0.8, 0.1], [0.3, 0.4, 0.3], [0.35, 0.5, 0.15], [0, 0.2, 0.8]] 669 ) 670 671 with pytest.raises(ValueError, match=msg): 672 roc_auc_score(y_true, y_scores, labels=labels, multi_class=multi_class) 673 674 675@pytest.mark.parametrize( 676 "msg, kwargs", 677 [ 678 ( 679 ( 680 r"average must be one of \('macro', 'weighted'\) for " 681 r"multiclass problems" 682 ), 683 {"average": "samples", "multi_class": "ovo"}, 684 ), 685 ( 686 ( 687 r"average must be one of \('macro', 'weighted'\) for " 688 r"multiclass problems" 689 ), 690 {"average": "micro", "multi_class": "ovr"}, 691 ), 692 ( 693 ( 694 r"sample_weight is not supported for multiclass one-vs-one " 695 r"ROC AUC, 'sample_weight' must be None in this case" 696 ), 697 {"multi_class": "ovo", "sample_weight": []}, 698 ), 699 ( 700 ( 701 r"Partial AUC computation not available in multiclass setting, " 702 r"'max_fpr' must be set to `None`, received `max_fpr=0.5` " 703 r"instead" 704 ), 705 {"multi_class": "ovo", "max_fpr": 0.5}, 706 ), 707 ( 708 ( 709 r"multi_class='ovp' is not supported for multiclass ROC AUC, " 710 r"multi_class must be in \('ovo', 'ovr'\)" 711 ), 712 {"multi_class": "ovp"}, 713 ), 714 (r"multi_class must be in \('ovo', 'ovr'\)", {}), 715 ], 716) 717def test_roc_auc_score_multiclass_error(msg, kwargs): 718 # Test that roc_auc_score function returns an error when trying 719 # to compute multiclass AUC for parameters where an output 720 # is not defined. 721 rng = check_random_state(404) 722 y_score = rng.rand(20, 3) 723 y_prob = softmax(y_score) 724 y_true = rng.randint(0, 3, size=20) 725 with pytest.raises(ValueError, match=msg): 726 roc_auc_score(y_true, y_prob, **kwargs) 727 728 729def test_auc_score_non_binary_class(): 730 # Test that roc_auc_score function returns an error when trying 731 # to compute AUC for non-binary class values. 732 rng = check_random_state(404) 733 y_pred = rng.rand(10) 734 # y_true contains only one class value 735 y_true = np.zeros(10, dtype="int") 736 err_msg = "ROC AUC score is not defined" 737 with pytest.raises(ValueError, match=err_msg): 738 roc_auc_score(y_true, y_pred) 739 y_true = np.ones(10, dtype="int") 740 with pytest.raises(ValueError, match=err_msg): 741 roc_auc_score(y_true, y_pred) 742 y_true = np.full(10, -1, dtype="int") 743 with pytest.raises(ValueError, match=err_msg): 744 roc_auc_score(y_true, y_pred) 745 746 with warnings.catch_warnings(record=True): 747 rng = check_random_state(404) 748 y_pred = rng.rand(10) 749 # y_true contains only one class value 750 y_true = np.zeros(10, dtype="int") 751 with pytest.raises(ValueError, match=err_msg): 752 roc_auc_score(y_true, y_pred) 753 y_true = np.ones(10, dtype="int") 754 with pytest.raises(ValueError, match=err_msg): 755 roc_auc_score(y_true, y_pred) 756 y_true = np.full(10, -1, dtype="int") 757 with pytest.raises(ValueError, match=err_msg): 758 roc_auc_score(y_true, y_pred) 759 760 761@pytest.mark.parametrize("curve_func", CURVE_FUNCS) 762def test_binary_clf_curve_multiclass_error(curve_func): 763 rng = check_random_state(404) 764 y_true = rng.randint(0, 3, size=10) 765 y_pred = rng.rand(10) 766 msg = "multiclass format is not supported" 767 with pytest.raises(ValueError, match=msg): 768 curve_func(y_true, y_pred) 769 770 771@pytest.mark.parametrize("curve_func", CURVE_FUNCS) 772def test_binary_clf_curve_implicit_pos_label(curve_func): 773 # Check that using string class labels raises an informative 774 # error for any supported string dtype: 775 msg = ( 776 "y_true takes value in {'a', 'b'} and pos_label is " 777 "not specified: either make y_true take " 778 "value in {0, 1} or {-1, 1} or pass pos_label " 779 "explicitly." 780 ) 781 with pytest.raises(ValueError, match=msg): 782 curve_func(np.array(["a", "b"], dtype="<U1"), [0.0, 1.0]) 783 784 with pytest.raises(ValueError, match=msg): 785 curve_func(np.array(["a", "b"], dtype=object), [0.0, 1.0]) 786 787 # The error message is slightly different for bytes-encoded 788 # class labels, but otherwise the behavior is the same: 789 msg = ( 790 "y_true takes value in {b'a', b'b'} and pos_label is " 791 "not specified: either make y_true take " 792 "value in {0, 1} or {-1, 1} or pass pos_label " 793 "explicitly." 794 ) 795 with pytest.raises(ValueError, match=msg): 796 curve_func(np.array([b"a", b"b"], dtype="<S1"), [0.0, 1.0]) 797 798 # Check that it is possible to use floating point class labels 799 # that are interpreted similarly to integer class labels: 800 y_pred = [0.0, 1.0, 0.2, 0.42] 801 int_curve = curve_func([0, 1, 1, 0], y_pred) 802 float_curve = curve_func([0.0, 1.0, 1.0, 0.0], y_pred) 803 for int_curve_part, float_curve_part in zip(int_curve, float_curve): 804 np.testing.assert_allclose(int_curve_part, float_curve_part) 805 806 807@pytest.mark.parametrize("curve_func", CURVE_FUNCS) 808def test_binary_clf_curve_zero_sample_weight(curve_func): 809 y_true = [0, 0, 1, 1, 1] 810 y_score = [0.1, 0.2, 0.3, 0.4, 0.5] 811 sample_weight = [1, 1, 1, 0.5, 0] 812 813 result_1 = curve_func(y_true, y_score, sample_weight=sample_weight) 814 result_2 = curve_func(y_true[:-1], y_score[:-1], sample_weight=sample_weight[:-1]) 815 816 for arr_1, arr_2 in zip(result_1, result_2): 817 assert_allclose(arr_1, arr_2) 818 819 820def test_precision_recall_curve(): 821 y_true, _, y_score = make_prediction(binary=True) 822 _test_precision_recall_curve(y_true, y_score) 823 824 # Use {-1, 1} for labels; make sure original labels aren't modified 825 y_true[np.where(y_true == 0)] = -1 826 y_true_copy = y_true.copy() 827 _test_precision_recall_curve(y_true, y_score) 828 assert_array_equal(y_true_copy, y_true) 829 830 labels = [1, 0, 0, 1] 831 predict_probas = [1, 2, 3, 4] 832 p, r, t = precision_recall_curve(labels, predict_probas) 833 assert_array_almost_equal(p, np.array([0.5, 0.33333333, 0.5, 1.0, 1.0])) 834 assert_array_almost_equal(r, np.array([1.0, 0.5, 0.5, 0.5, 0.0])) 835 assert_array_almost_equal(t, np.array([1, 2, 3, 4])) 836 assert p.size == r.size 837 assert p.size == t.size + 1 838 839 840def _test_precision_recall_curve(y_true, y_score): 841 # Test Precision-Recall and aread under PR curve 842 p, r, thresholds = precision_recall_curve(y_true, y_score) 843 precision_recall_auc = _average_precision_slow(y_true, y_score) 844 assert_array_almost_equal(precision_recall_auc, 0.859, 3) 845 assert_array_almost_equal( 846 precision_recall_auc, average_precision_score(y_true, y_score) 847 ) 848 # `_average_precision` is not very precise in case of 0.5 ties: be tolerant 849 assert_almost_equal( 850 _average_precision(y_true, y_score), precision_recall_auc, decimal=2 851 ) 852 assert p.size == r.size 853 assert p.size == thresholds.size + 1 854 # Smoke test in the case of proba having only one value 855 p, r, thresholds = precision_recall_curve(y_true, np.zeros_like(y_score)) 856 assert p.size == r.size 857 assert p.size == thresholds.size + 1 858 859 860def test_precision_recall_curve_toydata(): 861 with np.errstate(all="raise"): 862 # Binary classification 863 y_true = [0, 1] 864 y_score = [0, 1] 865 p, r, _ = precision_recall_curve(y_true, y_score) 866 auc_prc = average_precision_score(y_true, y_score) 867 assert_array_almost_equal(p, [1, 1]) 868 assert_array_almost_equal(r, [1, 0]) 869 assert_almost_equal(auc_prc, 1.0) 870 871 y_true = [0, 1] 872 y_score = [1, 0] 873 p, r, _ = precision_recall_curve(y_true, y_score) 874 auc_prc = average_precision_score(y_true, y_score) 875 assert_array_almost_equal(p, [0.5, 0.0, 1.0]) 876 assert_array_almost_equal(r, [1.0, 0.0, 0.0]) 877 # Here we are doing a terrible prediction: we are always getting 878 # it wrong, hence the average_precision_score is the accuracy at 879 # chance: 50% 880 assert_almost_equal(auc_prc, 0.5) 881 882 y_true = [1, 0] 883 y_score = [1, 1] 884 p, r, _ = precision_recall_curve(y_true, y_score) 885 auc_prc = average_precision_score(y_true, y_score) 886 assert_array_almost_equal(p, [0.5, 1]) 887 assert_array_almost_equal(r, [1.0, 0]) 888 assert_almost_equal(auc_prc, 0.5) 889 890 y_true = [1, 0] 891 y_score = [1, 0] 892 p, r, _ = precision_recall_curve(y_true, y_score) 893 auc_prc = average_precision_score(y_true, y_score) 894 assert_array_almost_equal(p, [1, 1]) 895 assert_array_almost_equal(r, [1, 0]) 896 assert_almost_equal(auc_prc, 1.0) 897 898 y_true = [1, 0] 899 y_score = [0.5, 0.5] 900 p, r, _ = precision_recall_curve(y_true, y_score) 901 auc_prc = average_precision_score(y_true, y_score) 902 assert_array_almost_equal(p, [0.5, 1]) 903 assert_array_almost_equal(r, [1, 0.0]) 904 assert_almost_equal(auc_prc, 0.5) 905 906 y_true = [0, 0] 907 y_score = [0.25, 0.75] 908 with pytest.raises(Exception): 909 precision_recall_curve(y_true, y_score) 910 with pytest.raises(Exception): 911 average_precision_score(y_true, y_score) 912 913 y_true = [1, 1] 914 y_score = [0.25, 0.75] 915 p, r, _ = precision_recall_curve(y_true, y_score) 916 assert_almost_equal(average_precision_score(y_true, y_score), 1.0) 917 assert_array_almost_equal(p, [1.0, 1.0, 1.0]) 918 assert_array_almost_equal(r, [1, 0.5, 0.0]) 919 920 # Multi-label classification task 921 y_true = np.array([[0, 1], [0, 1]]) 922 y_score = np.array([[0, 1], [0, 1]]) 923 with pytest.raises(Exception): 924 average_precision_score(y_true, y_score, average="macro") 925 with pytest.raises(Exception): 926 average_precision_score(y_true, y_score, average="weighted") 927 assert_almost_equal( 928 average_precision_score(y_true, y_score, average="samples"), 1.0 929 ) 930 assert_almost_equal( 931 average_precision_score(y_true, y_score, average="micro"), 1.0 932 ) 933 934 y_true = np.array([[0, 1], [0, 1]]) 935 y_score = np.array([[0, 1], [1, 0]]) 936 with pytest.raises(Exception): 937 average_precision_score(y_true, y_score, average="macro") 938 with pytest.raises(Exception): 939 average_precision_score(y_true, y_score, average="weighted") 940 assert_almost_equal( 941 average_precision_score(y_true, y_score, average="samples"), 0.75 942 ) 943 assert_almost_equal( 944 average_precision_score(y_true, y_score, average="micro"), 0.5 945 ) 946 947 y_true = np.array([[1, 0], [0, 1]]) 948 y_score = np.array([[0, 1], [1, 0]]) 949 assert_almost_equal( 950 average_precision_score(y_true, y_score, average="macro"), 0.5 951 ) 952 assert_almost_equal( 953 average_precision_score(y_true, y_score, average="weighted"), 0.5 954 ) 955 assert_almost_equal( 956 average_precision_score(y_true, y_score, average="samples"), 0.5 957 ) 958 assert_almost_equal( 959 average_precision_score(y_true, y_score, average="micro"), 0.5 960 ) 961 962 y_true = np.array([[1, 0], [0, 1]]) 963 y_score = np.array([[0.5, 0.5], [0.5, 0.5]]) 964 assert_almost_equal( 965 average_precision_score(y_true, y_score, average="macro"), 0.5 966 ) 967 assert_almost_equal( 968 average_precision_score(y_true, y_score, average="weighted"), 0.5 969 ) 970 assert_almost_equal( 971 average_precision_score(y_true, y_score, average="samples"), 0.5 972 ) 973 assert_almost_equal( 974 average_precision_score(y_true, y_score, average="micro"), 0.5 975 ) 976 977 with np.errstate(all="ignore"): 978 # if one class is never present weighted should not be NaN 979 y_true = np.array([[0, 0], [0, 1]]) 980 y_score = np.array([[0, 0], [0, 1]]) 981 assert_almost_equal( 982 average_precision_score(y_true, y_score, average="weighted"), 1 983 ) 984 985 986def test_average_precision_constant_values(): 987 # Check the average_precision_score of a constant predictor is 988 # the TPR 989 990 # Generate a dataset with 25% of positives 991 y_true = np.zeros(100, dtype=int) 992 y_true[::4] = 1 993 # And a constant score 994 y_score = np.ones(100) 995 # The precision is then the fraction of positive whatever the recall 996 # is, as there is only one threshold: 997 assert average_precision_score(y_true, y_score) == 0.25 998 999 1000def test_average_precision_score_pos_label_errors(): 1001 # Raise an error when pos_label is not in binary y_true 1002 y_true = np.array([0, 1]) 1003 y_pred = np.array([0, 1]) 1004 err_msg = r"pos_label=2 is not a valid label. It should be one of \[0, 1\]" 1005 with pytest.raises(ValueError, match=err_msg): 1006 average_precision_score(y_true, y_pred, pos_label=2) 1007 # Raise an error for multilabel-indicator y_true with 1008 # pos_label other than 1 1009 y_true = np.array([[1, 0], [0, 1], [0, 1], [1, 0]]) 1010 y_pred = np.array([[0.9, 0.1], [0.1, 0.9], [0.8, 0.2], [0.2, 0.8]]) 1011 err_msg = ( 1012 "Parameter pos_label is fixed to 1 for multilabel-indicator y_true. " 1013 "Do not set pos_label or set pos_label to 1." 1014 ) 1015 with pytest.raises(ValueError, match=err_msg): 1016 average_precision_score(y_true, y_pred, pos_label=0) 1017 1018 1019def test_score_scale_invariance(): 1020 # Test that average_precision_score and roc_auc_score are invariant by 1021 # the scaling or shifting of probabilities 1022 # This test was expanded (added scaled_down) in response to github 1023 # issue #3864 (and others), where overly aggressive rounding was causing 1024 # problems for users with very small y_score values 1025 y_true, _, y_score = make_prediction(binary=True) 1026 1027 roc_auc = roc_auc_score(y_true, y_score) 1028 roc_auc_scaled_up = roc_auc_score(y_true, 100 * y_score) 1029 roc_auc_scaled_down = roc_auc_score(y_true, 1e-6 * y_score) 1030 roc_auc_shifted = roc_auc_score(y_true, y_score - 10) 1031 assert roc_auc == roc_auc_scaled_up 1032 assert roc_auc == roc_auc_scaled_down 1033 assert roc_auc == roc_auc_shifted 1034 1035 pr_auc = average_precision_score(y_true, y_score) 1036 pr_auc_scaled_up = average_precision_score(y_true, 100 * y_score) 1037 pr_auc_scaled_down = average_precision_score(y_true, 1e-6 * y_score) 1038 pr_auc_shifted = average_precision_score(y_true, y_score - 10) 1039 assert pr_auc == pr_auc_scaled_up 1040 assert pr_auc == pr_auc_scaled_down 1041 assert pr_auc == pr_auc_shifted 1042 1043 1044@pytest.mark.parametrize( 1045 "y_true,y_score,expected_fpr,expected_fnr", 1046 [ 1047 ([0, 0, 1], [0, 0.5, 1], [0], [0]), 1048 ([0, 0, 1], [0, 0.25, 0.5], [0], [0]), 1049 ([0, 0, 1], [0.5, 0.75, 1], [0], [0]), 1050 ([0, 0, 1], [0.25, 0.5, 0.75], [0], [0]), 1051 ([0, 1, 0], [0, 0.5, 1], [0.5], [0]), 1052 ([0, 1, 0], [0, 0.25, 0.5], [0.5], [0]), 1053 ([0, 1, 0], [0.5, 0.75, 1], [0.5], [0]), 1054 ([0, 1, 0], [0.25, 0.5, 0.75], [0.5], [0]), 1055 ([0, 1, 1], [0, 0.5, 1], [0.0], [0]), 1056 ([0, 1, 1], [0, 0.25, 0.5], [0], [0]), 1057 ([0, 1, 1], [0.5, 0.75, 1], [0], [0]), 1058 ([0, 1, 1], [0.25, 0.5, 0.75], [0], [0]), 1059 ([1, 0, 0], [0, 0.5, 1], [1, 1, 0.5], [0, 1, 1]), 1060 ([1, 0, 0], [0, 0.25, 0.5], [1, 1, 0.5], [0, 1, 1]), 1061 ([1, 0, 0], [0.5, 0.75, 1], [1, 1, 0.5], [0, 1, 1]), 1062 ([1, 0, 0], [0.25, 0.5, 0.75], [1, 1, 0.5], [0, 1, 1]), 1063 ([1, 0, 1], [0, 0.5, 1], [1, 1, 0], [0, 0.5, 0.5]), 1064 ([1, 0, 1], [0, 0.25, 0.5], [1, 1, 0], [0, 0.5, 0.5]), 1065 ([1, 0, 1], [0.5, 0.75, 1], [1, 1, 0], [0, 0.5, 0.5]), 1066 ([1, 0, 1], [0.25, 0.5, 0.75], [1, 1, 0], [0, 0.5, 0.5]), 1067 ], 1068) 1069def test_det_curve_toydata(y_true, y_score, expected_fpr, expected_fnr): 1070 # Check on a batch of small examples. 1071 fpr, fnr, _ = det_curve(y_true, y_score) 1072 1073 assert_allclose(fpr, expected_fpr) 1074 assert_allclose(fnr, expected_fnr) 1075 1076 1077@pytest.mark.parametrize( 1078 "y_true,y_score,expected_fpr,expected_fnr", 1079 [ 1080 ([1, 0], [0.5, 0.5], [1], [0]), 1081 ([0, 1], [0.5, 0.5], [1], [0]), 1082 ([0, 0, 1], [0.25, 0.5, 0.5], [0.5], [0]), 1083 ([0, 1, 0], [0.25, 0.5, 0.5], [0.5], [0]), 1084 ([0, 1, 1], [0.25, 0.5, 0.5], [0], [0]), 1085 ([1, 0, 0], [0.25, 0.5, 0.5], [1], [0]), 1086 ([1, 0, 1], [0.25, 0.5, 0.5], [1], [0]), 1087 ([1, 1, 0], [0.25, 0.5, 0.5], [1], [0]), 1088 ], 1089) 1090def test_det_curve_tie_handling(y_true, y_score, expected_fpr, expected_fnr): 1091 fpr, fnr, _ = det_curve(y_true, y_score) 1092 1093 assert_allclose(fpr, expected_fpr) 1094 assert_allclose(fnr, expected_fnr) 1095 1096 1097def test_det_curve_sanity_check(): 1098 # Exactly duplicated inputs yield the same result. 1099 assert_allclose( 1100 det_curve([0, 0, 1], [0, 0.5, 1]), 1101 det_curve([0, 0, 0, 0, 1, 1], [0, 0, 0.5, 0.5, 1, 1]), 1102 ) 1103 1104 1105@pytest.mark.parametrize("y_score", [(0), (0.25), (0.5), (0.75), (1)]) 1106def test_det_curve_constant_scores(y_score): 1107 fpr, fnr, threshold = det_curve( 1108 y_true=[0, 1, 0, 1, 0, 1], y_score=np.full(6, y_score) 1109 ) 1110 1111 assert_allclose(fpr, [1]) 1112 assert_allclose(fnr, [0]) 1113 assert_allclose(threshold, [y_score]) 1114 1115 1116@pytest.mark.parametrize( 1117 "y_true", 1118 [ 1119 ([0, 0, 0, 0, 0, 1]), 1120 ([0, 0, 0, 0, 1, 1]), 1121 ([0, 0, 0, 1, 1, 1]), 1122 ([0, 0, 1, 1, 1, 1]), 1123 ([0, 1, 1, 1, 1, 1]), 1124 ], 1125) 1126def test_det_curve_perfect_scores(y_true): 1127 fpr, fnr, _ = det_curve(y_true=y_true, y_score=y_true) 1128 1129 assert_allclose(fpr, [0]) 1130 assert_allclose(fnr, [0]) 1131 1132 1133@pytest.mark.parametrize( 1134 "y_true, y_pred, err_msg", 1135 [ 1136 ([0, 1], [0, 0.5, 1], "inconsistent numbers of samples"), 1137 ([0, 1, 1], [0, 0.5], "inconsistent numbers of samples"), 1138 ([0, 0, 0], [0, 0.5, 1], "Only one class present in y_true"), 1139 ([1, 1, 1], [0, 0.5, 1], "Only one class present in y_true"), 1140 ( 1141 ["cancer", "cancer", "not cancer"], 1142 [0.2, 0.3, 0.8], 1143 "pos_label is not specified", 1144 ), 1145 ], 1146) 1147def test_det_curve_bad_input(y_true, y_pred, err_msg): 1148 # input variables with inconsistent numbers of samples 1149 with pytest.raises(ValueError, match=err_msg): 1150 det_curve(y_true, y_pred) 1151 1152 1153def test_det_curve_pos_label(): 1154 y_true = ["cancer"] * 3 + ["not cancer"] * 7 1155 y_pred_pos_not_cancer = np.array([0.1, 0.4, 0.6, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.9]) 1156 y_pred_pos_cancer = 1 - y_pred_pos_not_cancer 1157 1158 fpr_pos_cancer, fnr_pos_cancer, th_pos_cancer = det_curve( 1159 y_true, 1160 y_pred_pos_cancer, 1161 pos_label="cancer", 1162 ) 1163 fpr_pos_not_cancer, fnr_pos_not_cancer, th_pos_not_cancer = det_curve( 1164 y_true, 1165 y_pred_pos_not_cancer, 1166 pos_label="not cancer", 1167 ) 1168 1169 # check that the first threshold will change depending which label we 1170 # consider positive 1171 assert th_pos_cancer[0] == pytest.approx(0.4) 1172 assert th_pos_not_cancer[0] == pytest.approx(0.2) 1173 1174 # check for the symmetry of the fpr and fnr 1175 assert_allclose(fpr_pos_cancer, fnr_pos_not_cancer[::-1]) 1176 assert_allclose(fnr_pos_cancer, fpr_pos_not_cancer[::-1]) 1177 1178 1179def check_lrap_toy(lrap_score): 1180 # Check on several small example that it works 1181 assert_almost_equal(lrap_score([[0, 1]], [[0.25, 0.75]]), 1) 1182 assert_almost_equal(lrap_score([[0, 1]], [[0.75, 0.25]]), 1 / 2) 1183 assert_almost_equal(lrap_score([[1, 1]], [[0.75, 0.25]]), 1) 1184 1185 assert_almost_equal(lrap_score([[0, 0, 1]], [[0.25, 0.5, 0.75]]), 1) 1186 assert_almost_equal(lrap_score([[0, 1, 0]], [[0.25, 0.5, 0.75]]), 1 / 2) 1187 assert_almost_equal(lrap_score([[0, 1, 1]], [[0.25, 0.5, 0.75]]), 1) 1188 assert_almost_equal(lrap_score([[1, 0, 0]], [[0.25, 0.5, 0.75]]), 1 / 3) 1189 assert_almost_equal( 1190 lrap_score([[1, 0, 1]], [[0.25, 0.5, 0.75]]), (2 / 3 + 1 / 1) / 2 1191 ) 1192 assert_almost_equal( 1193 lrap_score([[1, 1, 0]], [[0.25, 0.5, 0.75]]), (2 / 3 + 1 / 2) / 2 1194 ) 1195 1196 assert_almost_equal(lrap_score([[0, 0, 1]], [[0.75, 0.5, 0.25]]), 1 / 3) 1197 assert_almost_equal(lrap_score([[0, 1, 0]], [[0.75, 0.5, 0.25]]), 1 / 2) 1198 assert_almost_equal( 1199 lrap_score([[0, 1, 1]], [[0.75, 0.5, 0.25]]), (1 / 2 + 2 / 3) / 2 1200 ) 1201 assert_almost_equal(lrap_score([[1, 0, 0]], [[0.75, 0.5, 0.25]]), 1) 1202 assert_almost_equal(lrap_score([[1, 0, 1]], [[0.75, 0.5, 0.25]]), (1 + 2 / 3) / 2) 1203 assert_almost_equal(lrap_score([[1, 1, 0]], [[0.75, 0.5, 0.25]]), 1) 1204 assert_almost_equal(lrap_score([[1, 1, 1]], [[0.75, 0.5, 0.25]]), 1) 1205 1206 assert_almost_equal(lrap_score([[0, 0, 1]], [[0.5, 0.75, 0.25]]), 1 / 3) 1207 assert_almost_equal(lrap_score([[0, 1, 0]], [[0.5, 0.75, 0.25]]), 1) 1208 assert_almost_equal(lrap_score([[0, 1, 1]], [[0.5, 0.75, 0.25]]), (1 + 2 / 3) / 2) 1209 assert_almost_equal(lrap_score([[1, 0, 0]], [[0.5, 0.75, 0.25]]), 1 / 2) 1210 assert_almost_equal( 1211 lrap_score([[1, 0, 1]], [[0.5, 0.75, 0.25]]), (1 / 2 + 2 / 3) / 2 1212 ) 1213 assert_almost_equal(lrap_score([[1, 1, 0]], [[0.5, 0.75, 0.25]]), 1) 1214 assert_almost_equal(lrap_score([[1, 1, 1]], [[0.5, 0.75, 0.25]]), 1) 1215 1216 # Tie handling 1217 assert_almost_equal(lrap_score([[1, 0]], [[0.5, 0.5]]), 0.5) 1218 assert_almost_equal(lrap_score([[0, 1]], [[0.5, 0.5]]), 0.5) 1219 assert_almost_equal(lrap_score([[1, 1]], [[0.5, 0.5]]), 1) 1220 1221 assert_almost_equal(lrap_score([[0, 0, 1]], [[0.25, 0.5, 0.5]]), 0.5) 1222 assert_almost_equal(lrap_score([[0, 1, 0]], [[0.25, 0.5, 0.5]]), 0.5) 1223 assert_almost_equal(lrap_score([[0, 1, 1]], [[0.25, 0.5, 0.5]]), 1) 1224 assert_almost_equal(lrap_score([[1, 0, 0]], [[0.25, 0.5, 0.5]]), 1 / 3) 1225 assert_almost_equal( 1226 lrap_score([[1, 0, 1]], [[0.25, 0.5, 0.5]]), (2 / 3 + 1 / 2) / 2 1227 ) 1228 assert_almost_equal( 1229 lrap_score([[1, 1, 0]], [[0.25, 0.5, 0.5]]), (2 / 3 + 1 / 2) / 2 1230 ) 1231 assert_almost_equal(lrap_score([[1, 1, 1]], [[0.25, 0.5, 0.5]]), 1) 1232 1233 assert_almost_equal(lrap_score([[1, 1, 0]], [[0.5, 0.5, 0.5]]), 2 / 3) 1234 1235 assert_almost_equal(lrap_score([[1, 1, 1, 0]], [[0.5, 0.5, 0.5, 0.5]]), 3 / 4) 1236 1237 1238def check_zero_or_all_relevant_labels(lrap_score): 1239 random_state = check_random_state(0) 1240 1241 for n_labels in range(2, 5): 1242 y_score = random_state.uniform(size=(1, n_labels)) 1243 y_score_ties = np.zeros_like(y_score) 1244 1245 # No relevant labels 1246 y_true = np.zeros((1, n_labels)) 1247 assert lrap_score(y_true, y_score) == 1.0 1248 assert lrap_score(y_true, y_score_ties) == 1.0 1249 1250 # Only relevant labels 1251 y_true = np.ones((1, n_labels)) 1252 assert lrap_score(y_true, y_score) == 1.0 1253 assert lrap_score(y_true, y_score_ties) == 1.0 1254 1255 # Degenerate case: only one label 1256 assert_almost_equal( 1257 lrap_score([[1], [0], [1], [0]], [[0.5], [0.5], [0.5], [0.5]]), 1.0 1258 ) 1259 1260 1261def check_lrap_error_raised(lrap_score): 1262 # Raise value error if not appropriate format 1263 with pytest.raises(ValueError): 1264 lrap_score([0, 1, 0], [0.25, 0.3, 0.2]) 1265 with pytest.raises(ValueError): 1266 lrap_score([0, 1, 2], [[0.25, 0.75, 0.0], [0.7, 0.3, 0.0], [0.8, 0.2, 0.0]]) 1267 with pytest.raises(ValueError): 1268 lrap_score( 1269 [(0), (1), (2)], [[0.25, 0.75, 0.0], [0.7, 0.3, 0.0], [0.8, 0.2, 0.0]] 1270 ) 1271 1272 # Check that y_true.shape != y_score.shape raise the proper exception 1273 with pytest.raises(ValueError): 1274 lrap_score([[0, 1], [0, 1]], [0, 1]) 1275 with pytest.raises(ValueError): 1276 lrap_score([[0, 1], [0, 1]], [[0, 1]]) 1277 with pytest.raises(ValueError): 1278 lrap_score([[0, 1], [0, 1]], [[0], [1]]) 1279 with pytest.raises(ValueError): 1280 lrap_score([[0, 1]], [[0, 1], [0, 1]]) 1281 with pytest.raises(ValueError): 1282 lrap_score([[0], [1]], [[0, 1], [0, 1]]) 1283 with pytest.raises(ValueError): 1284 lrap_score([[0, 1], [0, 1]], [[0], [1]]) 1285 1286 1287def check_lrap_only_ties(lrap_score): 1288 # Check tie handling in score 1289 # Basic check with only ties and increasing label space 1290 for n_labels in range(2, 10): 1291 y_score = np.ones((1, n_labels)) 1292 1293 # Check for growing number of consecutive relevant 1294 for n_relevant in range(1, n_labels): 1295 # Check for a bunch of positions 1296 for pos in range(n_labels - n_relevant): 1297 y_true = np.zeros((1, n_labels)) 1298 y_true[0, pos : pos + n_relevant] = 1 1299 assert_almost_equal(lrap_score(y_true, y_score), n_relevant / n_labels) 1300 1301 1302def check_lrap_without_tie_and_increasing_score(lrap_score): 1303 # Check that Label ranking average precision works for various 1304 # Basic check with increasing label space size and decreasing score 1305 for n_labels in range(2, 10): 1306 y_score = n_labels - (np.arange(n_labels).reshape((1, n_labels)) + 1) 1307 1308 # First and last 1309 y_true = np.zeros((1, n_labels)) 1310 y_true[0, 0] = 1 1311 y_true[0, -1] = 1 1312 assert_almost_equal(lrap_score(y_true, y_score), (2 / n_labels + 1) / 2) 1313 1314 # Check for growing number of consecutive relevant label 1315 for n_relevant in range(1, n_labels): 1316 # Check for a bunch of position 1317 for pos in range(n_labels - n_relevant): 1318 y_true = np.zeros((1, n_labels)) 1319 y_true[0, pos : pos + n_relevant] = 1 1320 assert_almost_equal( 1321 lrap_score(y_true, y_score), 1322 sum( 1323 (r + 1) / ((pos + r + 1) * n_relevant) 1324 for r in range(n_relevant) 1325 ), 1326 ) 1327 1328 1329def _my_lrap(y_true, y_score): 1330 """Simple implementation of label ranking average precision""" 1331 check_consistent_length(y_true, y_score) 1332 y_true = check_array(y_true) 1333 y_score = check_array(y_score) 1334 n_samples, n_labels = y_true.shape 1335 score = np.empty((n_samples,)) 1336 for i in range(n_samples): 1337 # The best rank correspond to 1. Rank higher than 1 are worse. 1338 # The best inverse ranking correspond to n_labels. 1339 unique_rank, inv_rank = np.unique(y_score[i], return_inverse=True) 1340 n_ranks = unique_rank.size 1341 rank = n_ranks - inv_rank 1342 1343 # Rank need to be corrected to take into account ties 1344 # ex: rank 1 ex aequo means that both label are rank 2. 1345 corr_rank = np.bincount(rank, minlength=n_ranks + 1).cumsum() 1346 rank = corr_rank[rank] 1347 1348 relevant = y_true[i].nonzero()[0] 1349 if relevant.size == 0 or relevant.size == n_labels: 1350 score[i] = 1 1351 continue 1352 1353 score[i] = 0.0 1354 for label in relevant: 1355 # Let's count the number of relevant label with better rank 1356 # (smaller rank). 1357 n_ranked_above = sum(rank[r] <= rank[label] for r in relevant) 1358 1359 # Weight by the rank of the actual label 1360 score[i] += n_ranked_above / rank[label] 1361 1362 score[i] /= relevant.size 1363 1364 return score.mean() 1365 1366 1367def check_alternative_lrap_implementation( 1368 lrap_score, n_classes=5, n_samples=20, random_state=0 1369): 1370 _, y_true = make_multilabel_classification( 1371 n_features=1, 1372 allow_unlabeled=False, 1373 random_state=random_state, 1374 n_classes=n_classes, 1375 n_samples=n_samples, 1376 ) 1377 1378 # Score with ties 1379 y_score = _sparse_random_matrix( 1380 n_components=y_true.shape[0], 1381 n_features=y_true.shape[1], 1382 random_state=random_state, 1383 ) 1384 1385 if hasattr(y_score, "toarray"): 1386 y_score = y_score.toarray() 1387 score_lrap = label_ranking_average_precision_score(y_true, y_score) 1388 score_my_lrap = _my_lrap(y_true, y_score) 1389 assert_almost_equal(score_lrap, score_my_lrap) 1390 1391 # Uniform score 1392 random_state = check_random_state(random_state) 1393 y_score = random_state.uniform(size=(n_samples, n_classes)) 1394 score_lrap = label_ranking_average_precision_score(y_true, y_score) 1395 score_my_lrap = _my_lrap(y_true, y_score) 1396 assert_almost_equal(score_lrap, score_my_lrap) 1397 1398 1399@pytest.mark.parametrize( 1400 "check", 1401 ( 1402 check_lrap_toy, 1403 check_lrap_without_tie_and_increasing_score, 1404 check_lrap_only_ties, 1405 check_zero_or_all_relevant_labels, 1406 ), 1407) 1408@pytest.mark.parametrize("func", (label_ranking_average_precision_score, _my_lrap)) 1409def test_label_ranking_avp(check, func): 1410 check(func) 1411 1412 1413def test_lrap_error_raised(): 1414 check_lrap_error_raised(label_ranking_average_precision_score) 1415 1416 1417@pytest.mark.parametrize("n_samples", (1, 2, 8, 20)) 1418@pytest.mark.parametrize("n_classes", (2, 5, 10)) 1419@pytest.mark.parametrize("random_state", range(1)) 1420def test_alternative_lrap_implementation(n_samples, n_classes, random_state): 1421 1422 check_alternative_lrap_implementation( 1423 label_ranking_average_precision_score, n_classes, n_samples, random_state 1424 ) 1425 1426 1427def test_lrap_sample_weighting_zero_labels(): 1428 # Degenerate sample labeling (e.g., zero labels for a sample) is a valid 1429 # special case for lrap (the sample is considered to achieve perfect 1430 # precision), but this case is not tested in test_common. 1431 # For these test samples, the APs are 0.5, 0.75, and 1.0 (default for zero 1432 # labels). 1433 y_true = np.array([[1, 0, 0, 0], [1, 0, 0, 1], [0, 0, 0, 0]], dtype=bool) 1434 y_score = np.array( 1435 [[0.3, 0.4, 0.2, 0.1], [0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1]] 1436 ) 1437 samplewise_lraps = np.array([0.5, 0.75, 1.0]) 1438 sample_weight = np.array([1.0, 1.0, 0.0]) 1439 1440 assert_almost_equal( 1441 label_ranking_average_precision_score( 1442 y_true, y_score, sample_weight=sample_weight 1443 ), 1444 np.sum(sample_weight * samplewise_lraps) / np.sum(sample_weight), 1445 ) 1446 1447 1448def test_coverage_error(): 1449 # Toy case 1450 assert_almost_equal(coverage_error([[0, 1]], [[0.25, 0.75]]), 1) 1451 assert_almost_equal(coverage_error([[0, 1]], [[0.75, 0.25]]), 2) 1452 assert_almost_equal(coverage_error([[1, 1]], [[0.75, 0.25]]), 2) 1453 assert_almost_equal(coverage_error([[0, 0]], [[0.75, 0.25]]), 0) 1454 1455 assert_almost_equal(coverage_error([[0, 0, 0]], [[0.25, 0.5, 0.75]]), 0) 1456 assert_almost_equal(coverage_error([[0, 0, 1]], [[0.25, 0.5, 0.75]]), 1) 1457 assert_almost_equal(coverage_error([[0, 1, 0]], [[0.25, 0.5, 0.75]]), 2) 1458 assert_almost_equal(coverage_error([[0, 1, 1]], [[0.25, 0.5, 0.75]]), 2) 1459 assert_almost_equal(coverage_error([[1, 0, 0]], [[0.25, 0.5, 0.75]]), 3) 1460 assert_almost_equal(coverage_error([[1, 0, 1]], [[0.25, 0.5, 0.75]]), 3) 1461 assert_almost_equal(coverage_error([[1, 1, 0]], [[0.25, 0.5, 0.75]]), 3) 1462 assert_almost_equal(coverage_error([[1, 1, 1]], [[0.25, 0.5, 0.75]]), 3) 1463 1464 assert_almost_equal(coverage_error([[0, 0, 0]], [[0.75, 0.5, 0.25]]), 0) 1465 assert_almost_equal(coverage_error([[0, 0, 1]], [[0.75, 0.5, 0.25]]), 3) 1466 assert_almost_equal(coverage_error([[0, 1, 0]], [[0.75, 0.5, 0.25]]), 2) 1467 assert_almost_equal(coverage_error([[0, 1, 1]], [[0.75, 0.5, 0.25]]), 3) 1468 assert_almost_equal(coverage_error([[1, 0, 0]], [[0.75, 0.5, 0.25]]), 1) 1469 assert_almost_equal(coverage_error([[1, 0, 1]], [[0.75, 0.5, 0.25]]), 3) 1470 assert_almost_equal(coverage_error([[1, 1, 0]], [[0.75, 0.5, 0.25]]), 2) 1471 assert_almost_equal(coverage_error([[1, 1, 1]], [[0.75, 0.5, 0.25]]), 3) 1472 1473 assert_almost_equal(coverage_error([[0, 0, 0]], [[0.5, 0.75, 0.25]]), 0) 1474 assert_almost_equal(coverage_error([[0, 0, 1]], [[0.5, 0.75, 0.25]]), 3) 1475 assert_almost_equal(coverage_error([[0, 1, 0]], [[0.5, 0.75, 0.25]]), 1) 1476 assert_almost_equal(coverage_error([[0, 1, 1]], [[0.5, 0.75, 0.25]]), 3) 1477 assert_almost_equal(coverage_error([[1, 0, 0]], [[0.5, 0.75, 0.25]]), 2) 1478 assert_almost_equal(coverage_error([[1, 0, 1]], [[0.5, 0.75, 0.25]]), 3) 1479 assert_almost_equal(coverage_error([[1, 1, 0]], [[0.5, 0.75, 0.25]]), 2) 1480 assert_almost_equal(coverage_error([[1, 1, 1]], [[0.5, 0.75, 0.25]]), 3) 1481 1482 # Non trivial case 1483 assert_almost_equal( 1484 coverage_error([[0, 1, 0], [1, 1, 0]], [[0.1, 10.0, -3], [0, 1, 3]]), 1485 (1 + 3) / 2.0, 1486 ) 1487 1488 assert_almost_equal( 1489 coverage_error( 1490 [[0, 1, 0], [1, 1, 0], [0, 1, 1]], [[0.1, 10, -3], [0, 1, 3], [0, 2, 0]] 1491 ), 1492 (1 + 3 + 3) / 3.0, 1493 ) 1494 1495 assert_almost_equal( 1496 coverage_error( 1497 [[0, 1, 0], [1, 1, 0], [0, 1, 1]], [[0.1, 10, -3], [3, 1, 3], [0, 2, 0]] 1498 ), 1499 (1 + 3 + 3) / 3.0, 1500 ) 1501 1502 1503def test_coverage_tie_handling(): 1504 assert_almost_equal(coverage_error([[0, 0]], [[0.5, 0.5]]), 0) 1505 assert_almost_equal(coverage_error([[1, 0]], [[0.5, 0.5]]), 2) 1506 assert_almost_equal(coverage_error([[0, 1]], [[0.5, 0.5]]), 2) 1507 assert_almost_equal(coverage_error([[1, 1]], [[0.5, 0.5]]), 2) 1508 1509 assert_almost_equal(coverage_error([[0, 0, 0]], [[0.25, 0.5, 0.5]]), 0) 1510 assert_almost_equal(coverage_error([[0, 0, 1]], [[0.25, 0.5, 0.5]]), 2) 1511 assert_almost_equal(coverage_error([[0, 1, 0]], [[0.25, 0.5, 0.5]]), 2) 1512 assert_almost_equal(coverage_error([[0, 1, 1]], [[0.25, 0.5, 0.5]]), 2) 1513 assert_almost_equal(coverage_error([[1, 0, 0]], [[0.25, 0.5, 0.5]]), 3) 1514 assert_almost_equal(coverage_error([[1, 0, 1]], [[0.25, 0.5, 0.5]]), 3) 1515 assert_almost_equal(coverage_error([[1, 1, 0]], [[0.25, 0.5, 0.5]]), 3) 1516 assert_almost_equal(coverage_error([[1, 1, 1]], [[0.25, 0.5, 0.5]]), 3) 1517 1518 1519def test_label_ranking_loss(): 1520 assert_almost_equal(label_ranking_loss([[0, 1]], [[0.25, 0.75]]), 0) 1521 assert_almost_equal(label_ranking_loss([[0, 1]], [[0.75, 0.25]]), 1) 1522 1523 assert_almost_equal(label_ranking_loss([[0, 0, 1]], [[0.25, 0.5, 0.75]]), 0) 1524 assert_almost_equal(label_ranking_loss([[0, 1, 0]], [[0.25, 0.5, 0.75]]), 1 / 2) 1525 assert_almost_equal(label_ranking_loss([[0, 1, 1]], [[0.25, 0.5, 0.75]]), 0) 1526 assert_almost_equal(label_ranking_loss([[1, 0, 0]], [[0.25, 0.5, 0.75]]), 2 / 2) 1527 assert_almost_equal(label_ranking_loss([[1, 0, 1]], [[0.25, 0.5, 0.75]]), 1 / 2) 1528 assert_almost_equal(label_ranking_loss([[1, 1, 0]], [[0.25, 0.5, 0.75]]), 2 / 2) 1529 1530 # Undefined metrics - the ranking doesn't matter 1531 assert_almost_equal(label_ranking_loss([[0, 0]], [[0.75, 0.25]]), 0) 1532 assert_almost_equal(label_ranking_loss([[1, 1]], [[0.75, 0.25]]), 0) 1533 assert_almost_equal(label_ranking_loss([[0, 0]], [[0.5, 0.5]]), 0) 1534 assert_almost_equal(label_ranking_loss([[1, 1]], [[0.5, 0.5]]), 0) 1535 1536 assert_almost_equal(label_ranking_loss([[0, 0, 0]], [[0.5, 0.75, 0.25]]), 0) 1537 assert_almost_equal(label_ranking_loss([[1, 1, 1]], [[0.5, 0.75, 0.25]]), 0) 1538 assert_almost_equal(label_ranking_loss([[0, 0, 0]], [[0.25, 0.5, 0.5]]), 0) 1539 assert_almost_equal(label_ranking_loss([[1, 1, 1]], [[0.25, 0.5, 0.5]]), 0) 1540 1541 # Non trivial case 1542 assert_almost_equal( 1543 label_ranking_loss([[0, 1, 0], [1, 1, 0]], [[0.1, 10.0, -3], [0, 1, 3]]), 1544 (0 + 2 / 2) / 2.0, 1545 ) 1546 1547 assert_almost_equal( 1548 label_ranking_loss( 1549 [[0, 1, 0], [1, 1, 0], [0, 1, 1]], [[0.1, 10, -3], [0, 1, 3], [0, 2, 0]] 1550 ), 1551 (0 + 2 / 2 + 1 / 2) / 3.0, 1552 ) 1553 1554 assert_almost_equal( 1555 label_ranking_loss( 1556 [[0, 1, 0], [1, 1, 0], [0, 1, 1]], [[0.1, 10, -3], [3, 1, 3], [0, 2, 0]] 1557 ), 1558 (0 + 2 / 2 + 1 / 2) / 3.0, 1559 ) 1560 1561 # Sparse csr matrices 1562 assert_almost_equal( 1563 label_ranking_loss( 1564 csr_matrix(np.array([[0, 1, 0], [1, 1, 0]])), [[0.1, 10, -3], [3, 1, 3]] 1565 ), 1566 (0 + 2 / 2) / 2.0, 1567 ) 1568 1569 1570def test_ranking_appropriate_input_shape(): 1571 # Check that y_true.shape != y_score.shape raise the proper exception 1572 with pytest.raises(ValueError): 1573 label_ranking_loss([[0, 1], [0, 1]], [0, 1]) 1574 with pytest.raises(ValueError): 1575 label_ranking_loss([[0, 1], [0, 1]], [[0, 1]]) 1576 with pytest.raises(ValueError): 1577 label_ranking_loss([[0, 1], [0, 1]], [[0], [1]]) 1578 with pytest.raises(ValueError): 1579 label_ranking_loss([[0, 1]], [[0, 1], [0, 1]]) 1580 with pytest.raises(ValueError): 1581 label_ranking_loss([[0], [1]], [[0, 1], [0, 1]]) 1582 with pytest.raises(ValueError): 1583 label_ranking_loss([[0, 1], [0, 1]], [[0], [1]]) 1584 1585 1586def test_ranking_loss_ties_handling(): 1587 # Tie handling 1588 assert_almost_equal(label_ranking_loss([[1, 0]], [[0.5, 0.5]]), 1) 1589 assert_almost_equal(label_ranking_loss([[0, 1]], [[0.5, 0.5]]), 1) 1590 assert_almost_equal(label_ranking_loss([[0, 0, 1]], [[0.25, 0.5, 0.5]]), 1 / 2) 1591 assert_almost_equal(label_ranking_loss([[0, 1, 0]], [[0.25, 0.5, 0.5]]), 1 / 2) 1592 assert_almost_equal(label_ranking_loss([[0, 1, 1]], [[0.25, 0.5, 0.5]]), 0) 1593 assert_almost_equal(label_ranking_loss([[1, 0, 0]], [[0.25, 0.5, 0.5]]), 1) 1594 assert_almost_equal(label_ranking_loss([[1, 0, 1]], [[0.25, 0.5, 0.5]]), 1) 1595 assert_almost_equal(label_ranking_loss([[1, 1, 0]], [[0.25, 0.5, 0.5]]), 1) 1596 1597 1598def test_dcg_score(): 1599 _, y_true = make_multilabel_classification(random_state=0, n_classes=10) 1600 y_score = -y_true + 1 1601 _test_dcg_score_for(y_true, y_score) 1602 y_true, y_score = np.random.RandomState(0).random_sample((2, 100, 10)) 1603 _test_dcg_score_for(y_true, y_score) 1604 1605 1606def _test_dcg_score_for(y_true, y_score): 1607 discount = np.log2(np.arange(y_true.shape[1]) + 2) 1608 ideal = _dcg_sample_scores(y_true, y_true) 1609 score = _dcg_sample_scores(y_true, y_score) 1610 assert (score <= ideal).all() 1611 assert (_dcg_sample_scores(y_true, y_true, k=5) <= ideal).all() 1612 assert ideal.shape == (y_true.shape[0],) 1613 assert score.shape == (y_true.shape[0],) 1614 assert ideal == pytest.approx((np.sort(y_true)[:, ::-1] / discount).sum(axis=1)) 1615 1616 1617def test_dcg_ties(): 1618 y_true = np.asarray([np.arange(5)]) 1619 y_score = np.zeros(y_true.shape) 1620 dcg = _dcg_sample_scores(y_true, y_score) 1621 dcg_ignore_ties = _dcg_sample_scores(y_true, y_score, ignore_ties=True) 1622 discounts = 1 / np.log2(np.arange(2, 7)) 1623 assert dcg == pytest.approx([discounts.sum() * y_true.mean()]) 1624 assert dcg_ignore_ties == pytest.approx([(discounts * y_true[:, ::-1]).sum()]) 1625 y_score[0, 3:] = 1 1626 dcg = _dcg_sample_scores(y_true, y_score) 1627 dcg_ignore_ties = _dcg_sample_scores(y_true, y_score, ignore_ties=True) 1628 assert dcg_ignore_ties == pytest.approx([(discounts * y_true[:, ::-1]).sum()]) 1629 assert dcg == pytest.approx( 1630 [ 1631 discounts[:2].sum() * y_true[0, 3:].mean() 1632 + discounts[2:].sum() * y_true[0, :3].mean() 1633 ] 1634 ) 1635 1636 1637def test_ndcg_ignore_ties_with_k(): 1638 a = np.arange(12).reshape((2, 6)) 1639 assert ndcg_score(a, a, k=3, ignore_ties=True) == pytest.approx( 1640 ndcg_score(a, a, k=3, ignore_ties=True) 1641 ) 1642 1643 1644def test_ndcg_invariant(): 1645 y_true = np.arange(70).reshape(7, 10) 1646 y_score = y_true + np.random.RandomState(0).uniform(-0.2, 0.2, size=y_true.shape) 1647 ndcg = ndcg_score(y_true, y_score) 1648 ndcg_no_ties = ndcg_score(y_true, y_score, ignore_ties=True) 1649 assert ndcg == pytest.approx(ndcg_no_ties) 1650 assert ndcg == pytest.approx(1.0) 1651 y_score += 1000 1652 assert ndcg_score(y_true, y_score) == pytest.approx(1.0) 1653 1654 1655@pytest.mark.parametrize("ignore_ties", [True, False]) 1656def test_ndcg_toy_examples(ignore_ties): 1657 y_true = 3 * np.eye(7)[:5] 1658 y_score = np.tile(np.arange(6, -1, -1), (5, 1)) 1659 y_score_noisy = y_score + np.random.RandomState(0).uniform( 1660 -0.2, 0.2, size=y_score.shape 1661 ) 1662 assert _dcg_sample_scores( 1663 y_true, y_score, ignore_ties=ignore_ties 1664 ) == pytest.approx(3 / np.log2(np.arange(2, 7))) 1665 assert _dcg_sample_scores( 1666 y_true, y_score_noisy, ignore_ties=ignore_ties 1667 ) == pytest.approx(3 / np.log2(np.arange(2, 7))) 1668 assert _ndcg_sample_scores( 1669 y_true, y_score, ignore_ties=ignore_ties 1670 ) == pytest.approx(1 / np.log2(np.arange(2, 7))) 1671 assert _dcg_sample_scores( 1672 y_true, y_score, log_base=10, ignore_ties=ignore_ties 1673 ) == pytest.approx(3 / np.log10(np.arange(2, 7))) 1674 assert ndcg_score(y_true, y_score, ignore_ties=ignore_ties) == pytest.approx( 1675 (1 / np.log2(np.arange(2, 7))).mean() 1676 ) 1677 assert dcg_score(y_true, y_score, ignore_ties=ignore_ties) == pytest.approx( 1678 (3 / np.log2(np.arange(2, 7))).mean() 1679 ) 1680 y_true = 3 * np.ones((5, 7)) 1681 expected_dcg_score = (3 / np.log2(np.arange(2, 9))).sum() 1682 assert _dcg_sample_scores( 1683 y_true, y_score, ignore_ties=ignore_ties 1684 ) == pytest.approx(expected_dcg_score * np.ones(5)) 1685 assert _ndcg_sample_scores( 1686 y_true, y_score, ignore_ties=ignore_ties 1687 ) == pytest.approx(np.ones(5)) 1688 assert dcg_score(y_true, y_score, ignore_ties=ignore_ties) == pytest.approx( 1689 expected_dcg_score 1690 ) 1691 assert ndcg_score(y_true, y_score, ignore_ties=ignore_ties) == pytest.approx(1.0) 1692 1693 1694def test_ndcg_score(): 1695 _, y_true = make_multilabel_classification(random_state=0, n_classes=10) 1696 y_score = -y_true + 1 1697 _test_ndcg_score_for(y_true, y_score) 1698 y_true, y_score = np.random.RandomState(0).random_sample((2, 100, 10)) 1699 _test_ndcg_score_for(y_true, y_score) 1700 1701 1702def _test_ndcg_score_for(y_true, y_score): 1703 ideal = _ndcg_sample_scores(y_true, y_true) 1704 score = _ndcg_sample_scores(y_true, y_score) 1705 assert (score <= ideal).all() 1706 all_zero = (y_true == 0).all(axis=1) 1707 assert ideal[~all_zero] == pytest.approx(np.ones((~all_zero).sum())) 1708 assert ideal[all_zero] == pytest.approx(np.zeros(all_zero.sum())) 1709 assert score[~all_zero] == pytest.approx( 1710 _dcg_sample_scores(y_true, y_score)[~all_zero] 1711 / _dcg_sample_scores(y_true, y_true)[~all_zero] 1712 ) 1713 assert score[all_zero] == pytest.approx(np.zeros(all_zero.sum())) 1714 assert ideal.shape == (y_true.shape[0],) 1715 assert score.shape == (y_true.shape[0],) 1716 1717 1718def test_partial_roc_auc_score(): 1719 # Check `roc_auc_score` for max_fpr != `None` 1720 y_true = np.array([0, 0, 1, 1]) 1721 assert roc_auc_score(y_true, y_true, max_fpr=1) == 1 1722 assert roc_auc_score(y_true, y_true, max_fpr=0.001) == 1 1723 with pytest.raises(ValueError): 1724 assert roc_auc_score(y_true, y_true, max_fpr=-0.1) 1725 with pytest.raises(ValueError): 1726 assert roc_auc_score(y_true, y_true, max_fpr=1.1) 1727 with pytest.raises(ValueError): 1728 assert roc_auc_score(y_true, y_true, max_fpr=0) 1729 1730 y_scores = np.array([0.1, 0, 0.1, 0.01]) 1731 roc_auc_with_max_fpr_one = roc_auc_score(y_true, y_scores, max_fpr=1) 1732 unconstrained_roc_auc = roc_auc_score(y_true, y_scores) 1733 assert roc_auc_with_max_fpr_one == unconstrained_roc_auc 1734 assert roc_auc_score(y_true, y_scores, max_fpr=0.3) == 0.5 1735 1736 y_true, y_pred, _ = make_prediction(binary=True) 1737 for max_fpr in np.linspace(1e-4, 1, 5): 1738 assert_almost_equal( 1739 roc_auc_score(y_true, y_pred, max_fpr=max_fpr), 1740 _partial_roc_auc_score(y_true, y_pred, max_fpr), 1741 ) 1742 1743 1744@pytest.mark.parametrize( 1745 "y_true, k, true_score", 1746 [ 1747 ([0, 1, 2, 3], 1, 0.25), 1748 ([0, 1, 2, 3], 2, 0.5), 1749 ([0, 1, 2, 3], 3, 0.75), 1750 ], 1751) 1752def test_top_k_accuracy_score(y_true, k, true_score): 1753 y_score = np.array( 1754 [ 1755 [0.4, 0.3, 0.2, 0.1], 1756 [0.1, 0.3, 0.4, 0.2], 1757 [0.4, 0.1, 0.2, 0.3], 1758 [0.3, 0.2, 0.4, 0.1], 1759 ] 1760 ) 1761 score = top_k_accuracy_score(y_true, y_score, k=k) 1762 assert score == pytest.approx(true_score) 1763 1764 1765@pytest.mark.parametrize( 1766 "y_score, k, true_score", 1767 [ 1768 (np.array([-1, -1, 1, 1]), 1, 1), 1769 (np.array([-1, 1, -1, 1]), 1, 0.5), 1770 (np.array([-1, 1, -1, 1]), 2, 1), 1771 (np.array([0.2, 0.2, 0.7, 0.7]), 1, 1), 1772 (np.array([0.2, 0.7, 0.2, 0.7]), 1, 0.5), 1773 (np.array([0.2, 0.7, 0.2, 0.7]), 2, 1), 1774 ], 1775) 1776def test_top_k_accuracy_score_binary(y_score, k, true_score): 1777 y_true = [0, 0, 1, 1] 1778 1779 threshold = 0.5 if y_score.min() >= 0 and y_score.max() <= 1 else 0 1780 y_pred = (y_score > threshold).astype(np.int64) if k == 1 else y_true 1781 1782 score = top_k_accuracy_score(y_true, y_score, k=k) 1783 score_acc = accuracy_score(y_true, y_pred) 1784 1785 assert score == score_acc == pytest.approx(true_score) 1786 1787 1788@pytest.mark.parametrize( 1789 "y_true, true_score, labels", 1790 [ 1791 (np.array([0, 1, 1, 2]), 0.75, [0, 1, 2, 3]), 1792 (np.array([0, 1, 1, 1]), 0.5, [0, 1, 2, 3]), 1793 (np.array([1, 1, 1, 1]), 0.5, [0, 1, 2, 3]), 1794 (np.array(["a", "e", "e", "a"]), 0.75, ["a", "b", "d", "e"]), 1795 ], 1796) 1797@pytest.mark.parametrize("labels_as_ndarray", [True, False]) 1798def test_top_k_accuracy_score_multiclass_with_labels( 1799 y_true, true_score, labels, labels_as_ndarray 1800): 1801 """Test when labels and y_score are multiclass.""" 1802 if labels_as_ndarray: 1803 labels = np.asarray(labels) 1804 y_score = np.array( 1805 [ 1806 [0.4, 0.3, 0.2, 0.1], 1807 [0.1, 0.3, 0.4, 0.2], 1808 [0.4, 0.1, 0.2, 0.3], 1809 [0.3, 0.2, 0.4, 0.1], 1810 ] 1811 ) 1812 1813 score = top_k_accuracy_score(y_true, y_score, k=2, labels=labels) 1814 assert score == pytest.approx(true_score) 1815 1816 1817def test_top_k_accuracy_score_increasing(): 1818 # Make sure increasing k leads to a higher score 1819 X, y = datasets.make_classification( 1820 n_classes=10, n_samples=1000, n_informative=10, random_state=0 1821 ) 1822 1823 X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) 1824 1825 clf = LogisticRegression(random_state=0) 1826 clf.fit(X_train, y_train) 1827 1828 for X, y in zip((X_train, X_test), (y_train, y_test)): 1829 scores = [ 1830 top_k_accuracy_score(y, clf.predict_proba(X), k=k) for k in range(2, 10) 1831 ] 1832 1833 assert np.all(np.diff(scores) > 0) 1834 1835 1836@pytest.mark.parametrize( 1837 "y_true, k, true_score", 1838 [ 1839 ([0, 1, 2, 3], 1, 0.25), 1840 ([0, 1, 2, 3], 2, 0.5), 1841 ([0, 1, 2, 3], 3, 1), 1842 ], 1843) 1844def test_top_k_accuracy_score_ties(y_true, k, true_score): 1845 # Make sure highest indices labels are chosen first in case of ties 1846 y_score = np.array( 1847 [ 1848 [5, 5, 7, 0], 1849 [1, 5, 5, 5], 1850 [0, 0, 3, 3], 1851 [1, 1, 1, 1], 1852 ] 1853 ) 1854 assert top_k_accuracy_score(y_true, y_score, k=k) == pytest.approx(true_score) 1855 1856 1857@pytest.mark.parametrize( 1858 "y_true, k", 1859 [ 1860 ([0, 1, 2, 3], 4), 1861 ([0, 1, 2, 3], 5), 1862 ], 1863) 1864def test_top_k_accuracy_score_warning(y_true, k): 1865 y_score = np.array( 1866 [ 1867 [0.4, 0.3, 0.2, 0.1], 1868 [0.1, 0.4, 0.3, 0.2], 1869 [0.2, 0.1, 0.4, 0.3], 1870 [0.3, 0.2, 0.1, 0.4], 1871 ] 1872 ) 1873 expected_message = ( 1874 r"'k' \(\d+\) greater than or equal to 'n_classes' \(\d+\) will result in a " 1875 "perfect score and is therefore meaningless." 1876 ) 1877 with pytest.warns(UndefinedMetricWarning, match=expected_message): 1878 score = top_k_accuracy_score(y_true, y_score, k=k) 1879 assert score == 1 1880 1881 1882@pytest.mark.parametrize( 1883 "y_true, labels, msg", 1884 [ 1885 ( 1886 [0, 0.57, 1, 2], 1887 None, 1888 "y type must be 'binary' or 'multiclass', got 'continuous'", 1889 ), 1890 ( 1891 [0, 1, 2, 3], 1892 None, 1893 r"Number of classes in 'y_true' \(4\) not equal to the number of " 1894 r"classes in 'y_score' \(3\).", 1895 ), 1896 ( 1897 ["c", "c", "a", "b"], 1898 ["a", "b", "c", "c"], 1899 "Parameter 'labels' must be unique.", 1900 ), 1901 (["c", "c", "a", "b"], ["a", "c", "b"], "Parameter 'labels' must be ordered."), 1902 ( 1903 [0, 0, 1, 2], 1904 [0, 1, 2, 3], 1905 r"Number of given labels \(4\) not equal to the number of classes in " 1906 r"'y_score' \(3\).", 1907 ), 1908 ( 1909 [0, 0, 1, 2], 1910 [0, 1, 3], 1911 "'y_true' contains labels not in parameter 'labels'.", 1912 ), 1913 ], 1914) 1915def test_top_k_accuracy_score_error(y_true, labels, msg): 1916 y_score = np.array( 1917 [ 1918 [0.2, 0.1, 0.7], 1919 [0.4, 0.3, 0.3], 1920 [0.3, 0.4, 0.3], 1921 [0.4, 0.5, 0.1], 1922 ] 1923 ) 1924 with pytest.raises(ValueError, match=msg): 1925 top_k_accuracy_score(y_true, y_score, k=2, labels=labels) 1926