1import numpy as np
2import pytest
3from scipy import sparse
4
5from numpy.testing import assert_array_equal
6from numpy.testing import assert_allclose
7
8from sklearn.datasets import load_iris
9from sklearn.utils import check_array
10from sklearn.utils import _safe_indexing
11from sklearn.utils._testing import _convert_container
12
13from sklearn.utils._mocking import CheckingClassifier
14
15
16@pytest.fixture
17def iris():
18    return load_iris(return_X_y=True)
19
20
21def _success(x):
22    return True
23
24
25def _fail(x):
26    return False
27
28
29@pytest.mark.parametrize(
30    "kwargs",
31    [
32        {},
33        {"check_X": _success},
34        {"check_y": _success},
35        {"check_X": _success, "check_y": _success},
36    ],
37)
38def test_check_on_fit_success(iris, kwargs):
39    X, y = iris
40    CheckingClassifier(**kwargs).fit(X, y)
41
42
43@pytest.mark.parametrize(
44    "kwargs",
45    [
46        {"check_X": _fail},
47        {"check_y": _fail},
48        {"check_X": _success, "check_y": _fail},
49        {"check_X": _fail, "check_y": _success},
50        {"check_X": _fail, "check_y": _fail},
51    ],
52)
53def test_check_on_fit_fail(iris, kwargs):
54    X, y = iris
55    clf = CheckingClassifier(**kwargs)
56    with pytest.raises(AssertionError):
57        clf.fit(X, y)
58
59
60@pytest.mark.parametrize(
61    "pred_func", ["predict", "predict_proba", "decision_function", "score"]
62)
63def test_check_X_on_predict_success(iris, pred_func):
64    X, y = iris
65    clf = CheckingClassifier(check_X=_success).fit(X, y)
66    getattr(clf, pred_func)(X)
67
68
69@pytest.mark.parametrize(
70    "pred_func", ["predict", "predict_proba", "decision_function", "score"]
71)
72def test_check_X_on_predict_fail(iris, pred_func):
73    X, y = iris
74    clf = CheckingClassifier(check_X=_success).fit(X, y)
75    clf.set_params(check_X=_fail)
76    with pytest.raises(AssertionError):
77        getattr(clf, pred_func)(X)
78
79
80@pytest.mark.parametrize("input_type", ["list", "array", "sparse", "dataframe"])
81def test_checking_classifier(iris, input_type):
82    # Check that the CheckingClassifier outputs what we expect
83    X, y = iris
84    X = _convert_container(X, input_type)
85    clf = CheckingClassifier()
86    clf.fit(X, y)
87
88    assert_array_equal(clf.classes_, np.unique(y))
89    assert len(clf.classes_) == 3
90    assert clf.n_features_in_ == 4
91
92    y_pred = clf.predict(X)
93    assert_array_equal(y_pred, np.zeros(y_pred.size, dtype=int))
94
95    assert clf.score(X) == pytest.approx(0)
96    clf.set_params(foo_param=10)
97    assert clf.fit(X, y).score(X) == pytest.approx(1)
98
99    y_proba = clf.predict_proba(X)
100    assert y_proba.shape == (150, 3)
101    assert_allclose(y_proba[:, 0], 1)
102    assert_allclose(y_proba[:, 1:], 0)
103
104    y_decision = clf.decision_function(X)
105    assert y_decision.shape == (150, 3)
106    assert_allclose(y_decision[:, 0], 1)
107    assert_allclose(y_decision[:, 1:], 0)
108
109    # check the shape in case of binary classification
110    first_2_classes = np.logical_or(y == 0, y == 1)
111    X = _safe_indexing(X, first_2_classes)
112    y = _safe_indexing(y, first_2_classes)
113    clf.fit(X, y)
114
115    y_proba = clf.predict_proba(X)
116    assert y_proba.shape == (100, 2)
117    assert_allclose(y_proba[:, 0], 1)
118    assert_allclose(y_proba[:, 1], 0)
119
120    y_decision = clf.decision_function(X)
121    assert y_decision.shape == (100,)
122    assert_allclose(y_decision, 0)
123
124
125def test_checking_classifier_with_params(iris):
126    X, y = iris
127    X_sparse = sparse.csr_matrix(X)
128
129    clf = CheckingClassifier(check_X=sparse.issparse)
130    with pytest.raises(AssertionError):
131        clf.fit(X, y)
132    clf.fit(X_sparse, y)
133
134    clf = CheckingClassifier(
135        check_X=check_array, check_X_params={"accept_sparse": False}
136    )
137    clf.fit(X, y)
138    with pytest.raises(TypeError, match="A sparse matrix was passed"):
139        clf.fit(X_sparse, y)
140
141
142def test_checking_classifier_fit_params(iris):
143    # check the error raised when the number of samples is not the one expected
144    X, y = iris
145    clf = CheckingClassifier(expected_fit_params=["sample_weight"])
146    sample_weight = np.ones(len(X) // 2)
147
148    with pytest.raises(AssertionError, match="Fit parameter sample_weight"):
149        clf.fit(X, y, sample_weight=sample_weight)
150
151
152def test_checking_classifier_missing_fit_params(iris):
153    X, y = iris
154    clf = CheckingClassifier(expected_fit_params=["sample_weight"])
155    with pytest.raises(AssertionError, match="Expected fit parameter"):
156        clf.fit(X, y)
157
158
159@pytest.mark.parametrize(
160    "methods_to_check",
161    [["predict"], ["predict", "predict_proba"]],
162)
163@pytest.mark.parametrize(
164    "predict_method", ["predict", "predict_proba", "decision_function", "score"]
165)
166def test_checking_classifier_methods_to_check(iris, methods_to_check, predict_method):
167    # check that methods_to_check allows to bypass checks
168    X, y = iris
169
170    clf = CheckingClassifier(
171        check_X=sparse.issparse,
172        methods_to_check=methods_to_check,
173    )
174
175    clf.fit(X, y)
176    if predict_method in methods_to_check:
177        with pytest.raises(AssertionError):
178            getattr(clf, predict_method)(X)
179    else:
180        getattr(clf, predict_method)(X)
181