1"""
2Monkey patching of distutils.
3"""
4
5import sys
6import distutils.filelist
7import platform
8import types
9import functools
10from importlib import import_module
11import inspect
12
13import setuptools
14
15__all__ = []
16"""
17Everything is private. Contact the project team
18if you think you need this functionality.
19"""
20
21
22def _get_mro(cls):
23    """
24    Returns the bases classes for cls sorted by the MRO.
25
26    Works around an issue on Jython where inspect.getmro will not return all
27    base classes if multiple classes share the same name. Instead, this
28    function will return a tuple containing the class itself, and the contents
29    of cls.__bases__. See https://github.com/pypa/setuptools/issues/1024.
30    """
31    if platform.python_implementation() == "Jython":
32        return (cls,) + cls.__bases__
33    return inspect.getmro(cls)
34
35
36def get_unpatched(item):
37    lookup = (
38        get_unpatched_class if isinstance(item, type) else
39        get_unpatched_function if isinstance(item, types.FunctionType) else
40        lambda item: None
41    )
42    return lookup(item)
43
44
45def get_unpatched_class(cls):
46    """Protect against re-patching the distutils if reloaded
47
48    Also ensures that no other distutils extension monkeypatched the distutils
49    first.
50    """
51    external_bases = (
52        cls
53        for cls in _get_mro(cls)
54        if not cls.__module__.startswith('setuptools')
55    )
56    base = next(external_bases)
57    if not base.__module__.startswith('distutils'):
58        msg = "distutils has already been patched by %r" % cls
59        raise AssertionError(msg)
60    return base
61
62
63def patch_all():
64    # we can't patch distutils.cmd, alas
65    distutils.core.Command = setuptools.Command
66
67    has_issue_12885 = sys.version_info <= (3, 5, 3)
68
69    if has_issue_12885:
70        # fix findall bug in distutils (http://bugs.python.org/issue12885)
71        distutils.filelist.findall = setuptools.findall
72
73    needs_warehouse = (
74        sys.version_info < (2, 7, 13)
75        or
76        (3, 4) < sys.version_info < (3, 4, 6)
77        or
78        (3, 5) < sys.version_info <= (3, 5, 3)
79    )
80
81    if needs_warehouse:
82        warehouse = 'https://upload.pypi.org/legacy/'
83        distutils.config.PyPIRCCommand.DEFAULT_REPOSITORY = warehouse
84
85    _patch_distribution_metadata()
86
87    # Install Distribution throughout the distutils
88    for module in distutils.dist, distutils.core, distutils.cmd:
89        module.Distribution = setuptools.dist.Distribution
90
91    # Install the patched Extension
92    distutils.core.Extension = setuptools.extension.Extension
93    distutils.extension.Extension = setuptools.extension.Extension
94    if 'distutils.command.build_ext' in sys.modules:
95        sys.modules['distutils.command.build_ext'].Extension = (
96            setuptools.extension.Extension
97        )
98
99    patch_for_msvc_specialized_compiler()
100
101
102def _patch_distribution_metadata():
103    """Patch write_pkg_file and read_pkg_file for higher metadata standards"""
104    for attr in ('write_pkg_file', 'read_pkg_file', 'get_metadata_version'):
105        new_val = getattr(setuptools.dist, attr)
106        setattr(distutils.dist.DistributionMetadata, attr, new_val)
107
108
109def patch_func(replacement, target_mod, func_name):
110    """
111    Patch func_name in target_mod with replacement
112
113    Important - original must be resolved by name to avoid
114    patching an already patched function.
115    """
116    original = getattr(target_mod, func_name)
117
118    # set the 'unpatched' attribute on the replacement to
119    # point to the original.
120    vars(replacement).setdefault('unpatched', original)
121
122    # replace the function in the original module
123    setattr(target_mod, func_name, replacement)
124
125
126def get_unpatched_function(candidate):
127    return getattr(candidate, 'unpatched')
128
129
130def patch_for_msvc_specialized_compiler():
131    """
132    Patch functions in distutils to use standalone Microsoft Visual C++
133    compilers.
134    """
135    # import late to avoid circular imports on Python < 3.5
136    msvc = import_module('setuptools.msvc')
137
138    if platform.system() != 'Windows':
139        # Compilers only available on Microsoft Windows
140        return
141
142    def patch_params(mod_name, func_name):
143        """
144        Prepare the parameters for patch_func to patch indicated function.
145        """
146        repl_prefix = 'msvc9_' if 'msvc9' in mod_name else 'msvc14_'
147        repl_name = repl_prefix + func_name.lstrip('_')
148        repl = getattr(msvc, repl_name)
149        mod = import_module(mod_name)
150        if not hasattr(mod, func_name):
151            raise ImportError(func_name)
152        return repl, mod, func_name
153
154    # Python 2.7 to 3.4
155    msvc9 = functools.partial(patch_params, 'distutils.msvc9compiler')
156
157    # Python 3.5+
158    msvc14 = functools.partial(patch_params, 'distutils._msvccompiler')
159
160    try:
161        # Patch distutils.msvc9compiler
162        patch_func(*msvc9('find_vcvarsall'))
163        patch_func(*msvc9('query_vcvarsall'))
164    except ImportError:
165        pass
166
167    try:
168        # Patch distutils._msvccompiler._get_vc_env
169        patch_func(*msvc14('_get_vc_env'))
170    except ImportError:
171        pass
172
173    try:
174        # Patch distutils._msvccompiler.gen_lib_options for Numpy
175        patch_func(*msvc14('gen_lib_options'))
176    except ImportError:
177        pass
178