1import os
2import shutil
3import tempfile
4import warnings
5from pickle import loads
6from pickle import dumps
7from functools import partial
8from importlib import resources
9
10import pytest
11
12import numpy as np
13from sklearn.datasets import get_data_home
14from sklearn.datasets import clear_data_home
15from sklearn.datasets import load_files
16from sklearn.datasets import load_sample_images
17from sklearn.datasets import load_sample_image
18from sklearn.datasets import load_digits
19from sklearn.datasets import load_diabetes
20from sklearn.datasets import load_linnerud
21from sklearn.datasets import load_iris
22from sklearn.datasets import load_breast_cancer
23from sklearn.datasets import load_boston
24from sklearn.datasets import load_wine
25from sklearn.datasets._base import (
26    load_csv_data,
27    load_gzip_compressed_csv_data,
28)
29from sklearn.utils import Bunch
30from sklearn.utils._testing import SkipTest
31from sklearn.datasets.tests.test_common import check_as_frame
32
33from sklearn.externals._pilutil import pillow_installed
34
35from sklearn.utils import IS_PYPY
36
37
38def _remove_dir(path):
39    if os.path.isdir(path):
40        shutil.rmtree(path)
41
42
43@pytest.fixture(scope="module")
44def data_home(tmpdir_factory):
45    tmp_file = str(tmpdir_factory.mktemp("scikit_learn_data_home_test"))
46    yield tmp_file
47    _remove_dir(tmp_file)
48
49
50@pytest.fixture(scope="module")
51def load_files_root(tmpdir_factory):
52    tmp_file = str(tmpdir_factory.mktemp("scikit_learn_load_files_test"))
53    yield tmp_file
54    _remove_dir(tmp_file)
55
56
57@pytest.fixture
58def test_category_dir_1(load_files_root):
59    test_category_dir1 = tempfile.mkdtemp(dir=load_files_root)
60    sample_file = tempfile.NamedTemporaryFile(dir=test_category_dir1, delete=False)
61    sample_file.write(b"Hello World!\n")
62    sample_file.close()
63    yield str(test_category_dir1)
64    _remove_dir(test_category_dir1)
65
66
67@pytest.fixture
68def test_category_dir_2(load_files_root):
69    test_category_dir2 = tempfile.mkdtemp(dir=load_files_root)
70    yield str(test_category_dir2)
71    _remove_dir(test_category_dir2)
72
73
74def test_data_home(data_home):
75    # get_data_home will point to a pre-existing folder
76    data_home = get_data_home(data_home=data_home)
77    assert data_home == data_home
78    assert os.path.exists(data_home)
79
80    # clear_data_home will delete both the content and the folder it-self
81    clear_data_home(data_home=data_home)
82    assert not os.path.exists(data_home)
83
84    # if the folder is missing it will be created again
85    data_home = get_data_home(data_home=data_home)
86    assert os.path.exists(data_home)
87
88
89def test_default_empty_load_files(load_files_root):
90    res = load_files(load_files_root)
91    assert len(res.filenames) == 0
92    assert len(res.target_names) == 0
93    assert res.DESCR is None
94
95
96def test_default_load_files(test_category_dir_1, test_category_dir_2, load_files_root):
97    if IS_PYPY:
98        pytest.xfail("[PyPy] fails due to string containing NUL characters")
99    res = load_files(load_files_root)
100    assert len(res.filenames) == 1
101    assert len(res.target_names) == 2
102    assert res.DESCR is None
103    assert res.data == [b"Hello World!\n"]
104
105
106def test_load_files_w_categories_desc_and_encoding(
107    test_category_dir_1, test_category_dir_2, load_files_root
108):
109    if IS_PYPY:
110        pytest.xfail("[PyPy] fails due to string containing NUL characters")
111    category = os.path.abspath(test_category_dir_1).split("/").pop()
112    res = load_files(
113        load_files_root, description="test", categories=category, encoding="utf-8"
114    )
115    assert len(res.filenames) == 1
116    assert len(res.target_names) == 1
117    assert res.DESCR == "test"
118    assert res.data == ["Hello World!\n"]
119
120
121def test_load_files_wo_load_content(
122    test_category_dir_1, test_category_dir_2, load_files_root
123):
124    res = load_files(load_files_root, load_content=False)
125    assert len(res.filenames) == 1
126    assert len(res.target_names) == 2
127    assert res.DESCR is None
128    assert res.get("data") is None
129
130
131@pytest.mark.parametrize(
132    "filename, expected_n_samples, expected_n_features, expected_target_names",
133    [
134        ("wine_data.csv", 178, 13, ["class_0", "class_1", "class_2"]),
135        ("iris.csv", 150, 4, ["setosa", "versicolor", "virginica"]),
136        ("breast_cancer.csv", 569, 30, ["malignant", "benign"]),
137    ],
138)
139def test_load_csv_data(
140    filename, expected_n_samples, expected_n_features, expected_target_names
141):
142    actual_data, actual_target, actual_target_names = load_csv_data(filename)
143    assert actual_data.shape[0] == expected_n_samples
144    assert actual_data.shape[1] == expected_n_features
145    assert actual_target.shape[0] == expected_n_samples
146    np.testing.assert_array_equal(actual_target_names, expected_target_names)
147
148
149def test_load_csv_data_with_descr():
150    data_file_name = "iris.csv"
151    descr_file_name = "iris.rst"
152
153    res_without_descr = load_csv_data(data_file_name=data_file_name)
154    res_with_descr = load_csv_data(
155        data_file_name=data_file_name, descr_file_name=descr_file_name
156    )
157    assert len(res_with_descr) == 4
158    assert len(res_without_descr) == 3
159
160    np.testing.assert_array_equal(res_with_descr[0], res_without_descr[0])
161    np.testing.assert_array_equal(res_with_descr[1], res_without_descr[1])
162    np.testing.assert_array_equal(res_with_descr[2], res_without_descr[2])
163
164    assert res_with_descr[-1].startswith(".. _iris_dataset:")
165
166
167@pytest.mark.parametrize(
168    "filename, kwargs, expected_shape",
169    [
170        ("diabetes_data.csv.gz", {}, [442, 10]),
171        ("diabetes_target.csv.gz", {}, [442]),
172        ("digits.csv.gz", {"delimiter": ","}, [1797, 65]),
173    ],
174)
175def test_load_gzip_compressed_csv_data(filename, kwargs, expected_shape):
176    actual_data = load_gzip_compressed_csv_data(filename, **kwargs)
177    assert actual_data.shape == tuple(expected_shape)
178
179
180def test_load_gzip_compressed_csv_data_with_descr():
181    data_file_name = "diabetes_target.csv.gz"
182    descr_file_name = "diabetes.rst"
183
184    expected_data = load_gzip_compressed_csv_data(data_file_name=data_file_name)
185    actual_data, descr = load_gzip_compressed_csv_data(
186        data_file_name=data_file_name,
187        descr_file_name=descr_file_name,
188    )
189
190    np.testing.assert_array_equal(actual_data, expected_data)
191    assert descr.startswith(".. _diabetes_dataset:")
192
193
194def test_load_sample_images():
195    try:
196        res = load_sample_images()
197        assert len(res.images) == 2
198        assert len(res.filenames) == 2
199        images = res.images
200
201        # assert is china image
202        assert np.all(images[0][0, 0, :] == np.array([174, 201, 231], dtype=np.uint8))
203        # assert is flower image
204        assert np.all(images[1][0, 0, :] == np.array([2, 19, 13], dtype=np.uint8))
205        assert res.DESCR
206    except ImportError:
207        warnings.warn("Could not load sample images, PIL is not available.")
208
209
210def test_load_sample_image():
211    try:
212        china = load_sample_image("china.jpg")
213        assert china.dtype == "uint8"
214        assert china.shape == (427, 640, 3)
215    except ImportError:
216        warnings.warn("Could not load sample images, PIL is not available.")
217
218
219def test_load_missing_sample_image_error():
220    if pillow_installed:
221        with pytest.raises(AttributeError):
222            load_sample_image("blop.jpg")
223    else:
224        warnings.warn("Could not load sample images, PIL is not available.")
225
226
227@pytest.mark.filterwarnings("ignore:Function load_boston is deprecated")
228@pytest.mark.parametrize(
229    "loader_func, data_shape, target_shape, n_target, has_descr, filenames",
230    [
231        (load_breast_cancer, (569, 30), (569,), 2, True, ["filename"]),
232        (load_wine, (178, 13), (178,), 3, True, []),
233        (load_iris, (150, 4), (150,), 3, True, ["filename"]),
234        (
235            load_linnerud,
236            (20, 3),
237            (20, 3),
238            3,
239            True,
240            ["data_filename", "target_filename"],
241        ),
242        (load_diabetes, (442, 10), (442,), None, True, []),
243        (load_digits, (1797, 64), (1797,), 10, True, []),
244        (partial(load_digits, n_class=9), (1617, 64), (1617,), 10, True, []),
245        (load_boston, (506, 13), (506,), None, True, ["filename"]),
246    ],
247)
248def test_loader(loader_func, data_shape, target_shape, n_target, has_descr, filenames):
249    bunch = loader_func()
250
251    assert isinstance(bunch, Bunch)
252    assert bunch.data.shape == data_shape
253    assert bunch.target.shape == target_shape
254    if hasattr(bunch, "feature_names"):
255        assert len(bunch.feature_names) == data_shape[1]
256    if n_target is not None:
257        assert len(bunch.target_names) == n_target
258    if has_descr:
259        assert bunch.DESCR
260    if filenames:
261        assert "data_module" in bunch
262        assert all(
263            [
264                f in bunch and resources.is_resource(bunch["data_module"], bunch[f])
265                for f in filenames
266            ]
267        )
268
269
270@pytest.mark.parametrize(
271    "loader_func, data_dtype, target_dtype",
272    [
273        (load_breast_cancer, np.float64, int),
274        (load_diabetes, np.float64, np.float64),
275        (load_digits, np.float64, int),
276        (load_iris, np.float64, int),
277        (load_linnerud, np.float64, np.float64),
278        (load_wine, np.float64, int),
279    ],
280)
281def test_toy_dataset_frame_dtype(loader_func, data_dtype, target_dtype):
282    default_result = loader_func()
283    check_as_frame(
284        default_result,
285        loader_func,
286        expected_data_dtype=data_dtype,
287        expected_target_dtype=target_dtype,
288    )
289
290
291def test_loads_dumps_bunch():
292    bunch = Bunch(x="x")
293    bunch_from_pkl = loads(dumps(bunch))
294    bunch_from_pkl.x = "y"
295    assert bunch_from_pkl["x"] == bunch_from_pkl.x
296
297
298def test_bunch_pickle_generated_with_0_16_and_read_with_0_17():
299    bunch = Bunch(key="original")
300    # This reproduces a problem when Bunch pickles have been created
301    # with scikit-learn 0.16 and are read with 0.17. Basically there
302    # is a surprising behaviour because reading bunch.key uses
303    # bunch.__dict__ (which is non empty for 0.16 Bunch objects)
304    # whereas assigning into bunch.key uses bunch.__setattr__. See
305    # https://github.com/scikit-learn/scikit-learn/issues/6196 for
306    # more details
307    bunch.__dict__["key"] = "set from __dict__"
308    bunch_from_pkl = loads(dumps(bunch))
309    # After loading from pickle the __dict__ should have been ignored
310    assert bunch_from_pkl.key == "original"
311    assert bunch_from_pkl["key"] == "original"
312    # Making sure that changing the attr does change the value
313    # associated with __getitem__ as well
314    bunch_from_pkl.key = "changed"
315    assert bunch_from_pkl.key == "changed"
316    assert bunch_from_pkl["key"] == "changed"
317
318
319def test_bunch_dir():
320    # check that dir (important for autocomplete) shows attributes
321    data = load_iris()
322    assert "data" in dir(data)
323
324
325# FIXME: to be removed in 1.2
326def test_load_boston_warning():
327    """Check that we raise the ethical warning when loading `load_boston`."""
328    warn_msg = "The Boston housing prices dataset has an ethical problem"
329    with pytest.warns(FutureWarning, match=warn_msg):
330        load_boston()
331
332
333@pytest.mark.filterwarnings("ignore:Function load_boston is deprecated")
334def test_load_boston_alternative():
335    pd = pytest.importorskip("pandas")
336    if os.environ.get("SKLEARN_SKIP_NETWORK_TESTS", "1") == "1":
337        raise SkipTest(
338            "This test requires an internet connection to fetch the dataset."
339        )
340
341    boston_sklearn = load_boston()
342
343    data_url = "http://lib.stat.cmu.edu/datasets/boston"
344    try:
345        raw_df = pd.read_csv(data_url, sep=r"\s+", skiprows=22, header=None)
346    except ConnectionError as e:
347        pytest.xfail(f"The dataset can't be downloaded. Got exception: {e}")
348    data = np.hstack([raw_df.values[::2, :], raw_df.values[1::2, :2]])
349    target = raw_df.values[1::2, 2]
350
351    np.testing.assert_allclose(data, boston_sklearn.data)
352    np.testing.assert_allclose(target, boston_sklearn.target)
353