1import os
2import operator
3import sys
4import contextlib
5import itertools
6import unittest
7from distutils.errors import DistutilsError, DistutilsOptionError
8from distutils import log
9from unittest import TestLoader
10
11from setuptools.extern import six
12from setuptools.extern.six.moves import map, filter
13
14from pkg_resources import (resource_listdir, resource_exists, normalize_path,
15                           working_set, _namespace_packages, evaluate_marker,
16                           add_activation_listener, require, EntryPoint)
17from setuptools import Command
18from .build_py import _unique_everseen
19
20__metaclass__ = type
21
22
23class ScanningLoader(TestLoader):
24
25    def __init__(self):
26        TestLoader.__init__(self)
27        self._visited = set()
28
29    def loadTestsFromModule(self, module, pattern=None):
30        """Return a suite of all tests cases contained in the given module
31
32        If the module is a package, load tests from all the modules in it.
33        If the module has an ``additional_tests`` function, call it and add
34        the return value to the tests.
35        """
36        if module in self._visited:
37            return None
38        self._visited.add(module)
39
40        tests = []
41        tests.append(TestLoader.loadTestsFromModule(self, module))
42
43        if hasattr(module, "additional_tests"):
44            tests.append(module.additional_tests())
45
46        if hasattr(module, '__path__'):
47            for file in resource_listdir(module.__name__, ''):
48                if file.endswith('.py') and file != '__init__.py':
49                    submodule = module.__name__ + '.' + file[:-3]
50                else:
51                    if resource_exists(module.__name__, file + '/__init__.py'):
52                        submodule = module.__name__ + '.' + file
53                    else:
54                        continue
55                tests.append(self.loadTestsFromName(submodule))
56
57        if len(tests) != 1:
58            return self.suiteClass(tests)
59        else:
60            return tests[0]  # don't create a nested suite for only one return
61
62
63# adapted from jaraco.classes.properties:NonDataProperty
64class NonDataProperty:
65    def __init__(self, fget):
66        self.fget = fget
67
68    def __get__(self, obj, objtype=None):
69        if obj is None:
70            return self
71        return self.fget(obj)
72
73
74class test(Command):
75    """Command to run unit tests after in-place build"""
76
77    description = "run unit tests after in-place build (deprecated)"
78
79    user_options = [
80        ('test-module=', 'm', "Run 'test_suite' in specified module"),
81        ('test-suite=', 's',
82         "Run single test, case or suite (e.g. 'module.test_suite')"),
83        ('test-runner=', 'r', "Test runner to use"),
84    ]
85
86    def initialize_options(self):
87        self.test_suite = None
88        self.test_module = None
89        self.test_loader = None
90        self.test_runner = None
91
92    def finalize_options(self):
93
94        if self.test_suite and self.test_module:
95            msg = "You may specify a module or a suite, but not both"
96            raise DistutilsOptionError(msg)
97
98        if self.test_suite is None:
99            if self.test_module is None:
100                self.test_suite = self.distribution.test_suite
101            else:
102                self.test_suite = self.test_module + ".test_suite"
103
104        if self.test_loader is None:
105            self.test_loader = getattr(self.distribution, 'test_loader', None)
106        if self.test_loader is None:
107            self.test_loader = "setuptools.command.test:ScanningLoader"
108        if self.test_runner is None:
109            self.test_runner = getattr(self.distribution, 'test_runner', None)
110
111    @NonDataProperty
112    def test_args(self):
113        return list(self._test_args())
114
115    def _test_args(self):
116        if not self.test_suite and sys.version_info >= (2, 7):
117            yield 'discover'
118        if self.verbose:
119            yield '--verbose'
120        if self.test_suite:
121            yield self.test_suite
122
123    def with_project_on_sys_path(self, func):
124        """
125        Backward compatibility for project_on_sys_path context.
126        """
127        with self.project_on_sys_path():
128            func()
129
130    @contextlib.contextmanager
131    def project_on_sys_path(self, include_dists=[]):
132        with_2to3 = not six.PY2 and getattr(
133            self.distribution, 'use_2to3', False)
134
135        if with_2to3:
136            # If we run 2to3 we can not do this inplace:
137
138            # Ensure metadata is up-to-date
139            self.reinitialize_command('build_py', inplace=0)
140            self.run_command('build_py')
141            bpy_cmd = self.get_finalized_command("build_py")
142            build_path = normalize_path(bpy_cmd.build_lib)
143
144            # Build extensions
145            self.reinitialize_command('egg_info', egg_base=build_path)
146            self.run_command('egg_info')
147
148            self.reinitialize_command('build_ext', inplace=0)
149            self.run_command('build_ext')
150        else:
151            # Without 2to3 inplace works fine:
152            self.run_command('egg_info')
153
154            # Build extensions in-place
155            self.reinitialize_command('build_ext', inplace=1)
156            self.run_command('build_ext')
157
158        ei_cmd = self.get_finalized_command("egg_info")
159
160        old_path = sys.path[:]
161        old_modules = sys.modules.copy()
162
163        try:
164            project_path = normalize_path(ei_cmd.egg_base)
165            sys.path.insert(0, project_path)
166            working_set.__init__()
167            add_activation_listener(lambda dist: dist.activate())
168            require('%s==%s' % (ei_cmd.egg_name, ei_cmd.egg_version))
169            with self.paths_on_pythonpath([project_path]):
170                yield
171        finally:
172            sys.path[:] = old_path
173            sys.modules.clear()
174            sys.modules.update(old_modules)
175            working_set.__init__()
176
177    @staticmethod
178    @contextlib.contextmanager
179    def paths_on_pythonpath(paths):
180        """
181        Add the indicated paths to the head of the PYTHONPATH environment
182        variable so that subprocesses will also see the packages at
183        these paths.
184
185        Do this in a context that restores the value on exit.
186        """
187        nothing = object()
188        orig_pythonpath = os.environ.get('PYTHONPATH', nothing)
189        current_pythonpath = os.environ.get('PYTHONPATH', '')
190        try:
191            prefix = os.pathsep.join(_unique_everseen(paths))
192            to_join = filter(None, [prefix, current_pythonpath])
193            new_path = os.pathsep.join(to_join)
194            if new_path:
195                os.environ['PYTHONPATH'] = new_path
196            yield
197        finally:
198            if orig_pythonpath is nothing:
199                os.environ.pop('PYTHONPATH', None)
200            else:
201                os.environ['PYTHONPATH'] = orig_pythonpath
202
203    @staticmethod
204    def install_dists(dist):
205        """
206        Install the requirements indicated by self.distribution and
207        return an iterable of the dists that were built.
208        """
209        ir_d = dist.fetch_build_eggs(dist.install_requires)
210        tr_d = dist.fetch_build_eggs(dist.tests_require or [])
211        er_d = dist.fetch_build_eggs(
212            v for k, v in dist.extras_require.items()
213            if k.startswith(':') and evaluate_marker(k[1:])
214        )
215        return itertools.chain(ir_d, tr_d, er_d)
216
217    def run(self):
218        self.announce(
219            "WARNING: Testing via this command is deprecated and will be "
220            "removed in a future version. Users looking for a generic test "
221            "entry point independent of test runner are encouraged to use "
222            "tox.",
223            log.WARN,
224        )
225
226        installed_dists = self.install_dists(self.distribution)
227
228        cmd = ' '.join(self._argv)
229        if self.dry_run:
230            self.announce('skipping "%s" (dry run)' % cmd)
231            return
232
233        self.announce('running "%s"' % cmd)
234
235        paths = map(operator.attrgetter('location'), installed_dists)
236        with self.paths_on_pythonpath(paths):
237            with self.project_on_sys_path():
238                self.run_tests()
239
240    def run_tests(self):
241        # Purge modules under test from sys.modules. The test loader will
242        # re-import them from the build location. Required when 2to3 is used
243        # with namespace packages.
244        if not six.PY2 and getattr(self.distribution, 'use_2to3', False):
245            module = self.test_suite.split('.')[0]
246            if module in _namespace_packages:
247                del_modules = []
248                if module in sys.modules:
249                    del_modules.append(module)
250                module += '.'
251                for name in sys.modules:
252                    if name.startswith(module):
253                        del_modules.append(name)
254                list(map(sys.modules.__delitem__, del_modules))
255
256        test = unittest.main(
257            None, None, self._argv,
258            testLoader=self._resolve_as_ep(self.test_loader),
259            testRunner=self._resolve_as_ep(self.test_runner),
260            exit=False,
261        )
262        if not test.result.wasSuccessful():
263            msg = 'Test failed: %s' % test.result
264            self.announce(msg, log.ERROR)
265            raise DistutilsError(msg)
266
267    @property
268    def _argv(self):
269        return ['unittest'] + self.test_args
270
271    @staticmethod
272    def _resolve_as_ep(val):
273        """
274        Load the indicated attribute value, called, as a as if it were
275        specified as an entry point.
276        """
277        if val is None:
278            return
279        parsed = EntryPoint.parse("x=" + val)
280        return parsed.resolve()()
281