1import os
2from os.path import exists
3from os.path import join
4from os import environ
5import warnings
6
7from sklearn.utils import IS_PYPY
8from sklearn.utils._testing import SkipTest
9from sklearn.utils._testing import check_skip_network
10from sklearn.utils.fixes import parse_version
11from sklearn.datasets import get_data_home
12from sklearn.datasets._base import _pkl_filepath
13from sklearn.datasets._twenty_newsgroups import CACHE_NAME
14
15
16def setup_labeled_faces():
17    data_home = get_data_home()
18    if not exists(join(data_home, "lfw_home")):
19        raise SkipTest("Skipping dataset loading doctests")
20
21
22def setup_rcv1():
23    check_skip_network()
24    # skip the test in rcv1.rst if the dataset is not already loaded
25    rcv1_dir = join(get_data_home(), "RCV1")
26    if not exists(rcv1_dir):
27        raise SkipTest("Download RCV1 dataset to run this test.")
28
29
30def setup_twenty_newsgroups():
31    cache_path = _pkl_filepath(get_data_home(), CACHE_NAME)
32    if not exists(cache_path):
33        raise SkipTest("Skipping dataset loading doctests")
34
35
36def setup_working_with_text_data():
37    if IS_PYPY and os.environ.get("CI", None):
38        raise SkipTest("Skipping too slow test with PyPy on CI")
39    check_skip_network()
40    cache_path = _pkl_filepath(get_data_home(), CACHE_NAME)
41    if not exists(cache_path):
42        raise SkipTest("Skipping dataset loading doctests")
43
44
45def setup_loading_other_datasets():
46    try:
47        import pandas  # noqa
48    except ImportError:
49        raise SkipTest("Skipping loading_other_datasets.rst, pandas not installed")
50
51    # checks SKLEARN_SKIP_NETWORK_TESTS to see if test should run
52    run_network_tests = environ.get("SKLEARN_SKIP_NETWORK_TESTS", "1") == "0"
53    if not run_network_tests:
54        raise SkipTest(
55            "Skipping loading_other_datasets.rst, tests can be "
56            "enabled by setting SKLEARN_SKIP_NETWORK_TESTS=0"
57        )
58
59
60def setup_compose():
61    try:
62        import pandas  # noqa
63    except ImportError:
64        raise SkipTest("Skipping compose.rst, pandas not installed")
65
66
67def setup_impute():
68    try:
69        import pandas  # noqa
70    except ImportError:
71        raise SkipTest("Skipping impute.rst, pandas not installed")
72
73
74def setup_grid_search():
75    try:
76        import pandas  # noqa
77    except ImportError:
78        raise SkipTest("Skipping grid_search.rst, pandas not installed")
79
80
81def setup_preprocessing():
82    try:
83        import pandas  # noqa
84
85        if parse_version(pandas.__version__) < parse_version("1.1.0"):
86            raise SkipTest("Skipping preprocessing.rst, pandas version < 1.1.0")
87    except ImportError:
88        raise SkipTest("Skipping preprocessing.rst, pandas not installed")
89
90
91def setup_unsupervised_learning():
92    try:
93        import skimage  # noqa
94    except ImportError:
95        raise SkipTest("Skipping unsupervised_learning.rst, scikit-image not installed")
96    # ignore deprecation warnings from scipy.misc.face
97    warnings.filterwarnings(
98        "ignore", "The binary mode of fromstring", DeprecationWarning
99    )
100
101
102def skip_if_matplotlib_not_installed(fname):
103    try:
104        import matplotlib  # noqa
105    except ImportError:
106        basename = os.path.basename(fname)
107        raise SkipTest(f"Skipping doctests for {basename}, matplotlib not installed")
108
109
110def pytest_runtest_setup(item):
111    fname = item.fspath.strpath
112    # normalise filename to use forward slashes on Windows for easier handling
113    # later
114    fname = fname.replace(os.sep, "/")
115
116    is_index = fname.endswith("datasets/index.rst")
117    if fname.endswith("datasets/labeled_faces.rst") or is_index:
118        setup_labeled_faces()
119    elif fname.endswith("datasets/rcv1.rst") or is_index:
120        setup_rcv1()
121    elif fname.endswith("datasets/twenty_newsgroups.rst") or is_index:
122        setup_twenty_newsgroups()
123    elif (
124        fname.endswith("tutorial/text_analytics/working_with_text_data.rst") or is_index
125    ):
126        setup_working_with_text_data()
127    elif fname.endswith("modules/compose.rst") or is_index:
128        setup_compose()
129    elif IS_PYPY and fname.endswith("modules/feature_extraction.rst"):
130        raise SkipTest("FeatureHasher is not compatible with PyPy")
131    elif fname.endswith("datasets/loading_other_datasets.rst"):
132        setup_loading_other_datasets()
133    elif fname.endswith("modules/impute.rst"):
134        setup_impute()
135    elif fname.endswith("modules/grid_search.rst"):
136        setup_grid_search()
137    elif fname.endswith("modules/preprocessing.rst"):
138        setup_preprocessing()
139    elif fname.endswith("statistical_inference/unsupervised_learning.rst"):
140        setup_unsupervised_learning()
141
142    rst_files_requiring_matplotlib = [
143        "modules/partial_dependence.rst",
144        "modules/tree.rst",
145        "tutorial/statistical_inference/settings.rst",
146        "tutorial/statistical_inference/supervised_learning.rst",
147    ]
148    for each in rst_files_requiring_matplotlib:
149        if fname.endswith(each):
150            skip_if_matplotlib_not_installed(fname)
151
152
153def pytest_configure(config):
154    # Use matplotlib agg backend during the tests including doctests
155    try:
156        import matplotlib
157
158        matplotlib.use("agg")
159    except ImportError:
160        pass
161