1import os 2from os.path import exists 3from os.path import join 4from os import environ 5import warnings 6 7from sklearn.utils import IS_PYPY 8from sklearn.utils._testing import SkipTest 9from sklearn.utils._testing import check_skip_network 10from sklearn.utils.fixes import parse_version 11from sklearn.datasets import get_data_home 12from sklearn.datasets._base import _pkl_filepath 13from sklearn.datasets._twenty_newsgroups import CACHE_NAME 14 15 16def setup_labeled_faces(): 17 data_home = get_data_home() 18 if not exists(join(data_home, "lfw_home")): 19 raise SkipTest("Skipping dataset loading doctests") 20 21 22def setup_rcv1(): 23 check_skip_network() 24 # skip the test in rcv1.rst if the dataset is not already loaded 25 rcv1_dir = join(get_data_home(), "RCV1") 26 if not exists(rcv1_dir): 27 raise SkipTest("Download RCV1 dataset to run this test.") 28 29 30def setup_twenty_newsgroups(): 31 cache_path = _pkl_filepath(get_data_home(), CACHE_NAME) 32 if not exists(cache_path): 33 raise SkipTest("Skipping dataset loading doctests") 34 35 36def setup_working_with_text_data(): 37 if IS_PYPY and os.environ.get("CI", None): 38 raise SkipTest("Skipping too slow test with PyPy on CI") 39 check_skip_network() 40 cache_path = _pkl_filepath(get_data_home(), CACHE_NAME) 41 if not exists(cache_path): 42 raise SkipTest("Skipping dataset loading doctests") 43 44 45def setup_loading_other_datasets(): 46 try: 47 import pandas # noqa 48 except ImportError: 49 raise SkipTest("Skipping loading_other_datasets.rst, pandas not installed") 50 51 # checks SKLEARN_SKIP_NETWORK_TESTS to see if test should run 52 run_network_tests = environ.get("SKLEARN_SKIP_NETWORK_TESTS", "1") == "0" 53 if not run_network_tests: 54 raise SkipTest( 55 "Skipping loading_other_datasets.rst, tests can be " 56 "enabled by setting SKLEARN_SKIP_NETWORK_TESTS=0" 57 ) 58 59 60def setup_compose(): 61 try: 62 import pandas # noqa 63 except ImportError: 64 raise SkipTest("Skipping compose.rst, pandas not installed") 65 66 67def setup_impute(): 68 try: 69 import pandas # noqa 70 except ImportError: 71 raise SkipTest("Skipping impute.rst, pandas not installed") 72 73 74def setup_grid_search(): 75 try: 76 import pandas # noqa 77 except ImportError: 78 raise SkipTest("Skipping grid_search.rst, pandas not installed") 79 80 81def setup_preprocessing(): 82 try: 83 import pandas # noqa 84 85 if parse_version(pandas.__version__) < parse_version("1.1.0"): 86 raise SkipTest("Skipping preprocessing.rst, pandas version < 1.1.0") 87 except ImportError: 88 raise SkipTest("Skipping preprocessing.rst, pandas not installed") 89 90 91def setup_unsupervised_learning(): 92 try: 93 import skimage # noqa 94 except ImportError: 95 raise SkipTest("Skipping unsupervised_learning.rst, scikit-image not installed") 96 # ignore deprecation warnings from scipy.misc.face 97 warnings.filterwarnings( 98 "ignore", "The binary mode of fromstring", DeprecationWarning 99 ) 100 101 102def skip_if_matplotlib_not_installed(fname): 103 try: 104 import matplotlib # noqa 105 except ImportError: 106 basename = os.path.basename(fname) 107 raise SkipTest(f"Skipping doctests for {basename}, matplotlib not installed") 108 109 110def pytest_runtest_setup(item): 111 fname = item.fspath.strpath 112 # normalise filename to use forward slashes on Windows for easier handling 113 # later 114 fname = fname.replace(os.sep, "/") 115 116 is_index = fname.endswith("datasets/index.rst") 117 if fname.endswith("datasets/labeled_faces.rst") or is_index: 118 setup_labeled_faces() 119 elif fname.endswith("datasets/rcv1.rst") or is_index: 120 setup_rcv1() 121 elif fname.endswith("datasets/twenty_newsgroups.rst") or is_index: 122 setup_twenty_newsgroups() 123 elif ( 124 fname.endswith("tutorial/text_analytics/working_with_text_data.rst") or is_index 125 ): 126 setup_working_with_text_data() 127 elif fname.endswith("modules/compose.rst") or is_index: 128 setup_compose() 129 elif IS_PYPY and fname.endswith("modules/feature_extraction.rst"): 130 raise SkipTest("FeatureHasher is not compatible with PyPy") 131 elif fname.endswith("datasets/loading_other_datasets.rst"): 132 setup_loading_other_datasets() 133 elif fname.endswith("modules/impute.rst"): 134 setup_impute() 135 elif fname.endswith("modules/grid_search.rst"): 136 setup_grid_search() 137 elif fname.endswith("modules/preprocessing.rst"): 138 setup_preprocessing() 139 elif fname.endswith("statistical_inference/unsupervised_learning.rst"): 140 setup_unsupervised_learning() 141 142 rst_files_requiring_matplotlib = [ 143 "modules/partial_dependence.rst", 144 "modules/tree.rst", 145 "tutorial/statistical_inference/settings.rst", 146 "tutorial/statistical_inference/supervised_learning.rst", 147 ] 148 for each in rst_files_requiring_matplotlib: 149 if fname.endswith(each): 150 skip_if_matplotlib_not_installed(fname) 151 152 153def pytest_configure(config): 154 # Use matplotlib agg backend during the tests including doctests 155 try: 156 import matplotlib 157 158 matplotlib.use("agg") 159 except ImportError: 160 pass 161