1import os
2import operator
3import sys
4import contextlib
5import itertools
6import unittest
7from distutils.errors import DistutilsError, DistutilsOptionError
8from distutils import log
9from unittest import TestLoader
10
11from pkg_resources import (resource_listdir, resource_exists, normalize_path,
12                           working_set, _namespace_packages, evaluate_marker,
13                           add_activation_listener, require, EntryPoint)
14from setuptools import Command
15from .build_py import _unique_everseen
16
17
18class ScanningLoader(TestLoader):
19
20    def __init__(self):
21        TestLoader.__init__(self)
22        self._visited = set()
23
24    def loadTestsFromModule(self, module, pattern=None):
25        """Return a suite of all tests cases contained in the given module
26
27        If the module is a package, load tests from all the modules in it.
28        If the module has an ``additional_tests`` function, call it and add
29        the return value to the tests.
30        """
31        if module in self._visited:
32            return None
33        self._visited.add(module)
34
35        tests = []
36        tests.append(TestLoader.loadTestsFromModule(self, module))
37
38        if hasattr(module, "additional_tests"):
39            tests.append(module.additional_tests())
40
41        if hasattr(module, '__path__'):
42            for file in resource_listdir(module.__name__, ''):
43                if file.endswith('.py') and file != '__init__.py':
44                    submodule = module.__name__ + '.' + file[:-3]
45                else:
46                    if resource_exists(module.__name__, file + '/__init__.py'):
47                        submodule = module.__name__ + '.' + file
48                    else:
49                        continue
50                tests.append(self.loadTestsFromName(submodule))
51
52        if len(tests) != 1:
53            return self.suiteClass(tests)
54        else:
55            return tests[0]  # don't create a nested suite for only one return
56
57
58# adapted from jaraco.classes.properties:NonDataProperty
59class NonDataProperty:
60    def __init__(self, fget):
61        self.fget = fget
62
63    def __get__(self, obj, objtype=None):
64        if obj is None:
65            return self
66        return self.fget(obj)
67
68
69class test(Command):
70    """Command to run unit tests after in-place build"""
71
72    description = "run unit tests after in-place build (deprecated)"
73
74    user_options = [
75        ('test-module=', 'm', "Run 'test_suite' in specified module"),
76        ('test-suite=', 's',
77         "Run single test, case or suite (e.g. 'module.test_suite')"),
78        ('test-runner=', 'r', "Test runner to use"),
79    ]
80
81    def initialize_options(self):
82        self.test_suite = None
83        self.test_module = None
84        self.test_loader = None
85        self.test_runner = None
86
87    def finalize_options(self):
88
89        if self.test_suite and self.test_module:
90            msg = "You may specify a module or a suite, but not both"
91            raise DistutilsOptionError(msg)
92
93        if self.test_suite is None:
94            if self.test_module is None:
95                self.test_suite = self.distribution.test_suite
96            else:
97                self.test_suite = self.test_module + ".test_suite"
98
99        if self.test_loader is None:
100            self.test_loader = getattr(self.distribution, 'test_loader', None)
101        if self.test_loader is None:
102            self.test_loader = "setuptools.command.test:ScanningLoader"
103        if self.test_runner is None:
104            self.test_runner = getattr(self.distribution, 'test_runner', None)
105
106    @NonDataProperty
107    def test_args(self):
108        return list(self._test_args())
109
110    def _test_args(self):
111        if not self.test_suite and sys.version_info >= (2, 7):
112            yield 'discover'
113        if self.verbose:
114            yield '--verbose'
115        if self.test_suite:
116            yield self.test_suite
117
118    def with_project_on_sys_path(self, func):
119        """
120        Backward compatibility for project_on_sys_path context.
121        """
122        with self.project_on_sys_path():
123            func()
124
125    @contextlib.contextmanager
126    def project_on_sys_path(self, include_dists=[]):
127        with_2to3 = getattr(self.distribution, 'use_2to3', False)
128
129        if with_2to3:
130            # If we run 2to3 we can not do this inplace:
131
132            # Ensure metadata is up-to-date
133            self.reinitialize_command('build_py', inplace=0)
134            self.run_command('build_py')
135            bpy_cmd = self.get_finalized_command("build_py")
136            build_path = normalize_path(bpy_cmd.build_lib)
137
138            # Build extensions
139            self.reinitialize_command('egg_info', egg_base=build_path)
140            self.run_command('egg_info')
141
142            self.reinitialize_command('build_ext', inplace=0)
143            self.run_command('build_ext')
144        else:
145            # Without 2to3 inplace works fine:
146            self.run_command('egg_info')
147
148            # Build extensions in-place
149            self.reinitialize_command('build_ext', inplace=1)
150            self.run_command('build_ext')
151
152        ei_cmd = self.get_finalized_command("egg_info")
153
154        old_path = sys.path[:]
155        old_modules = sys.modules.copy()
156
157        try:
158            project_path = normalize_path(ei_cmd.egg_base)
159            sys.path.insert(0, project_path)
160            working_set.__init__()
161            add_activation_listener(lambda dist: dist.activate())
162            require('%s==%s' % (ei_cmd.egg_name, ei_cmd.egg_version))
163            with self.paths_on_pythonpath([project_path]):
164                yield
165        finally:
166            sys.path[:] = old_path
167            sys.modules.clear()
168            sys.modules.update(old_modules)
169            working_set.__init__()
170
171    @staticmethod
172    @contextlib.contextmanager
173    def paths_on_pythonpath(paths):
174        """
175        Add the indicated paths to the head of the PYTHONPATH environment
176        variable so that subprocesses will also see the packages at
177        these paths.
178
179        Do this in a context that restores the value on exit.
180        """
181        nothing = object()
182        orig_pythonpath = os.environ.get('PYTHONPATH', nothing)
183        current_pythonpath = os.environ.get('PYTHONPATH', '')
184        try:
185            prefix = os.pathsep.join(_unique_everseen(paths))
186            to_join = filter(None, [prefix, current_pythonpath])
187            new_path = os.pathsep.join(to_join)
188            if new_path:
189                os.environ['PYTHONPATH'] = new_path
190            yield
191        finally:
192            if orig_pythonpath is nothing:
193                os.environ.pop('PYTHONPATH', None)
194            else:
195                os.environ['PYTHONPATH'] = orig_pythonpath
196
197    @staticmethod
198    def install_dists(dist):
199        """
200        Install the requirements indicated by self.distribution and
201        return an iterable of the dists that were built.
202        """
203        ir_d = dist.fetch_build_eggs(dist.install_requires)
204        tr_d = dist.fetch_build_eggs(dist.tests_require or [])
205        er_d = dist.fetch_build_eggs(
206            v for k, v in dist.extras_require.items()
207            if k.startswith(':') and evaluate_marker(k[1:])
208        )
209        return itertools.chain(ir_d, tr_d, er_d)
210
211    def run(self):
212        self.announce(
213            "WARNING: Testing via this command is deprecated and will be "
214            "removed in a future version. Users looking for a generic test "
215            "entry point independent of test runner are encouraged to use "
216            "tox.",
217            log.WARN,
218        )
219
220        installed_dists = self.install_dists(self.distribution)
221
222        cmd = ' '.join(self._argv)
223        if self.dry_run:
224            self.announce('skipping "%s" (dry run)' % cmd)
225            return
226
227        self.announce('running "%s"' % cmd)
228
229        paths = map(operator.attrgetter('location'), installed_dists)
230        with self.paths_on_pythonpath(paths):
231            with self.project_on_sys_path():
232                self.run_tests()
233
234    def run_tests(self):
235        # Purge modules under test from sys.modules. The test loader will
236        # re-import them from the build location. Required when 2to3 is used
237        # with namespace packages.
238        if getattr(self.distribution, 'use_2to3', False):
239            module = self.test_suite.split('.')[0]
240            if module in _namespace_packages:
241                del_modules = []
242                if module in sys.modules:
243                    del_modules.append(module)
244                module += '.'
245                for name in sys.modules:
246                    if name.startswith(module):
247                        del_modules.append(name)
248                list(map(sys.modules.__delitem__, del_modules))
249
250        test = unittest.main(
251            None, None, self._argv,
252            testLoader=self._resolve_as_ep(self.test_loader),
253            testRunner=self._resolve_as_ep(self.test_runner),
254            exit=False,
255        )
256        if not test.result.wasSuccessful():
257            msg = 'Test failed: %s' % test.result
258            self.announce(msg, log.ERROR)
259            raise DistutilsError(msg)
260
261    @property
262    def _argv(self):
263        return ['unittest'] + self.test_args
264
265    @staticmethod
266    def _resolve_as_ep(val):
267        """
268        Load the indicated attribute value, called, as a as if it were
269        specified as an entry point.
270        """
271        if val is None:
272            return
273        parsed = EntryPoint.parse("x=" + val)
274        return parsed.resolve()()
275