1import os 2import shutil 3import tempfile 4import warnings 5from pickle import loads 6from pickle import dumps 7from functools import partial 8from importlib import resources 9 10import pytest 11 12import numpy as np 13from sklearn.datasets import get_data_home 14from sklearn.datasets import clear_data_home 15from sklearn.datasets import load_files 16from sklearn.datasets import load_sample_images 17from sklearn.datasets import load_sample_image 18from sklearn.datasets import load_digits 19from sklearn.datasets import load_diabetes 20from sklearn.datasets import load_linnerud 21from sklearn.datasets import load_iris 22from sklearn.datasets import load_breast_cancer 23from sklearn.datasets import load_boston 24from sklearn.datasets import load_wine 25from sklearn.datasets._base import ( 26 load_csv_data, 27 load_gzip_compressed_csv_data, 28) 29from sklearn.utils import Bunch 30from sklearn.utils._testing import SkipTest 31from sklearn.datasets.tests.test_common import check_as_frame 32 33from sklearn.externals._pilutil import pillow_installed 34 35from sklearn.utils import IS_PYPY 36 37 38def _remove_dir(path): 39 if os.path.isdir(path): 40 shutil.rmtree(path) 41 42 43@pytest.fixture(scope="module") 44def data_home(tmpdir_factory): 45 tmp_file = str(tmpdir_factory.mktemp("scikit_learn_data_home_test")) 46 yield tmp_file 47 _remove_dir(tmp_file) 48 49 50@pytest.fixture(scope="module") 51def load_files_root(tmpdir_factory): 52 tmp_file = str(tmpdir_factory.mktemp("scikit_learn_load_files_test")) 53 yield tmp_file 54 _remove_dir(tmp_file) 55 56 57@pytest.fixture 58def test_category_dir_1(load_files_root): 59 test_category_dir1 = tempfile.mkdtemp(dir=load_files_root) 60 sample_file = tempfile.NamedTemporaryFile(dir=test_category_dir1, delete=False) 61 sample_file.write(b"Hello World!\n") 62 sample_file.close() 63 yield str(test_category_dir1) 64 _remove_dir(test_category_dir1) 65 66 67@pytest.fixture 68def test_category_dir_2(load_files_root): 69 test_category_dir2 = tempfile.mkdtemp(dir=load_files_root) 70 yield str(test_category_dir2) 71 _remove_dir(test_category_dir2) 72 73 74def test_data_home(data_home): 75 # get_data_home will point to a pre-existing folder 76 data_home = get_data_home(data_home=data_home) 77 assert data_home == data_home 78 assert os.path.exists(data_home) 79 80 # clear_data_home will delete both the content and the folder it-self 81 clear_data_home(data_home=data_home) 82 assert not os.path.exists(data_home) 83 84 # if the folder is missing it will be created again 85 data_home = get_data_home(data_home=data_home) 86 assert os.path.exists(data_home) 87 88 89def test_default_empty_load_files(load_files_root): 90 res = load_files(load_files_root) 91 assert len(res.filenames) == 0 92 assert len(res.target_names) == 0 93 assert res.DESCR is None 94 95 96def test_default_load_files(test_category_dir_1, test_category_dir_2, load_files_root): 97 if IS_PYPY: 98 pytest.xfail("[PyPy] fails due to string containing NUL characters") 99 res = load_files(load_files_root) 100 assert len(res.filenames) == 1 101 assert len(res.target_names) == 2 102 assert res.DESCR is None 103 assert res.data == [b"Hello World!\n"] 104 105 106def test_load_files_w_categories_desc_and_encoding( 107 test_category_dir_1, test_category_dir_2, load_files_root 108): 109 if IS_PYPY: 110 pytest.xfail("[PyPy] fails due to string containing NUL characters") 111 category = os.path.abspath(test_category_dir_1).split("/").pop() 112 res = load_files( 113 load_files_root, description="test", categories=category, encoding="utf-8" 114 ) 115 assert len(res.filenames) == 1 116 assert len(res.target_names) == 1 117 assert res.DESCR == "test" 118 assert res.data == ["Hello World!\n"] 119 120 121def test_load_files_wo_load_content( 122 test_category_dir_1, test_category_dir_2, load_files_root 123): 124 res = load_files(load_files_root, load_content=False) 125 assert len(res.filenames) == 1 126 assert len(res.target_names) == 2 127 assert res.DESCR is None 128 assert res.get("data") is None 129 130 131@pytest.mark.parametrize( 132 "filename, expected_n_samples, expected_n_features, expected_target_names", 133 [ 134 ("wine_data.csv", 178, 13, ["class_0", "class_1", "class_2"]), 135 ("iris.csv", 150, 4, ["setosa", "versicolor", "virginica"]), 136 ("breast_cancer.csv", 569, 30, ["malignant", "benign"]), 137 ], 138) 139def test_load_csv_data( 140 filename, expected_n_samples, expected_n_features, expected_target_names 141): 142 actual_data, actual_target, actual_target_names = load_csv_data(filename) 143 assert actual_data.shape[0] == expected_n_samples 144 assert actual_data.shape[1] == expected_n_features 145 assert actual_target.shape[0] == expected_n_samples 146 np.testing.assert_array_equal(actual_target_names, expected_target_names) 147 148 149def test_load_csv_data_with_descr(): 150 data_file_name = "iris.csv" 151 descr_file_name = "iris.rst" 152 153 res_without_descr = load_csv_data(data_file_name=data_file_name) 154 res_with_descr = load_csv_data( 155 data_file_name=data_file_name, descr_file_name=descr_file_name 156 ) 157 assert len(res_with_descr) == 4 158 assert len(res_without_descr) == 3 159 160 np.testing.assert_array_equal(res_with_descr[0], res_without_descr[0]) 161 np.testing.assert_array_equal(res_with_descr[1], res_without_descr[1]) 162 np.testing.assert_array_equal(res_with_descr[2], res_without_descr[2]) 163 164 assert res_with_descr[-1].startswith(".. _iris_dataset:") 165 166 167@pytest.mark.parametrize( 168 "filename, kwargs, expected_shape", 169 [ 170 ("diabetes_data.csv.gz", {}, [442, 10]), 171 ("diabetes_target.csv.gz", {}, [442]), 172 ("digits.csv.gz", {"delimiter": ","}, [1797, 65]), 173 ], 174) 175def test_load_gzip_compressed_csv_data(filename, kwargs, expected_shape): 176 actual_data = load_gzip_compressed_csv_data(filename, **kwargs) 177 assert actual_data.shape == tuple(expected_shape) 178 179 180def test_load_gzip_compressed_csv_data_with_descr(): 181 data_file_name = "diabetes_target.csv.gz" 182 descr_file_name = "diabetes.rst" 183 184 expected_data = load_gzip_compressed_csv_data(data_file_name=data_file_name) 185 actual_data, descr = load_gzip_compressed_csv_data( 186 data_file_name=data_file_name, 187 descr_file_name=descr_file_name, 188 ) 189 190 np.testing.assert_array_equal(actual_data, expected_data) 191 assert descr.startswith(".. _diabetes_dataset:") 192 193 194def test_load_sample_images(): 195 try: 196 res = load_sample_images() 197 assert len(res.images) == 2 198 assert len(res.filenames) == 2 199 images = res.images 200 201 # assert is china image 202 assert np.all(images[0][0, 0, :] == np.array([174, 201, 231], dtype=np.uint8)) 203 # assert is flower image 204 assert np.all(images[1][0, 0, :] == np.array([2, 19, 13], dtype=np.uint8)) 205 assert res.DESCR 206 except ImportError: 207 warnings.warn("Could not load sample images, PIL is not available.") 208 209 210def test_load_sample_image(): 211 try: 212 china = load_sample_image("china.jpg") 213 assert china.dtype == "uint8" 214 assert china.shape == (427, 640, 3) 215 except ImportError: 216 warnings.warn("Could not load sample images, PIL is not available.") 217 218 219def test_load_missing_sample_image_error(): 220 if pillow_installed: 221 with pytest.raises(AttributeError): 222 load_sample_image("blop.jpg") 223 else: 224 warnings.warn("Could not load sample images, PIL is not available.") 225 226 227@pytest.mark.filterwarnings("ignore:Function load_boston is deprecated") 228@pytest.mark.parametrize( 229 "loader_func, data_shape, target_shape, n_target, has_descr, filenames", 230 [ 231 (load_breast_cancer, (569, 30), (569,), 2, True, ["filename"]), 232 (load_wine, (178, 13), (178,), 3, True, []), 233 (load_iris, (150, 4), (150,), 3, True, ["filename"]), 234 ( 235 load_linnerud, 236 (20, 3), 237 (20, 3), 238 3, 239 True, 240 ["data_filename", "target_filename"], 241 ), 242 (load_diabetes, (442, 10), (442,), None, True, []), 243 (load_digits, (1797, 64), (1797,), 10, True, []), 244 (partial(load_digits, n_class=9), (1617, 64), (1617,), 10, True, []), 245 (load_boston, (506, 13), (506,), None, True, ["filename"]), 246 ], 247) 248def test_loader(loader_func, data_shape, target_shape, n_target, has_descr, filenames): 249 bunch = loader_func() 250 251 assert isinstance(bunch, Bunch) 252 assert bunch.data.shape == data_shape 253 assert bunch.target.shape == target_shape 254 if hasattr(bunch, "feature_names"): 255 assert len(bunch.feature_names) == data_shape[1] 256 if n_target is not None: 257 assert len(bunch.target_names) == n_target 258 if has_descr: 259 assert bunch.DESCR 260 if filenames: 261 assert "data_module" in bunch 262 assert all( 263 [ 264 f in bunch and resources.is_resource(bunch["data_module"], bunch[f]) 265 for f in filenames 266 ] 267 ) 268 269 270@pytest.mark.parametrize( 271 "loader_func, data_dtype, target_dtype", 272 [ 273 (load_breast_cancer, np.float64, int), 274 (load_diabetes, np.float64, np.float64), 275 (load_digits, np.float64, int), 276 (load_iris, np.float64, int), 277 (load_linnerud, np.float64, np.float64), 278 (load_wine, np.float64, int), 279 ], 280) 281def test_toy_dataset_frame_dtype(loader_func, data_dtype, target_dtype): 282 default_result = loader_func() 283 check_as_frame( 284 default_result, 285 loader_func, 286 expected_data_dtype=data_dtype, 287 expected_target_dtype=target_dtype, 288 ) 289 290 291def test_loads_dumps_bunch(): 292 bunch = Bunch(x="x") 293 bunch_from_pkl = loads(dumps(bunch)) 294 bunch_from_pkl.x = "y" 295 assert bunch_from_pkl["x"] == bunch_from_pkl.x 296 297 298def test_bunch_pickle_generated_with_0_16_and_read_with_0_17(): 299 bunch = Bunch(key="original") 300 # This reproduces a problem when Bunch pickles have been created 301 # with scikit-learn 0.16 and are read with 0.17. Basically there 302 # is a surprising behaviour because reading bunch.key uses 303 # bunch.__dict__ (which is non empty for 0.16 Bunch objects) 304 # whereas assigning into bunch.key uses bunch.__setattr__. See 305 # https://github.com/scikit-learn/scikit-learn/issues/6196 for 306 # more details 307 bunch.__dict__["key"] = "set from __dict__" 308 bunch_from_pkl = loads(dumps(bunch)) 309 # After loading from pickle the __dict__ should have been ignored 310 assert bunch_from_pkl.key == "original" 311 assert bunch_from_pkl["key"] == "original" 312 # Making sure that changing the attr does change the value 313 # associated with __getitem__ as well 314 bunch_from_pkl.key = "changed" 315 assert bunch_from_pkl.key == "changed" 316 assert bunch_from_pkl["key"] == "changed" 317 318 319def test_bunch_dir(): 320 # check that dir (important for autocomplete) shows attributes 321 data = load_iris() 322 assert "data" in dir(data) 323 324 325# FIXME: to be removed in 1.2 326def test_load_boston_warning(): 327 """Check that we raise the ethical warning when loading `load_boston`.""" 328 warn_msg = "The Boston housing prices dataset has an ethical problem" 329 with pytest.warns(FutureWarning, match=warn_msg): 330 load_boston() 331 332 333@pytest.mark.filterwarnings("ignore:Function load_boston is deprecated") 334def test_load_boston_alternative(): 335 pd = pytest.importorskip("pandas") 336 if os.environ.get("SKLEARN_SKIP_NETWORK_TESTS", "1") == "1": 337 raise SkipTest( 338 "This test requires an internet connection to fetch the dataset." 339 ) 340 341 boston_sklearn = load_boston() 342 343 data_url = "http://lib.stat.cmu.edu/datasets/boston" 344 try: 345 raw_df = pd.read_csv(data_url, sep=r"\s+", skiprows=22, header=None) 346 except ConnectionError as e: 347 pytest.xfail(f"The dataset can't be downloaded. Got exception: {e}") 348 data = np.hstack([raw_df.values[::2, :], raw_df.values[1::2, :2]]) 349 target = raw_df.values[1::2, 2] 350 351 np.testing.assert_allclose(data, boston_sklearn.data) 352 np.testing.assert_allclose(target, boston_sklearn.target) 353