1#! /usr/bin/env python
2#
3# Copyright (C) 2007-2009 Cournapeau David <cournape@gmail.com>
4#               2010 Fabian Pedregosa <fabian.pedregosa@inria.fr>
5# License: 3-clause BSD
6
7import sys
8import os
9import platform
10import shutil
11
12# We need to import setuptools before because it monkey-patches distutils
13import setuptools  # noqa
14from distutils.command.clean import clean as Clean
15from distutils.command.sdist import sdist
16
17import traceback
18import importlib
19
20try:
21    import builtins
22except ImportError:
23    # Python 2 compat: just to be able to declare that Python >=3.7 is needed.
24    import __builtin__ as builtins
25
26# This is a bit (!) hackish: we are setting a global variable so that the
27# main sklearn __init__ can detect if it is being loaded by the setup
28# routine, to avoid attempting to load components that aren't built yet:
29# the numpy distutils extensions that are used by scikit-learn to
30# recursively build the compiled extensions in sub-packages is based on the
31# Python import machinery.
32builtins.__SKLEARN_SETUP__ = True
33
34
35DISTNAME = "scikit-learn"
36DESCRIPTION = "A set of python modules for machine learning and data mining"
37with open("README.rst") as f:
38    LONG_DESCRIPTION = f.read()
39MAINTAINER = "Andreas Mueller"
40MAINTAINER_EMAIL = "amueller@ais.uni-bonn.de"
41URL = "http://scikit-learn.org"
42DOWNLOAD_URL = "https://pypi.org/project/scikit-learn/#files"
43LICENSE = "new BSD"
44PROJECT_URLS = {
45    "Bug Tracker": "https://github.com/scikit-learn/scikit-learn/issues",
46    "Documentation": "https://scikit-learn.org/stable/documentation.html",
47    "Source Code": "https://github.com/scikit-learn/scikit-learn",
48}
49
50# We can actually import a restricted version of sklearn that
51# does not need the compiled code
52import sklearn  # noqa
53import sklearn._min_dependencies as min_deps  # noqa
54from sklearn.externals._packaging.version import parse as parse_version  # noqa
55
56
57VERSION = sklearn.__version__
58
59
60# For some commands, use setuptools
61SETUPTOOLS_COMMANDS = {
62    "develop",
63    "release",
64    "bdist_egg",
65    "bdist_rpm",
66    "bdist_wininst",
67    "install_egg_info",
68    "build_sphinx",
69    "egg_info",
70    "easy_install",
71    "upload",
72    "bdist_wheel",
73    "--single-version-externally-managed",
74}
75if SETUPTOOLS_COMMANDS.intersection(sys.argv):
76    extra_setuptools_args = dict(
77        zip_safe=False,  # the package can run out of an .egg file
78        include_package_data=True,
79        extras_require={
80            key: min_deps.tag_to_packages[key]
81            for key in ["examples", "docs", "tests", "benchmark"]
82        },
83    )
84else:
85    extra_setuptools_args = dict()
86
87
88# Custom clean command to remove build artifacts
89
90
91class CleanCommand(Clean):
92    description = "Remove build artifacts from the source tree"
93
94    def run(self):
95        Clean.run(self)
96        # Remove c files if we are not within a sdist package
97        cwd = os.path.abspath(os.path.dirname(__file__))
98        remove_c_files = not os.path.exists(os.path.join(cwd, "PKG-INFO"))
99        if remove_c_files:
100            print("Will remove generated .c files")
101        if os.path.exists("build"):
102            shutil.rmtree("build")
103        for dirpath, dirnames, filenames in os.walk("sklearn"):
104            for filename in filenames:
105                if any(
106                    filename.endswith(suffix)
107                    for suffix in (".so", ".pyd", ".dll", ".pyc")
108                ):
109                    os.unlink(os.path.join(dirpath, filename))
110                    continue
111                extension = os.path.splitext(filename)[1]
112                if remove_c_files and extension in [".c", ".cpp"]:
113                    pyx_file = str.replace(filename, extension, ".pyx")
114                    if os.path.exists(os.path.join(dirpath, pyx_file)):
115                        os.unlink(os.path.join(dirpath, filename))
116            for dirname in dirnames:
117                if dirname == "__pycache__":
118                    shutil.rmtree(os.path.join(dirpath, dirname))
119
120
121cmdclass = {"clean": CleanCommand, "sdist": sdist}
122
123# Custom build_ext command to set OpenMP compile flags depending on os and
124# compiler. Also makes it possible to set the parallelism level via
125# and environment variable (useful for the wheel building CI).
126# build_ext has to be imported after setuptools
127try:
128    from numpy.distutils.command.build_ext import build_ext  # noqa
129
130    class build_ext_subclass(build_ext):
131        def finalize_options(self):
132            super().finalize_options()
133            if self.parallel is None:
134                # Do not override self.parallel if already defined by
135                # command-line flag (--parallel or -j)
136
137                parallel = os.environ.get("SKLEARN_BUILD_PARALLEL")
138                if parallel:
139                    self.parallel = int(parallel)
140            if self.parallel:
141                print("setting parallel=%d " % self.parallel)
142
143        def build_extensions(self):
144            from sklearn._build_utils.openmp_helpers import get_openmp_flag
145
146            if sklearn._OPENMP_SUPPORTED:
147                openmp_flag = get_openmp_flag(self.compiler)
148
149                for e in self.extensions:
150                    e.extra_compile_args += openmp_flag
151                    e.extra_link_args += openmp_flag
152
153            build_ext.build_extensions(self)
154
155    cmdclass["build_ext"] = build_ext_subclass
156
157except ImportError:
158    # Numpy should not be a dependency just to be able to introspect
159    # that python 3.7 is required.
160    pass
161
162
163# Optional wheelhouse-uploader features
164# To automate release of binary packages for scikit-learn we need a tool
165# to download the packages generated by travis and appveyor workers (with
166# version number matching the current release) and upload them all at once
167# to PyPI at release time.
168# The URL of the artifact repositories are configured in the setup.cfg file.
169
170WHEELHOUSE_UPLOADER_COMMANDS = {"fetch_artifacts", "upload_all"}
171if WHEELHOUSE_UPLOADER_COMMANDS.intersection(sys.argv):
172    import wheelhouse_uploader.cmd
173
174    cmdclass.update(vars(wheelhouse_uploader.cmd))
175
176
177def configuration(parent_package="", top_path=None):
178    if os.path.exists("MANIFEST"):
179        os.remove("MANIFEST")
180
181    from numpy.distutils.misc_util import Configuration
182    from sklearn._build_utils import _check_cython_version
183
184    config = Configuration(None, parent_package, top_path)
185
186    # Avoid useless msg:
187    # "Ignoring attempt to set 'name' (from ... "
188    config.set_options(
189        ignore_setup_xxx_py=True,
190        assume_default_configuration=True,
191        delegate_options_to_subpackages=True,
192        quiet=True,
193    )
194
195    # Cython is required by config.add_subpackage for templated extensions
196    # that need the tempita sub-submodule. So check that we have the correct
197    # version of Cython so as to be able to raise a more informative error
198    # message from the start if it's not the case.
199    _check_cython_version()
200
201    config.add_subpackage("sklearn")
202
203    return config
204
205
206def check_package_status(package, min_version):
207    """
208    Returns a dictionary containing a boolean specifying whether given package
209    is up-to-date, along with the version string (empty string if
210    not installed).
211    """
212    package_status = {}
213    try:
214        module = importlib.import_module(package)
215        package_version = module.__version__
216        package_status["up_to_date"] = parse_version(package_version) >= parse_version(
217            min_version
218        )
219        package_status["version"] = package_version
220    except ImportError:
221        traceback.print_exc()
222        package_status["up_to_date"] = False
223        package_status["version"] = ""
224
225    req_str = "scikit-learn requires {} >= {}.\n".format(package, min_version)
226
227    instructions = (
228        "Installation instructions are available on the "
229        "scikit-learn website: "
230        "http://scikit-learn.org/stable/install.html\n"
231    )
232
233    if package_status["up_to_date"] is False:
234        if package_status["version"]:
235            raise ImportError(
236                "Your installation of {} {} is out-of-date.\n{}{}".format(
237                    package, package_status["version"], req_str, instructions
238                )
239            )
240        else:
241            raise ImportError(
242                "{} is not installed.\n{}{}".format(package, req_str, instructions)
243            )
244
245
246def setup_package():
247    metadata = dict(
248        name=DISTNAME,
249        maintainer=MAINTAINER,
250        maintainer_email=MAINTAINER_EMAIL,
251        description=DESCRIPTION,
252        license=LICENSE,
253        url=URL,
254        download_url=DOWNLOAD_URL,
255        project_urls=PROJECT_URLS,
256        version=VERSION,
257        long_description=LONG_DESCRIPTION,
258        classifiers=[
259            "Intended Audience :: Science/Research",
260            "Intended Audience :: Developers",
261            "License :: OSI Approved",
262            "Programming Language :: C",
263            "Programming Language :: Python",
264            "Topic :: Software Development",
265            "Topic :: Scientific/Engineering",
266            "Development Status :: 5 - Production/Stable",
267            "Operating System :: Microsoft :: Windows",
268            "Operating System :: POSIX",
269            "Operating System :: Unix",
270            "Operating System :: MacOS",
271            "Programming Language :: Python :: 3",
272            "Programming Language :: Python :: 3.7",
273            "Programming Language :: Python :: 3.8",
274            "Programming Language :: Python :: 3.9",
275            "Programming Language :: Python :: Implementation :: CPython",
276            "Programming Language :: Python :: Implementation :: PyPy",
277        ],
278        cmdclass=cmdclass,
279        python_requires=">=3.7",
280        install_requires=min_deps.tag_to_packages["install"],
281        package_data={"": ["*.pxd"]},
282        **extra_setuptools_args,
283    )
284
285    commands = [arg for arg in sys.argv[1:] if not arg.startswith("-")]
286    if all(
287        command in ("egg_info", "dist_info", "clean", "check") for command in commands
288    ):
289        # These actions are required to succeed without Numpy for example when
290        # pip is used to install Scikit-learn when Numpy is not yet present in
291        # the system.
292
293        # These commands use setup from setuptools
294        from setuptools import setup
295
296        metadata["version"] = VERSION
297    else:
298        if sys.version_info < (3, 6):
299            raise RuntimeError(
300                "Scikit-learn requires Python 3.7 or later. The current"
301                " Python version is %s installed in %s."
302                % (platform.python_version(), sys.executable)
303            )
304
305        check_package_status("numpy", min_deps.NUMPY_MIN_VERSION)
306
307        check_package_status("scipy", min_deps.SCIPY_MIN_VERSION)
308
309        # These commands require the setup from numpy.distutils because they
310        # may use numpy.distutils compiler classes.
311        from numpy.distutils.core import setup
312
313        metadata["configuration"] = configuration
314
315    setup(**metadata)
316
317
318if __name__ == "__main__":
319    setup_package()
320