1"""Tests for input validation functions"""
2
3import numbers
4import warnings
5import os
6import re
7
8from tempfile import NamedTemporaryFile
9from itertools import product
10from operator import itemgetter
11
12import pytest
13from pytest import importorskip
14import numpy as np
15import scipy.sparse as sp
16
17from sklearn.utils._testing import assert_no_warnings
18from sklearn.utils._testing import ignore_warnings
19from sklearn.utils._testing import SkipTest
20from sklearn.utils._testing import assert_array_equal
21from sklearn.utils._testing import assert_allclose_dense_sparse
22from sklearn.utils._testing import assert_allclose
23from sklearn.utils._testing import _convert_container
24from sklearn.utils import as_float_array, check_array, check_symmetric
25from sklearn.utils import check_X_y
26from sklearn.utils import deprecated
27from sklearn.utils._mocking import MockDataFrame
28from sklearn.utils.fixes import parse_version
29from sklearn.utils.estimator_checks import _NotAnArray
30from sklearn.random_projection import _sparse_random_matrix
31from sklearn.linear_model import ARDRegression
32from sklearn.neighbors import KNeighborsClassifier
33from sklearn.ensemble import RandomForestRegressor
34from sklearn.svm import SVR
35from sklearn.datasets import make_blobs
36from sklearn.utils import _safe_indexing
37from sklearn.utils.validation import (
38    has_fit_parameter,
39    check_is_fitted,
40    check_consistent_length,
41    assert_all_finite,
42    check_memory,
43    check_non_negative,
44    _num_samples,
45    check_scalar,
46    _check_psd_eigenvalues,
47    _check_y,
48    _deprecate_positional_args,
49    _check_sample_weight,
50    _allclose_dense_sparse,
51    _num_features,
52    FLOAT_DTYPES,
53    _get_feature_names,
54    _check_feature_names_in,
55)
56from sklearn.utils.validation import _check_fit_params
57from sklearn.base import BaseEstimator
58import sklearn
59
60from sklearn.exceptions import NotFittedError, PositiveSpectrumWarning
61
62from sklearn.utils._testing import TempMemmap
63
64
65# TODO: Remove np.matrix usage in 1.2
66@pytest.mark.filterwarnings("ignore:np.matrix usage is deprecated in 1.0:FutureWarning")
67@pytest.mark.filterwarnings("ignore:the matrix subclass:PendingDeprecationWarning")
68def test_as_float_array():
69    # Test function for as_float_array
70    X = np.ones((3, 10), dtype=np.int32)
71    X = X + np.arange(10, dtype=np.int32)
72    X2 = as_float_array(X, copy=False)
73    assert X2.dtype == np.float32
74    # Another test
75    X = X.astype(np.int64)
76    X2 = as_float_array(X, copy=True)
77    # Checking that the array wasn't overwritten
78    assert as_float_array(X, copy=False) is not X
79    assert X2.dtype == np.float64
80    # Test int dtypes <= 32bit
81    tested_dtypes = [bool, np.int8, np.int16, np.int32, np.uint8, np.uint16, np.uint32]
82    for dtype in tested_dtypes:
83        X = X.astype(dtype)
84        X2 = as_float_array(X)
85        assert X2.dtype == np.float32
86
87    # Test object dtype
88    X = X.astype(object)
89    X2 = as_float_array(X, copy=True)
90    assert X2.dtype == np.float64
91
92    # Here, X is of the right type, it shouldn't be modified
93    X = np.ones((3, 2), dtype=np.float32)
94    assert as_float_array(X, copy=False) is X
95    # Test that if X is fortran ordered it stays
96    X = np.asfortranarray(X)
97    assert np.isfortran(as_float_array(X, copy=True))
98
99    # Test the copy parameter with some matrices
100    matrices = [
101        np.matrix(np.arange(5)),
102        sp.csc_matrix(np.arange(5)).toarray(),
103        _sparse_random_matrix(10, 10, density=0.10).toarray(),
104    ]
105    for M in matrices:
106        N = as_float_array(M, copy=True)
107        N[0, 0] = np.nan
108        assert not np.isnan(M).any()
109
110
111@pytest.mark.parametrize("X", [(np.random.random((10, 2))), (sp.rand(10, 2).tocsr())])
112def test_as_float_array_nan(X):
113    X[5, 0] = np.nan
114    X[6, 1] = np.nan
115    X_converted = as_float_array(X, force_all_finite="allow-nan")
116    assert_allclose_dense_sparse(X_converted, X)
117
118
119# TODO: Remove np.matrix usage in 1.2
120@pytest.mark.filterwarnings("ignore:np.matrix usage is deprecated in 1.0:FutureWarning")
121@pytest.mark.filterwarnings("ignore:the matrix subclass:PendingDeprecationWarning")
122def test_np_matrix():
123    # Confirm that input validation code does not return np.matrix
124    X = np.arange(12).reshape(3, 4)
125
126    assert not isinstance(as_float_array(X), np.matrix)
127    assert not isinstance(as_float_array(np.matrix(X)), np.matrix)
128    assert not isinstance(as_float_array(sp.csc_matrix(X)), np.matrix)
129
130
131def test_memmap():
132    # Confirm that input validation code doesn't copy memory mapped arrays
133
134    asflt = lambda x: as_float_array(x, copy=False)
135
136    with NamedTemporaryFile(prefix="sklearn-test") as tmp:
137        M = np.memmap(tmp, shape=(10, 10), dtype=np.float32)
138        M[:] = 0
139
140        for f in (check_array, np.asarray, asflt):
141            X = f(M)
142            X[:] = 1
143            assert_array_equal(X.ravel(), M.ravel())
144            X[:] = 0
145
146
147def test_ordering():
148    # Check that ordering is enforced correctly by validation utilities.
149    # We need to check each validation utility, because a 'copy' without
150    # 'order=K' will kill the ordering.
151    X = np.ones((10, 5))
152    for A in X, X.T:
153        for copy in (True, False):
154            B = check_array(A, order="C", copy=copy)
155            assert B.flags["C_CONTIGUOUS"]
156            B = check_array(A, order="F", copy=copy)
157            assert B.flags["F_CONTIGUOUS"]
158            if copy:
159                assert A is not B
160
161    X = sp.csr_matrix(X)
162    X.data = X.data[::-1]
163    assert not X.data.flags["C_CONTIGUOUS"]
164
165
166@pytest.mark.parametrize(
167    "value, force_all_finite", [(np.inf, False), (np.nan, "allow-nan"), (np.nan, False)]
168)
169@pytest.mark.parametrize("retype", [np.asarray, sp.csr_matrix])
170def test_check_array_force_all_finite_valid(value, force_all_finite, retype):
171    X = retype(np.arange(4).reshape(2, 2).astype(float))
172    X[0, 0] = value
173    X_checked = check_array(X, force_all_finite=force_all_finite, accept_sparse=True)
174    assert_allclose_dense_sparse(X, X_checked)
175
176
177@pytest.mark.parametrize(
178    "value, force_all_finite, match_msg",
179    [
180        (np.inf, True, "Input contains NaN, infinity"),
181        (np.inf, "allow-nan", "Input contains infinity"),
182        (np.nan, True, "Input contains NaN, infinity"),
183        (np.nan, "allow-inf", 'force_all_finite should be a bool or "allow-nan"'),
184        (np.nan, 1, "Input contains NaN, infinity"),
185    ],
186)
187@pytest.mark.parametrize("retype", [np.asarray, sp.csr_matrix])
188def test_check_array_force_all_finiteinvalid(
189    value, force_all_finite, match_msg, retype
190):
191    X = retype(np.arange(4).reshape(2, 2).astype(float))
192    X[0, 0] = value
193    with pytest.raises(ValueError, match=match_msg):
194        check_array(X, force_all_finite=force_all_finite, accept_sparse=True)
195
196
197def test_check_array_force_all_finite_object():
198    X = np.array([["a", "b", np.nan]], dtype=object).T
199
200    X_checked = check_array(X, dtype=None, force_all_finite="allow-nan")
201    assert X is X_checked
202
203    X_checked = check_array(X, dtype=None, force_all_finite=False)
204    assert X is X_checked
205
206    with pytest.raises(ValueError, match="Input contains NaN"):
207        check_array(X, dtype=None, force_all_finite=True)
208
209
210@pytest.mark.parametrize(
211    "X, err_msg",
212    [
213        (
214            np.array([[1, np.nan]]),
215            "Input contains NaN, infinity or a value too large for.*int",
216        ),
217        (
218            np.array([[1, np.nan]]),
219            "Input contains NaN, infinity or a value too large for.*int",
220        ),
221        (
222            np.array([[1, np.inf]]),
223            "Input contains NaN, infinity or a value too large for.*int",
224        ),
225        (np.array([[1, np.nan]], dtype=object), "cannot convert float NaN to integer"),
226    ],
227)
228@pytest.mark.parametrize("force_all_finite", [True, False])
229def test_check_array_force_all_finite_object_unsafe_casting(
230    X, err_msg, force_all_finite
231):
232    # casting a float array containing NaN or inf to int dtype should
233    # raise an error irrespective of the force_all_finite parameter.
234    with pytest.raises(ValueError, match=err_msg):
235        check_array(X, dtype=int, force_all_finite=force_all_finite)
236
237
238@ignore_warnings
239def test_check_array():
240    # accept_sparse == False
241    # raise error on sparse inputs
242    X = [[1, 2], [3, 4]]
243    X_csr = sp.csr_matrix(X)
244    with pytest.raises(TypeError):
245        check_array(X_csr)
246
247    # ensure_2d=False
248    X_array = check_array([0, 1, 2], ensure_2d=False)
249    assert X_array.ndim == 1
250    # ensure_2d=True with 1d array
251    with pytest.raises(ValueError, match="Expected 2D array, got 1D array instead"):
252        check_array([0, 1, 2], ensure_2d=True)
253
254    # ensure_2d=True with scalar array
255    with pytest.raises(ValueError, match="Expected 2D array, got scalar array instead"):
256        check_array(10, ensure_2d=True)
257
258    # don't allow ndim > 3
259    X_ndim = np.arange(8).reshape(2, 2, 2)
260    with pytest.raises(ValueError):
261        check_array(X_ndim)
262    check_array(X_ndim, allow_nd=True)  # doesn't raise
263
264    # dtype and order enforcement.
265    X_C = np.arange(4).reshape(2, 2).copy("C")
266    X_F = X_C.copy("F")
267    X_int = X_C.astype(int)
268    X_float = X_C.astype(float)
269    Xs = [X_C, X_F, X_int, X_float]
270    dtypes = [np.int32, int, float, np.float32, None, bool, object]
271    orders = ["C", "F", None]
272    copys = [True, False]
273
274    for X, dtype, order, copy in product(Xs, dtypes, orders, copys):
275        X_checked = check_array(X, dtype=dtype, order=order, copy=copy)
276        if dtype is not None:
277            assert X_checked.dtype == dtype
278        else:
279            assert X_checked.dtype == X.dtype
280        if order == "C":
281            assert X_checked.flags["C_CONTIGUOUS"]
282            assert not X_checked.flags["F_CONTIGUOUS"]
283        elif order == "F":
284            assert X_checked.flags["F_CONTIGUOUS"]
285            assert not X_checked.flags["C_CONTIGUOUS"]
286        if copy:
287            assert X is not X_checked
288        else:
289            # doesn't copy if it was already good
290            if (
291                X.dtype == X_checked.dtype
292                and X_checked.flags["C_CONTIGUOUS"] == X.flags["C_CONTIGUOUS"]
293                and X_checked.flags["F_CONTIGUOUS"] == X.flags["F_CONTIGUOUS"]
294            ):
295                assert X is X_checked
296
297    # allowed sparse != None
298    X_csc = sp.csc_matrix(X_C)
299    X_coo = X_csc.tocoo()
300    X_dok = X_csc.todok()
301    X_int = X_csc.astype(int)
302    X_float = X_csc.astype(float)
303
304    Xs = [X_csc, X_coo, X_dok, X_int, X_float]
305    accept_sparses = [["csr", "coo"], ["coo", "dok"]]
306    for X, dtype, accept_sparse, copy in product(Xs, dtypes, accept_sparses, copys):
307        with warnings.catch_warnings(record=True) as w:
308            X_checked = check_array(
309                X, dtype=dtype, accept_sparse=accept_sparse, copy=copy
310            )
311        if (dtype is object or sp.isspmatrix_dok(X)) and len(w):
312            # XXX unreached code as of v0.22
313            message = str(w[0].message)
314            messages = [
315                "object dtype is not supported by sparse matrices",
316                "Can't check dok sparse matrix for nan or inf.",
317            ]
318            assert message in messages
319        else:
320            assert len(w) == 0
321        if dtype is not None:
322            assert X_checked.dtype == dtype
323        else:
324            assert X_checked.dtype == X.dtype
325        if X.format in accept_sparse:
326            # no change if allowed
327            assert X.format == X_checked.format
328        else:
329            # got converted
330            assert X_checked.format == accept_sparse[0]
331        if copy:
332            assert X is not X_checked
333        else:
334            # doesn't copy if it was already good
335            if X.dtype == X_checked.dtype and X.format == X_checked.format:
336                assert X is X_checked
337
338    # other input formats
339    # convert lists to arrays
340    X_dense = check_array([[1, 2], [3, 4]])
341    assert isinstance(X_dense, np.ndarray)
342    # raise on too deep lists
343    with pytest.raises(ValueError):
344        check_array(X_ndim.tolist())
345    check_array(X_ndim.tolist(), allow_nd=True)  # doesn't raise
346
347    # convert weird stuff to arrays
348    X_no_array = _NotAnArray(X_dense)
349    result = check_array(X_no_array)
350    assert isinstance(result, np.ndarray)
351
352
353# TODO: Check for error in 1.1 when implicit conversion is removed
354@pytest.mark.parametrize(
355    "X",
356    [
357        [["1", "2"], ["3", "4"]],
358        np.array([["1", "2"], ["3", "4"]], dtype="U"),
359        np.array([["1", "2"], ["3", "4"]], dtype="S"),
360        [[b"1", b"2"], [b"3", b"4"]],
361        np.array([[b"1", b"2"], [b"3", b"4"]], dtype="V1"),
362    ],
363)
364def test_check_array_numeric_warns(X):
365    """Test that check_array warns when it converts a bytes/string into a
366    float."""
367    expected_msg = (
368        r"Arrays of bytes/strings is being converted to decimal .*"
369        r"deprecated in 0.24 and will be removed in 1.1"
370    )
371    with pytest.warns(FutureWarning, match=expected_msg):
372        check_array(X, dtype="numeric")
373
374
375# TODO: remove in 1.1
376@ignore_warnings(category=FutureWarning)
377@pytest.mark.parametrize(
378    "X",
379    [
380        [["11", "12"], ["13", "xx"]],
381        np.array([["11", "12"], ["13", "xx"]], dtype="U"),
382        np.array([["11", "12"], ["13", "xx"]], dtype="S"),
383        [[b"a", b"b"], [b"c", b"d"]],
384    ],
385)
386def test_check_array_dtype_numeric_errors(X):
387    """Error when string-ike array can not be converted"""
388    expected_warn_msg = "Unable to convert array of bytes/strings"
389    with pytest.raises(ValueError, match=expected_warn_msg):
390        check_array(X, dtype="numeric")
391
392
393@pytest.mark.parametrize("pd_dtype", ["Int8", "Int16", "UInt8", "UInt16"])
394@pytest.mark.parametrize(
395    "dtype, expected_dtype",
396    [
397        ([np.float32, np.float64], np.float32),
398        (np.float64, np.float64),
399        ("numeric", np.float64),
400    ],
401)
402def test_check_array_pandas_na_support(pd_dtype, dtype, expected_dtype):
403    # Test pandas IntegerArray with pd.NA
404    pd = pytest.importorskip("pandas", minversion="1.0")
405
406    X_np = np.array(
407        [[1, 2, 3, np.nan, np.nan], [np.nan, np.nan, 8, 4, 6], [1, 2, 3, 4, 5]]
408    ).T
409
410    # Creates dataframe with IntegerArrays with pd.NA
411    X = pd.DataFrame(X_np, dtype=pd_dtype, columns=["a", "b", "c"])
412    # column c has no nans
413    X["c"] = X["c"].astype("float")
414    X_checked = check_array(X, force_all_finite="allow-nan", dtype=dtype)
415    assert_allclose(X_checked, X_np)
416    assert X_checked.dtype == expected_dtype
417
418    X_checked = check_array(X, force_all_finite=False, dtype=dtype)
419    assert_allclose(X_checked, X_np)
420    assert X_checked.dtype == expected_dtype
421
422    msg = "Input contains NaN, infinity"
423    with pytest.raises(ValueError, match=msg):
424        check_array(X, force_all_finite=True)
425
426
427# TODO: remove test in 1.1 once this behavior is deprecated
428def test_check_array_pandas_dtype_object_conversion():
429    # test that data-frame like objects with dtype object
430    # get converted
431    X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=object)
432    X_df = MockDataFrame(X)
433    with pytest.warns(FutureWarning):
434        assert check_array(X_df).dtype.kind == "f"
435    with pytest.warns(FutureWarning):
436        assert check_array(X_df, ensure_2d=False).dtype.kind == "f"
437    # smoke-test against dataframes with column named "dtype"
438    X_df.dtype = "Hans"
439    with pytest.warns(FutureWarning):
440        assert check_array(X_df, ensure_2d=False).dtype.kind == "f"
441
442
443def test_check_array_pandas_dtype_casting():
444    # test that data-frames with homogeneous dtype are not upcast
445    pd = pytest.importorskip("pandas")
446    X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32)
447    X_df = pd.DataFrame(X)
448    assert check_array(X_df).dtype == np.float32
449    assert check_array(X_df, dtype=FLOAT_DTYPES).dtype == np.float32
450
451    X_df.iloc[:, 0] = X_df.iloc[:, 0].astype(np.float16)
452    assert_array_equal(X_df.dtypes, (np.float16, np.float32, np.float32))
453    assert check_array(X_df).dtype == np.float32
454    assert check_array(X_df, dtype=FLOAT_DTYPES).dtype == np.float32
455
456    X_df.iloc[:, 1] = X_df.iloc[:, 1].astype(np.int16)
457    # float16, int16, float32 casts to float32
458    assert check_array(X_df).dtype == np.float32
459    assert check_array(X_df, dtype=FLOAT_DTYPES).dtype == np.float32
460
461    X_df.iloc[:, 2] = X_df.iloc[:, 2].astype(np.float16)
462    # float16, int16, float16 casts to float32
463    assert check_array(X_df).dtype == np.float32
464    assert check_array(X_df, dtype=FLOAT_DTYPES).dtype == np.float32
465
466    X_df = X_df.astype(np.int16)
467    assert check_array(X_df).dtype == np.int16
468    # we're not using upcasting rules for determining
469    # the target type yet, so we cast to the default of float64
470    assert check_array(X_df, dtype=FLOAT_DTYPES).dtype == np.float64
471
472    # check that we handle pandas dtypes in a semi-reasonable way
473    # this is actually tricky because we can't really know that this
474    # should be integer ahead of converting it.
475    cat_df = pd.DataFrame({"cat_col": pd.Categorical([1, 2, 3])})
476    assert check_array(cat_df).dtype == np.int64
477    assert check_array(cat_df, dtype=FLOAT_DTYPES).dtype == np.float64
478
479
480def test_check_array_on_mock_dataframe():
481    arr = np.array([[0.2, 0.7], [0.6, 0.5], [0.4, 0.1], [0.7, 0.2]])
482    mock_df = MockDataFrame(arr)
483    checked_arr = check_array(mock_df)
484    assert checked_arr.dtype == arr.dtype
485    checked_arr = check_array(mock_df, dtype=np.float32)
486    assert checked_arr.dtype == np.dtype(np.float32)
487
488
489def test_check_array_dtype_stability():
490    # test that lists with ints don't get converted to floats
491    X = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
492    assert check_array(X).dtype.kind == "i"
493    assert check_array(X, ensure_2d=False).dtype.kind == "i"
494
495
496def test_check_array_dtype_warning():
497    X_int_list = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
498    X_float32 = np.asarray(X_int_list, dtype=np.float32)
499    X_int64 = np.asarray(X_int_list, dtype=np.int64)
500    X_csr_float32 = sp.csr_matrix(X_float32)
501    X_csc_float32 = sp.csc_matrix(X_float32)
502    X_csc_int32 = sp.csc_matrix(X_int64, dtype=np.int32)
503    integer_data = [X_int64, X_csc_int32]
504    float32_data = [X_float32, X_csr_float32, X_csc_float32]
505    for X in integer_data:
506        X_checked = assert_no_warnings(
507            check_array, X, dtype=np.float64, accept_sparse=True
508        )
509        assert X_checked.dtype == np.float64
510
511    for X in float32_data:
512        X_checked = assert_no_warnings(
513            check_array, X, dtype=[np.float64, np.float32], accept_sparse=True
514        )
515        assert X_checked.dtype == np.float32
516        assert X_checked is X
517
518        X_checked = assert_no_warnings(
519            check_array,
520            X,
521            dtype=[np.float64, np.float32],
522            accept_sparse=["csr", "dok"],
523            copy=True,
524        )
525        assert X_checked.dtype == np.float32
526        assert X_checked is not X
527
528    X_checked = assert_no_warnings(
529        check_array,
530        X_csc_float32,
531        dtype=[np.float64, np.float32],
532        accept_sparse=["csr", "dok"],
533        copy=False,
534    )
535    assert X_checked.dtype == np.float32
536    assert X_checked is not X_csc_float32
537    assert X_checked.format == "csr"
538
539
540def test_check_array_accept_sparse_type_exception():
541    X = [[1, 2], [3, 4]]
542    X_csr = sp.csr_matrix(X)
543    invalid_type = SVR()
544
545    msg = (
546        "A sparse matrix was passed, but dense data is required. "
547        r"Use X.toarray\(\) to convert to a dense numpy array."
548    )
549    with pytest.raises(TypeError, match=msg):
550        check_array(X_csr, accept_sparse=False)
551
552    msg = (
553        "Parameter 'accept_sparse' should be a string, "
554        "boolean or list of strings. You provided 'accept_sparse=.*'."
555    )
556    with pytest.raises(ValueError, match=msg):
557        check_array(X_csr, accept_sparse=invalid_type)
558
559    msg = (
560        "When providing 'accept_sparse' as a tuple or list, "
561        "it must contain at least one string value."
562    )
563    with pytest.raises(ValueError, match=msg):
564        check_array(X_csr, accept_sparse=[])
565    with pytest.raises(ValueError, match=msg):
566        check_array(X_csr, accept_sparse=())
567    with pytest.raises(TypeError, match="SVR"):
568        check_array(X_csr, accept_sparse=[invalid_type])
569
570
571def test_check_array_accept_sparse_no_exception():
572    X = [[1, 2], [3, 4]]
573    X_csr = sp.csr_matrix(X)
574
575    check_array(X_csr, accept_sparse=True)
576    check_array(X_csr, accept_sparse="csr")
577    check_array(X_csr, accept_sparse=["csr"])
578    check_array(X_csr, accept_sparse=("csr",))
579
580
581@pytest.fixture(params=["csr", "csc", "coo", "bsr"])
582def X_64bit(request):
583    X = sp.rand(20, 10, format=request.param)
584    for attr in ["indices", "indptr", "row", "col"]:
585        if hasattr(X, attr):
586            setattr(X, attr, getattr(X, attr).astype("int64"))
587    yield X
588
589
590def test_check_array_accept_large_sparse_no_exception(X_64bit):
591    # When large sparse are allowed
592    check_array(X_64bit, accept_large_sparse=True, accept_sparse=True)
593
594
595def test_check_array_accept_large_sparse_raise_exception(X_64bit):
596    # When large sparse are not allowed
597    msg = (
598        "Only sparse matrices with 32-bit integer indices "
599        "are accepted. Got int64 indices."
600    )
601    with pytest.raises(ValueError, match=msg):
602        check_array(X_64bit, accept_sparse=True, accept_large_sparse=False)
603
604
605def test_check_array_min_samples_and_features_messages():
606    # empty list is considered 2D by default:
607    msg = r"0 feature\(s\) \(shape=\(1, 0\)\) while a minimum of 1 is" " required."
608    with pytest.raises(ValueError, match=msg):
609        check_array([[]])
610
611    # If considered a 1D collection when ensure_2d=False, then the minimum
612    # number of samples will break:
613    msg = r"0 sample\(s\) \(shape=\(0,\)\) while a minimum of 1 is required."
614    with pytest.raises(ValueError, match=msg):
615        check_array([], ensure_2d=False)
616
617    # Invalid edge case when checking the default minimum sample of a scalar
618    msg = r"Singleton array array\(42\) cannot be considered a valid" " collection."
619    with pytest.raises(TypeError, match=msg):
620        check_array(42, ensure_2d=False)
621
622    # Simulate a model that would need at least 2 samples to be well defined
623    X = np.ones((1, 10))
624    y = np.ones(1)
625    msg = r"1 sample\(s\) \(shape=\(1, 10\)\) while a minimum of 2 is" " required."
626    with pytest.raises(ValueError, match=msg):
627        check_X_y(X, y, ensure_min_samples=2)
628
629    # The same message is raised if the data has 2 dimensions even if this is
630    # not mandatory
631    with pytest.raises(ValueError, match=msg):
632        check_X_y(X, y, ensure_min_samples=2, ensure_2d=False)
633
634    # Simulate a model that would require at least 3 features (e.g. SelectKBest
635    # with k=3)
636    X = np.ones((10, 2))
637    y = np.ones(2)
638    msg = r"2 feature\(s\) \(shape=\(10, 2\)\) while a minimum of 3 is" " required."
639    with pytest.raises(ValueError, match=msg):
640        check_X_y(X, y, ensure_min_features=3)
641
642    # Only the feature check is enabled whenever the number of dimensions is 2
643    # even if allow_nd is enabled:
644    with pytest.raises(ValueError, match=msg):
645        check_X_y(X, y, ensure_min_features=3, allow_nd=True)
646
647    # Simulate a case where a pipeline stage as trimmed all the features of a
648    # 2D dataset.
649    X = np.empty(0).reshape(10, 0)
650    y = np.ones(10)
651    msg = r"0 feature\(s\) \(shape=\(10, 0\)\) while a minimum of 1 is" " required."
652    with pytest.raises(ValueError, match=msg):
653        check_X_y(X, y)
654
655    # nd-data is not checked for any minimum number of features by default:
656    X = np.ones((10, 0, 28, 28))
657    y = np.ones(10)
658    X_checked, y_checked = check_X_y(X, y, allow_nd=True)
659    assert_array_equal(X, X_checked)
660    assert_array_equal(y, y_checked)
661
662
663def test_check_array_complex_data_error():
664    X = np.array([[1 + 2j, 3 + 4j, 5 + 7j], [2 + 3j, 4 + 5j, 6 + 7j]])
665    with pytest.raises(ValueError, match="Complex data not supported"):
666        check_array(X)
667
668    # list of lists
669    X = [[1 + 2j, 3 + 4j, 5 + 7j], [2 + 3j, 4 + 5j, 6 + 7j]]
670    with pytest.raises(ValueError, match="Complex data not supported"):
671        check_array(X)
672
673    # tuple of tuples
674    X = ((1 + 2j, 3 + 4j, 5 + 7j), (2 + 3j, 4 + 5j, 6 + 7j))
675    with pytest.raises(ValueError, match="Complex data not supported"):
676        check_array(X)
677
678    # list of np arrays
679    X = [np.array([1 + 2j, 3 + 4j, 5 + 7j]), np.array([2 + 3j, 4 + 5j, 6 + 7j])]
680    with pytest.raises(ValueError, match="Complex data not supported"):
681        check_array(X)
682
683    # tuple of np arrays
684    X = (np.array([1 + 2j, 3 + 4j, 5 + 7j]), np.array([2 + 3j, 4 + 5j, 6 + 7j]))
685    with pytest.raises(ValueError, match="Complex data not supported"):
686        check_array(X)
687
688    # dataframe
689    X = MockDataFrame(np.array([[1 + 2j, 3 + 4j, 5 + 7j], [2 + 3j, 4 + 5j, 6 + 7j]]))
690    with pytest.raises(ValueError, match="Complex data not supported"):
691        check_array(X)
692
693    # sparse matrix
694    X = sp.coo_matrix([[0, 1 + 2j], [0, 0]])
695    with pytest.raises(ValueError, match="Complex data not supported"):
696        check_array(X)
697
698    # target variable does not always go through check_array but should
699    # never accept complex data either.
700    y = np.array([1 + 2j, 3 + 4j, 5 + 7j, 2 + 3j, 4 + 5j, 6 + 7j])
701    with pytest.raises(ValueError, match="Complex data not supported"):
702        _check_y(y)
703
704
705def test_has_fit_parameter():
706    assert not has_fit_parameter(KNeighborsClassifier, "sample_weight")
707    assert has_fit_parameter(RandomForestRegressor, "sample_weight")
708    assert has_fit_parameter(SVR, "sample_weight")
709    assert has_fit_parameter(SVR(), "sample_weight")
710
711    class TestClassWithDeprecatedFitMethod:
712        @deprecated("Deprecated for the purpose of testing has_fit_parameter")
713        def fit(self, X, y, sample_weight=None):
714            pass
715
716    assert has_fit_parameter(
717        TestClassWithDeprecatedFitMethod, "sample_weight"
718    ), "has_fit_parameter fails for class with deprecated fit method."
719
720
721def test_check_symmetric():
722    arr_sym = np.array([[0, 1], [1, 2]])
723    arr_bad = np.ones(2)
724    arr_asym = np.array([[0, 2], [0, 2]])
725
726    test_arrays = {
727        "dense": arr_asym,
728        "dok": sp.dok_matrix(arr_asym),
729        "csr": sp.csr_matrix(arr_asym),
730        "csc": sp.csc_matrix(arr_asym),
731        "coo": sp.coo_matrix(arr_asym),
732        "lil": sp.lil_matrix(arr_asym),
733        "bsr": sp.bsr_matrix(arr_asym),
734    }
735
736    # check error for bad inputs
737    with pytest.raises(ValueError):
738        check_symmetric(arr_bad)
739
740    # check that asymmetric arrays are properly symmetrized
741    for arr_format, arr in test_arrays.items():
742        # Check for warnings and errors
743        with pytest.warns(UserWarning):
744            check_symmetric(arr)
745        with pytest.raises(ValueError):
746            check_symmetric(arr, raise_exception=True)
747
748        output = check_symmetric(arr, raise_warning=False)
749        if sp.issparse(output):
750            assert output.format == arr_format
751            assert_array_equal(output.toarray(), arr_sym)
752        else:
753            assert_array_equal(output, arr_sym)
754
755
756def test_check_is_fitted_with_is_fitted():
757    class Estimator(BaseEstimator):
758        def fit(self, **kwargs):
759            self._is_fitted = True
760            return self
761
762        def __sklearn_is_fitted__(self):
763            return hasattr(self, "_is_fitted") and self._is_fitted
764
765    with pytest.raises(NotFittedError):
766        check_is_fitted(Estimator())
767    check_is_fitted(Estimator().fit())
768
769
770def test_check_is_fitted():
771    # Check is TypeError raised when non estimator instance passed
772    with pytest.raises(TypeError):
773        check_is_fitted(ARDRegression)
774    with pytest.raises(TypeError):
775        check_is_fitted("SVR")
776
777    ard = ARDRegression()
778    svr = SVR()
779
780    try:
781        with pytest.raises(NotFittedError):
782            check_is_fitted(ard)
783        with pytest.raises(NotFittedError):
784            check_is_fitted(svr)
785    except ValueError:
786        assert False, "check_is_fitted failed with ValueError"
787
788    # NotFittedError is a subclass of both ValueError and AttributeError
789    try:
790        check_is_fitted(ard, msg="Random message %(name)s, %(name)s")
791    except ValueError as e:
792        assert str(e) == "Random message ARDRegression, ARDRegression"
793
794    try:
795        check_is_fitted(svr, msg="Another message %(name)s, %(name)s")
796    except AttributeError as e:
797        assert str(e) == "Another message SVR, SVR"
798
799    ard.fit(*make_blobs())
800    svr.fit(*make_blobs())
801
802    assert check_is_fitted(ard) is None
803    assert check_is_fitted(svr) is None
804
805
806def test_check_is_fitted_attributes():
807    class MyEstimator:
808        def fit(self, X, y):
809            return self
810
811    msg = "not fitted"
812    est = MyEstimator()
813
814    with pytest.raises(NotFittedError, match=msg):
815        check_is_fitted(est, attributes=["a_", "b_"])
816    with pytest.raises(NotFittedError, match=msg):
817        check_is_fitted(est, attributes=["a_", "b_"], all_or_any=all)
818    with pytest.raises(NotFittedError, match=msg):
819        check_is_fitted(est, attributes=["a_", "b_"], all_or_any=any)
820
821    est.a_ = "a"
822    with pytest.raises(NotFittedError, match=msg):
823        check_is_fitted(est, attributes=["a_", "b_"])
824    with pytest.raises(NotFittedError, match=msg):
825        check_is_fitted(est, attributes=["a_", "b_"], all_or_any=all)
826    check_is_fitted(est, attributes=["a_", "b_"], all_or_any=any)
827
828    est.b_ = "b"
829    check_is_fitted(est, attributes=["a_", "b_"])
830    check_is_fitted(est, attributes=["a_", "b_"], all_or_any=all)
831    check_is_fitted(est, attributes=["a_", "b_"], all_or_any=any)
832
833
834@pytest.mark.parametrize(
835    "wrap", [itemgetter(0), list, tuple], ids=["single", "list", "tuple"]
836)
837def test_check_is_fitted_with_attributes(wrap):
838    ard = ARDRegression()
839    with pytest.raises(NotFittedError, match="is not fitted yet"):
840        check_is_fitted(ard, wrap(["coef_"]))
841
842    ard.fit(*make_blobs())
843
844    # Does not raise
845    check_is_fitted(ard, wrap(["coef_"]))
846
847    # Raises when using attribute that is not defined
848    with pytest.raises(NotFittedError, match="is not fitted yet"):
849        check_is_fitted(ard, wrap(["coef_bad_"]))
850
851
852def test_check_consistent_length():
853    check_consistent_length([1], [2], [3], [4], [5])
854    check_consistent_length([[1, 2], [[1, 2]]], [1, 2], ["a", "b"])
855    check_consistent_length([1], (2,), np.array([3]), sp.csr_matrix((1, 2)))
856    with pytest.raises(ValueError, match="inconsistent numbers of samples"):
857        check_consistent_length([1, 2], [1])
858    with pytest.raises(TypeError, match=r"got <\w+ 'int'>"):
859        check_consistent_length([1, 2], 1)
860    with pytest.raises(TypeError, match=r"got <\w+ 'object'>"):
861        check_consistent_length([1, 2], object())
862
863    with pytest.raises(TypeError):
864        check_consistent_length([1, 2], np.array(1))
865
866    # Despite ensembles having __len__ they must raise TypeError
867    with pytest.raises(TypeError, match="Expected sequence or array-like"):
868        check_consistent_length([1, 2], RandomForestRegressor())
869    # XXX: We should have a test with a string, but what is correct behaviour?
870
871
872def test_check_dataframe_fit_attribute():
873    # check pandas dataframe with 'fit' column does not raise error
874    # https://github.com/scikit-learn/scikit-learn/issues/8415
875    try:
876        import pandas as pd
877
878        X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
879        X_df = pd.DataFrame(X, columns=["a", "b", "fit"])
880        check_consistent_length(X_df)
881    except ImportError:
882        raise SkipTest("Pandas not found")
883
884
885def test_suppress_validation():
886    X = np.array([0, np.inf])
887    with pytest.raises(ValueError):
888        assert_all_finite(X)
889    sklearn.set_config(assume_finite=True)
890    assert_all_finite(X)
891    sklearn.set_config(assume_finite=False)
892    with pytest.raises(ValueError):
893        assert_all_finite(X)
894
895
896def test_check_array_series():
897    # regression test that check_array works on pandas Series
898    pd = importorskip("pandas")
899    res = check_array(pd.Series([1, 2, 3]), ensure_2d=False)
900    assert_array_equal(res, np.array([1, 2, 3]))
901
902    # with categorical dtype (not a numpy dtype) (GH12699)
903    s = pd.Series(["a", "b", "c"]).astype("category")
904    res = check_array(s, dtype=None, ensure_2d=False)
905    assert_array_equal(res, np.array(["a", "b", "c"], dtype=object))
906
907
908def test_check_dataframe_mixed_float_dtypes():
909    # pandas dataframe will coerce a boolean into a object, this is a mismatch
910    # with np.result_type which will return a float
911    # check_array needs to explicitly check for bool dtype in a dataframe for
912    # this situation
913    # https://github.com/scikit-learn/scikit-learn/issues/15787
914
915    pd = importorskip("pandas")
916    df = pd.DataFrame(
917        {"int": [1, 2, 3], "float": [0, 0.1, 2.1], "bool": [True, False, True]},
918        columns=["int", "float", "bool"],
919    )
920
921    array = check_array(df, dtype=(np.float64, np.float32, np.float16))
922    expected_array = np.array(
923        [[1.0, 0.0, 1.0], [2.0, 0.1, 0.0], [3.0, 2.1, 1.0]], dtype=float
924    )
925    assert_allclose_dense_sparse(array, expected_array)
926
927
928class DummyMemory:
929    def cache(self, func):
930        return func
931
932
933class WrongDummyMemory:
934    pass
935
936
937@pytest.mark.filterwarnings("ignore:The 'cachedir' attribute")
938def test_check_memory():
939    memory = check_memory("cache_directory")
940    assert memory.cachedir == os.path.join("cache_directory", "joblib")
941    memory = check_memory(None)
942    assert memory.cachedir is None
943    dummy = DummyMemory()
944    memory = check_memory(dummy)
945    assert memory is dummy
946
947    msg = (
948        "'memory' should be None, a string or have the same interface as"
949        " joblib.Memory. Got memory='1' instead."
950    )
951    with pytest.raises(ValueError, match=msg):
952        check_memory(1)
953    dummy = WrongDummyMemory()
954    msg = (
955        "'memory' should be None, a string or have the same interface as"
956        " joblib.Memory. Got memory='{}' instead.".format(dummy)
957    )
958    with pytest.raises(ValueError, match=msg):
959        check_memory(dummy)
960
961
962@pytest.mark.parametrize("copy", [True, False])
963def test_check_array_memmap(copy):
964    X = np.ones((4, 4))
965    with TempMemmap(X, mmap_mode="r") as X_memmap:
966        X_checked = check_array(X_memmap, copy=copy)
967        assert np.may_share_memory(X_memmap, X_checked) == (not copy)
968        assert X_checked.flags["WRITEABLE"] == copy
969
970
971@pytest.mark.parametrize(
972    "retype",
973    [
974        np.asarray,
975        sp.csr_matrix,
976        sp.csc_matrix,
977        sp.coo_matrix,
978        sp.lil_matrix,
979        sp.bsr_matrix,
980        sp.dok_matrix,
981        sp.dia_matrix,
982    ],
983)
984def test_check_non_negative(retype):
985    A = np.array([[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]])
986    X = retype(A)
987    check_non_negative(X, "")
988    X = retype([[0, 0], [0, 0]])
989    check_non_negative(X, "")
990
991    A[0, 0] = -1
992    X = retype(A)
993    with pytest.raises(ValueError, match="Negative "):
994        check_non_negative(X, "")
995
996
997def test_check_X_y_informative_error():
998    X = np.ones((2, 2))
999    y = None
1000    with pytest.raises(ValueError, match="y cannot be None"):
1001        check_X_y(X, y)
1002
1003
1004def test_retrieve_samples_from_non_standard_shape():
1005    class TestNonNumericShape:
1006        def __init__(self):
1007            self.shape = ("not numeric",)
1008
1009        def __len__(self):
1010            return len([1, 2, 3])
1011
1012    X = TestNonNumericShape()
1013    assert _num_samples(X) == len(X)
1014
1015    # check that it gives a good error if there's no __len__
1016    class TestNoLenWeirdShape:
1017        def __init__(self):
1018            self.shape = ("not numeric",)
1019
1020    with pytest.raises(TypeError, match="Expected sequence or array-like"):
1021        _num_samples(TestNoLenWeirdShape())
1022
1023
1024@pytest.mark.parametrize("x", [2, 3, 2.5, 5])
1025def test_check_scalar_valid(x):
1026    """Test that check_scalar returns no error/warning if valid inputs are
1027    provided"""
1028    with pytest.warns(None) as record:
1029        scalar = check_scalar(
1030            x,
1031            "test_name",
1032            target_type=numbers.Real,
1033            min_val=2,
1034            max_val=5,
1035            include_boundaries="both",
1036        )
1037    assert len(record) == 0
1038    assert scalar == x
1039
1040
1041@pytest.mark.parametrize(
1042    "x, target_name, target_type, min_val, max_val, include_boundaries, err_msg",
1043    [
1044        (
1045            1,
1046            "test_name1",
1047            float,
1048            2,
1049            4,
1050            "neither",
1051            TypeError(
1052                "test_name1 must be an instance of <class 'float'>, not <class 'int'>."
1053            ),
1054        ),
1055        (
1056            1,
1057            "test_name2",
1058            int,
1059            2,
1060            4,
1061            "neither",
1062            ValueError("test_name2 == 1, must be > 2."),
1063        ),
1064        (
1065            5,
1066            "test_name3",
1067            int,
1068            2,
1069            4,
1070            "neither",
1071            ValueError("test_name3 == 5, must be < 4."),
1072        ),
1073        (
1074            2,
1075            "test_name4",
1076            int,
1077            2,
1078            4,
1079            "right",
1080            ValueError("test_name4 == 2, must be > 2."),
1081        ),
1082        (
1083            4,
1084            "test_name5",
1085            int,
1086            2,
1087            4,
1088            "left",
1089            ValueError("test_name5 == 4, must be < 4."),
1090        ),
1091        (
1092            4,
1093            "test_name6",
1094            int,
1095            2,
1096            4,
1097            "bad parameter value",
1098            ValueError(
1099                "Unknown value for `include_boundaries`: 'bad parameter value'. "
1100                "Possible values are: ('left', 'right', 'both', 'neither')."
1101            ),
1102        ),
1103    ],
1104)
1105def test_check_scalar_invalid(
1106    x, target_name, target_type, min_val, max_val, include_boundaries, err_msg
1107):
1108    """Test that check_scalar returns the right error if a wrong input is
1109    given"""
1110    with pytest.raises(Exception) as raised_error:
1111        check_scalar(
1112            x,
1113            target_name,
1114            target_type=target_type,
1115            min_val=min_val,
1116            max_val=max_val,
1117            include_boundaries=include_boundaries,
1118        )
1119    assert str(raised_error.value) == str(err_msg)
1120    assert type(raised_error.value) == type(err_msg)
1121
1122
1123_psd_cases_valid = {
1124    "nominal": ((1, 2), np.array([1, 2]), None, ""),
1125    "nominal_np_array": (np.array([1, 2]), np.array([1, 2]), None, ""),
1126    "insignificant_imag": (
1127        (5, 5e-5j),
1128        np.array([5, 0]),
1129        PositiveSpectrumWarning,
1130        "There are imaginary parts in eigenvalues \\(1e\\-05 of the maximum real part",
1131    ),
1132    "insignificant neg": ((5, -5e-5), np.array([5, 0]), PositiveSpectrumWarning, ""),
1133    "insignificant neg float32": (
1134        np.array([1, -1e-6], dtype=np.float32),
1135        np.array([1, 0], dtype=np.float32),
1136        PositiveSpectrumWarning,
1137        "There are negative eigenvalues \\(1e\\-06 of the maximum positive",
1138    ),
1139    "insignificant neg float64": (
1140        np.array([1, -1e-10], dtype=np.float64),
1141        np.array([1, 0], dtype=np.float64),
1142        PositiveSpectrumWarning,
1143        "There are negative eigenvalues \\(1e\\-10 of the maximum positive",
1144    ),
1145    "insignificant pos": (
1146        (5, 4e-12),
1147        np.array([5, 0]),
1148        PositiveSpectrumWarning,
1149        "the largest eigenvalue is more than 1e\\+12 times the smallest",
1150    ),
1151}
1152
1153
1154@pytest.mark.parametrize(
1155    "lambdas, expected_lambdas, w_type, w_msg",
1156    list(_psd_cases_valid.values()),
1157    ids=list(_psd_cases_valid.keys()),
1158)
1159@pytest.mark.parametrize("enable_warnings", [True, False])
1160def test_check_psd_eigenvalues_valid(
1161    lambdas, expected_lambdas, w_type, w_msg, enable_warnings
1162):
1163    # Test that ``_check_psd_eigenvalues`` returns the right output for valid
1164    # input, possibly raising the right warning
1165
1166    if not enable_warnings:
1167        w_type = None
1168        w_msg = ""
1169
1170    with pytest.warns(w_type, match=w_msg) as w:
1171        assert_array_equal(
1172            _check_psd_eigenvalues(lambdas, enable_warnings=enable_warnings),
1173            expected_lambdas,
1174        )
1175    if w_type is None:
1176        assert not w
1177
1178
1179_psd_cases_invalid = {
1180    "significant_imag": (
1181        (5, 5j),
1182        ValueError,
1183        "There are significant imaginary parts in eigenv",
1184    ),
1185    "all negative": (
1186        (-5, -1),
1187        ValueError,
1188        "All eigenvalues are negative \\(maximum is -1",
1189    ),
1190    "significant neg": (
1191        (5, -1),
1192        ValueError,
1193        "There are significant negative eigenvalues",
1194    ),
1195    "significant neg float32": (
1196        np.array([3e-4, -2e-6], dtype=np.float32),
1197        ValueError,
1198        "There are significant negative eigenvalues",
1199    ),
1200    "significant neg float64": (
1201        np.array([1e-5, -2e-10], dtype=np.float64),
1202        ValueError,
1203        "There are significant negative eigenvalues",
1204    ),
1205}
1206
1207
1208@pytest.mark.parametrize(
1209    "lambdas, err_type, err_msg",
1210    list(_psd_cases_invalid.values()),
1211    ids=list(_psd_cases_invalid.keys()),
1212)
1213def test_check_psd_eigenvalues_invalid(lambdas, err_type, err_msg):
1214    # Test that ``_check_psd_eigenvalues`` raises the right error for invalid
1215    # input
1216
1217    with pytest.raises(err_type, match=err_msg):
1218        _check_psd_eigenvalues(lambdas)
1219
1220
1221def test_check_sample_weight():
1222    # check array order
1223    sample_weight = np.ones(10)[::2]
1224    assert not sample_weight.flags["C_CONTIGUOUS"]
1225    sample_weight = _check_sample_weight(sample_weight, X=np.ones((5, 1)))
1226    assert sample_weight.flags["C_CONTIGUOUS"]
1227
1228    # check None input
1229    sample_weight = _check_sample_weight(None, X=np.ones((5, 2)))
1230    assert_allclose(sample_weight, np.ones(5))
1231
1232    # check numbers input
1233    sample_weight = _check_sample_weight(2.0, X=np.ones((5, 2)))
1234    assert_allclose(sample_weight, 2 * np.ones(5))
1235
1236    # check wrong number of dimensions
1237    with pytest.raises(ValueError, match="Sample weights must be 1D array or scalar"):
1238        _check_sample_weight(np.ones((2, 4)), X=np.ones((2, 2)))
1239
1240    # check incorrect n_samples
1241    msg = r"sample_weight.shape == \(4,\), expected \(2,\)!"
1242    with pytest.raises(ValueError, match=msg):
1243        _check_sample_weight(np.ones(4), X=np.ones((2, 2)))
1244
1245    # float32 dtype is preserved
1246    X = np.ones((5, 2))
1247    sample_weight = np.ones(5, dtype=np.float32)
1248    sample_weight = _check_sample_weight(sample_weight, X)
1249    assert sample_weight.dtype == np.float32
1250
1251    # int dtype will be converted to float64 instead
1252    X = np.ones((5, 2), dtype=int)
1253    sample_weight = _check_sample_weight(None, X, dtype=X.dtype)
1254    assert sample_weight.dtype == np.float64
1255
1256
1257@pytest.mark.parametrize("toarray", [np.array, sp.csr_matrix, sp.csc_matrix])
1258def test_allclose_dense_sparse_equals(toarray):
1259    base = np.arange(9).reshape(3, 3)
1260    x, y = toarray(base), toarray(base)
1261    assert _allclose_dense_sparse(x, y)
1262
1263
1264@pytest.mark.parametrize("toarray", [np.array, sp.csr_matrix, sp.csc_matrix])
1265def test_allclose_dense_sparse_not_equals(toarray):
1266    base = np.arange(9).reshape(3, 3)
1267    x, y = toarray(base), toarray(base + 1)
1268    assert not _allclose_dense_sparse(x, y)
1269
1270
1271@pytest.mark.parametrize("toarray", [sp.csr_matrix, sp.csc_matrix])
1272def test_allclose_dense_sparse_raise(toarray):
1273    x = np.arange(9).reshape(3, 3)
1274    y = toarray(x + 1)
1275
1276    msg = "Can only compare two sparse matrices, not a sparse matrix and an array"
1277    with pytest.raises(ValueError, match=msg):
1278        _allclose_dense_sparse(x, y)
1279
1280
1281def test_deprecate_positional_args_warns_for_function():
1282    @_deprecate_positional_args
1283    def f1(a, b, *, c=1, d=1):
1284        pass
1285
1286    with pytest.warns(FutureWarning, match=r"Pass c=3 as keyword args"):
1287        f1(1, 2, 3)
1288
1289    with pytest.warns(FutureWarning, match=r"Pass c=3, d=4 as keyword args"):
1290        f1(1, 2, 3, 4)
1291
1292    @_deprecate_positional_args
1293    def f2(a=1, *, b=1, c=1, d=1):
1294        pass
1295
1296    with pytest.warns(FutureWarning, match=r"Pass b=2 as keyword args"):
1297        f2(1, 2)
1298
1299    # The * is place before a keyword only argument without a default value
1300    @_deprecate_positional_args
1301    def f3(a, *, b, c=1, d=1):
1302        pass
1303
1304    with pytest.warns(FutureWarning, match=r"Pass b=2 as keyword args"):
1305        f3(1, 2)
1306
1307
1308def test_deprecate_positional_args_warns_for_function_version():
1309    @_deprecate_positional_args(version="1.1")
1310    def f1(a, *, b):
1311        pass
1312
1313    with pytest.warns(
1314        FutureWarning, match=r"From version 1.1 passing these as positional"
1315    ):
1316        f1(1, 2)
1317
1318
1319def test_deprecate_positional_args_warns_for_class():
1320    class A1:
1321        @_deprecate_positional_args
1322        def __init__(self, a, b, *, c=1, d=1):
1323            pass
1324
1325    with pytest.warns(FutureWarning, match=r"Pass c=3 as keyword args"):
1326        A1(1, 2, 3)
1327
1328    with pytest.warns(FutureWarning, match=r"Pass c=3, d=4 as keyword args"):
1329        A1(1, 2, 3, 4)
1330
1331    class A2:
1332        @_deprecate_positional_args
1333        def __init__(self, a=1, b=1, *, c=1, d=1):
1334            pass
1335
1336    with pytest.warns(FutureWarning, match=r"Pass c=3 as keyword args"):
1337        A2(1, 2, 3)
1338
1339    with pytest.warns(FutureWarning, match=r"Pass c=3, d=4 as keyword args"):
1340        A2(1, 2, 3, 4)
1341
1342
1343@pytest.mark.parametrize("indices", [None, [1, 3]])
1344def test_check_fit_params(indices):
1345    X = np.random.randn(4, 2)
1346    fit_params = {
1347        "list": [1, 2, 3, 4],
1348        "array": np.array([1, 2, 3, 4]),
1349        "sparse-col": sp.csc_matrix([1, 2, 3, 4]).T,
1350        "sparse-row": sp.csc_matrix([1, 2, 3, 4]),
1351        "scalar-int": 1,
1352        "scalar-str": "xxx",
1353        "None": None,
1354    }
1355    result = _check_fit_params(X, fit_params, indices)
1356    indices_ = indices if indices is not None else list(range(X.shape[0]))
1357
1358    for key in ["sparse-row", "scalar-int", "scalar-str", "None"]:
1359        assert result[key] is fit_params[key]
1360
1361    assert result["list"] == _safe_indexing(fit_params["list"], indices_)
1362    assert_array_equal(result["array"], _safe_indexing(fit_params["array"], indices_))
1363    assert_allclose_dense_sparse(
1364        result["sparse-col"], _safe_indexing(fit_params["sparse-col"], indices_)
1365    )
1366
1367
1368@pytest.mark.parametrize("sp_format", [True, "csr", "csc", "coo", "bsr"])
1369def test_check_sparse_pandas_sp_format(sp_format):
1370    # check_array converts pandas dataframe with only sparse arrays into
1371    # sparse matrix
1372    pd = pytest.importorskip("pandas", minversion="0.25.0")
1373    sp_mat = _sparse_random_matrix(10, 3)
1374
1375    sdf = pd.DataFrame.sparse.from_spmatrix(sp_mat)
1376    result = check_array(sdf, accept_sparse=sp_format)
1377
1378    if sp_format is True:
1379        # by default pandas converts to coo when accept_sparse is True
1380        sp_format = "coo"
1381
1382    assert sp.issparse(result)
1383    assert result.format == sp_format
1384    assert_allclose_dense_sparse(sp_mat, result)
1385
1386
1387@pytest.mark.parametrize(
1388    "ntype1, ntype2",
1389    [
1390        ("longdouble", "float16"),
1391        ("float16", "float32"),
1392        ("float32", "double"),
1393        ("int16", "int32"),
1394        ("int32", "long"),
1395        ("byte", "uint16"),
1396        ("ushort", "uint32"),
1397        ("uint32", "uint64"),
1398        ("uint8", "int8"),
1399    ],
1400)
1401def test_check_pandas_sparse_invalid(ntype1, ntype2):
1402    """check that we raise an error with dataframe having
1403    sparse extension arrays with unsupported mixed dtype
1404    and pandas version below 1.1. pandas versions 1.1 and
1405    above fixed this issue so no error will be raised."""
1406    pd = pytest.importorskip("pandas", minversion="0.25.0")
1407    df = pd.DataFrame(
1408        {
1409            "col1": pd.arrays.SparseArray([0, 1, 0], dtype=ntype1, fill_value=0),
1410            "col2": pd.arrays.SparseArray([1, 0, 1], dtype=ntype2, fill_value=0),
1411        }
1412    )
1413
1414    if parse_version(pd.__version__) < parse_version("1.1"):
1415        err_msg = "Pandas DataFrame with mixed sparse extension arrays"
1416        with pytest.raises(ValueError, match=err_msg):
1417            check_array(df, accept_sparse=["csr", "csc"])
1418    else:
1419        # pandas fixed this issue at 1.1 so from here on,
1420        # no error will be raised.
1421        check_array(df, accept_sparse=["csr", "csc"])
1422
1423
1424@pytest.mark.parametrize(
1425    "ntype1, ntype2, expected_subtype",
1426    [
1427        ("longfloat", "longdouble", np.floating),
1428        ("float16", "half", np.floating),
1429        ("single", "float32", np.floating),
1430        ("double", "float64", np.floating),
1431        ("int8", "byte", np.integer),
1432        ("short", "int16", np.integer),
1433        ("intc", "int32", np.integer),
1434        ("int0", "long", np.integer),
1435        ("int", "long", np.integer),
1436        ("int64", "longlong", np.integer),
1437        ("int_", "intp", np.integer),
1438        ("ubyte", "uint8", np.unsignedinteger),
1439        ("uint16", "ushort", np.unsignedinteger),
1440        ("uintc", "uint32", np.unsignedinteger),
1441        ("uint", "uint64", np.unsignedinteger),
1442        ("uintp", "ulonglong", np.unsignedinteger),
1443    ],
1444)
1445def test_check_pandas_sparse_valid(ntype1, ntype2, expected_subtype):
1446    # check that we support the conversion of sparse dataframe with mixed
1447    # type which can be converted safely.
1448    pd = pytest.importorskip("pandas", minversion="0.25.0")
1449    df = pd.DataFrame(
1450        {
1451            "col1": pd.arrays.SparseArray([0, 1, 0], dtype=ntype1, fill_value=0),
1452            "col2": pd.arrays.SparseArray([1, 0, 1], dtype=ntype2, fill_value=0),
1453        }
1454    )
1455    arr = check_array(df, accept_sparse=["csr", "csc"])
1456    assert np.issubdtype(arr.dtype, expected_subtype)
1457
1458
1459@pytest.mark.parametrize(
1460    "constructor_name",
1461    ["list", "tuple", "array", "dataframe", "sparse_csr", "sparse_csc"],
1462)
1463def test_num_features(constructor_name):
1464    """Check _num_features for array-likes."""
1465    X = [[1, 2, 3], [4, 5, 6]]
1466    X = _convert_container(X, constructor_name)
1467    assert _num_features(X) == 3
1468
1469
1470@pytest.mark.parametrize(
1471    "X",
1472    [
1473        [1, 2, 3],
1474        ["a", "b", "c"],
1475        [False, True, False],
1476        [1.0, 3.4, 4.0],
1477        [{"a": 1}, {"b": 2}, {"c": 3}],
1478    ],
1479    ids=["int", "str", "bool", "float", "dict"],
1480)
1481@pytest.mark.parametrize("constructor_name", ["list", "tuple", "array", "series"])
1482def test_num_features_errors_1d_containers(X, constructor_name):
1483    X = _convert_container(X, constructor_name)
1484    if constructor_name == "array":
1485        expected_type_name = "numpy.ndarray"
1486    elif constructor_name == "series":
1487        expected_type_name = "pandas.core.series.Series"
1488    else:
1489        expected_type_name = constructor_name
1490    message = (
1491        f"Unable to find the number of features from X of type {expected_type_name}"
1492    )
1493    if hasattr(X, "shape"):
1494        message += " with shape (3,)"
1495    elif isinstance(X[0], str):
1496        message += " where the samples are of type str"
1497    elif isinstance(X[0], dict):
1498        message += " where the samples are of type dict"
1499    with pytest.raises(TypeError, match=re.escape(message)):
1500        _num_features(X)
1501
1502
1503@pytest.mark.parametrize("X", [1, "b", False, 3.0], ids=["int", "str", "bool", "float"])
1504def test_num_features_errors_scalars(X):
1505    msg = f"Unable to find the number of features from X of type {type(X).__qualname__}"
1506    with pytest.raises(TypeError, match=msg):
1507        _num_features(X)
1508
1509
1510# TODO: Remove in 1.2
1511@pytest.mark.filterwarnings("ignore:the matrix subclass:PendingDeprecationWarning")
1512def test_check_array_deprecated_matrix():
1513    """Test that matrix support is deprecated in 1.0."""
1514
1515    X = np.matrix(np.arange(5))
1516    msg = (
1517        "np.matrix usage is deprecated in 1.0 and will raise a TypeError "
1518        "in 1.2. Please convert to a numpy array with np.asarray."
1519    )
1520    with pytest.warns(FutureWarning, match=msg):
1521        check_array(X)
1522
1523
1524@pytest.mark.parametrize(
1525    "names",
1526    [list(range(2)), range(2), None],
1527    ids=["list-int", "range", "default"],
1528)
1529def test_get_feature_names_pandas_with_ints_no_warning(names):
1530    """Get feature names with pandas dataframes with ints without warning"""
1531    pd = pytest.importorskip("pandas")
1532    X = pd.DataFrame([[1, 2], [4, 5], [5, 6]], columns=names)
1533
1534    with pytest.warns(None) as record:
1535        names = _get_feature_names(X)
1536    assert not record
1537    assert names is None
1538
1539
1540def test_get_feature_names_pandas():
1541    """Get feature names with pandas dataframes."""
1542    pd = pytest.importorskip("pandas")
1543    columns = [f"col_{i}" for i in range(3)]
1544    X = pd.DataFrame([[1, 2, 3], [4, 5, 6]], columns=columns)
1545    feature_names = _get_feature_names(X)
1546
1547    assert_array_equal(feature_names, columns)
1548
1549
1550def test_get_feature_names_numpy():
1551    """Get feature names return None for numpy arrays."""
1552    X = np.array([[1, 2, 3], [4, 5, 6]])
1553    names = _get_feature_names(X)
1554    assert names is None
1555
1556
1557# TODO: Convert to a error in 1.2
1558@pytest.mark.parametrize(
1559    "names, dtypes",
1560    [
1561        ([["a", "b"], ["c", "d"]], "['tuple']"),
1562        (["a", 1], "['int', 'str']"),
1563    ],
1564    ids=["multi-index", "mixed"],
1565)
1566def test_get_feature_names_invalid_dtypes_warns(names, dtypes):
1567    """Get feature names warns when the feature names have mixed dtypes"""
1568    pd = pytest.importorskip("pandas")
1569    X = pd.DataFrame([[1, 2], [4, 5], [5, 6]], columns=names)
1570
1571    msg = re.escape(
1572        "Feature names only support names that are all strings. "
1573        f"Got feature names with dtypes: {dtypes}. An error will be raised"
1574    )
1575    with pytest.warns(FutureWarning, match=msg):
1576        names = _get_feature_names(X)
1577    assert names is None
1578
1579
1580class PassthroughTransformer(BaseEstimator):
1581    def fit(self, X, y=None):
1582        self._validate_data(X, reset=True)
1583        return self
1584
1585    def transform(self, X):
1586        return X
1587
1588    def get_feature_names_out(self, input_features=None):
1589        return _check_feature_names_in(self, input_features)
1590
1591
1592def test_check_feature_names_in():
1593    """Check behavior of check_feature_names_in for arrays."""
1594    X = np.array([[0.0, 1.0, 2.0]])
1595    est = PassthroughTransformer().fit(X)
1596
1597    names = est.get_feature_names_out()
1598    assert_array_equal(names, ["x0", "x1", "x2"])
1599
1600    incorrect_len_names = ["x10", "x1"]
1601    with pytest.raises(ValueError, match="input_features should have length equal to"):
1602        est.get_feature_names_out(incorrect_len_names)
1603
1604    # remove n_feature_in_
1605    del est.n_features_in_
1606    with pytest.raises(ValueError, match="Unable to generate feature names"):
1607        est.get_feature_names_out()
1608
1609
1610def test_check_feature_names_in_pandas():
1611    """Check behavior of check_feature_names_in for pandas dataframes."""
1612    pd = pytest.importorskip("pandas")
1613    names = ["a", "b", "c"]
1614    df = pd.DataFrame([[0.0, 1.0, 2.0]], columns=names)
1615    est = PassthroughTransformer().fit(df)
1616
1617    names = est.get_feature_names_out()
1618    assert_array_equal(names, ["a", "b", "c"])
1619
1620    with pytest.raises(ValueError, match="input_features is not equal to"):
1621        est.get_feature_names_out(["x1", "x2", "x3"])
1622