1"""Loading unittests."""
2
3import os
4import re
5import sys
6import traceback
7import types
8import functools
9import warnings
10
11from fnmatch import fnmatch, fnmatchcase
12
13from . import case, suite, util
14
15__unittest = True
16
17# what about .pyc (etc)
18# we would need to avoid loading the same tests multiple times
19# from '.py', *and* '.pyc'
20VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE)
21
22
23class _FailedTest(case.TestCase):
24    _testMethodName = None
25
26    def __init__(self, method_name, exception):
27        self._exception = exception
28        super(_FailedTest, self).__init__(method_name)
29
30    def __getattr__(self, name):
31        if name != self._testMethodName:
32            return super(_FailedTest, self).__getattr__(name)
33        def testFailure():
34            raise self._exception
35        return testFailure
36
37
38def _make_failed_import_test(name, suiteClass):
39    message = 'Failed to import test module: %s\n%s' % (
40        name, traceback.format_exc())
41    return _make_failed_test(name, ImportError(message), suiteClass, message)
42
43def _make_failed_load_tests(name, exception, suiteClass):
44    message = 'Failed to call load_tests:\n%s' % (traceback.format_exc(),)
45    return _make_failed_test(
46        name, exception, suiteClass, message)
47
48def _make_failed_test(methodname, exception, suiteClass, message):
49    test = _FailedTest(methodname, exception)
50    return suiteClass((test,)), message
51
52def _make_skipped_test(methodname, exception, suiteClass):
53    @case.skip(str(exception))
54    def testSkipped(self):
55        pass
56    attrs = {methodname: testSkipped}
57    TestClass = type("ModuleSkipped", (case.TestCase,), attrs)
58    return suiteClass((TestClass(methodname),))
59
60def _jython_aware_splitext(path):
61    if path.lower().endswith('$py.class'):
62        return path[:-9]
63    return os.path.splitext(path)[0]
64
65
66class TestLoader(object):
67    """
68    This class is responsible for loading tests according to various criteria
69    and returning them wrapped in a TestSuite
70    """
71    testMethodPrefix = 'test'
72    sortTestMethodsUsing = staticmethod(util.three_way_cmp)
73    testNamePatterns = None
74    suiteClass = suite.TestSuite
75    _top_level_dir = None
76
77    def __init__(self):
78        super(TestLoader, self).__init__()
79        self.errors = []
80        # Tracks packages which we have called into via load_tests, to
81        # avoid infinite re-entrancy.
82        self._loading_packages = set()
83
84    def loadTestsFromTestCase(self, testCaseClass):
85        """Return a suite of all test cases contained in testCaseClass"""
86        if issubclass(testCaseClass, suite.TestSuite):
87            raise TypeError("Test cases should not be derived from "
88                            "TestSuite. Maybe you meant to derive from "
89                            "TestCase?")
90        testCaseNames = self.getTestCaseNames(testCaseClass)
91        if not testCaseNames and hasattr(testCaseClass, 'runTest'):
92            testCaseNames = ['runTest']
93        loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames))
94        return loaded_suite
95
96    # XXX After Python 3.5, remove backward compatibility hacks for
97    # use_load_tests deprecation via *args and **kws.  See issue 16662.
98    def loadTestsFromModule(self, module, *args, pattern=None, **kws):
99        """Return a suite of all test cases contained in the given module"""
100        # This method used to take an undocumented and unofficial
101        # use_load_tests argument.  For backward compatibility, we still
102        # accept the argument (which can also be the first position) but we
103        # ignore it and issue a deprecation warning if it's present.
104        if len(args) > 0 or 'use_load_tests' in kws:
105            warnings.warn('use_load_tests is deprecated and ignored',
106                          DeprecationWarning)
107            kws.pop('use_load_tests', None)
108        if len(args) > 1:
109            # Complain about the number of arguments, but don't forget the
110            # required `module` argument.
111            complaint = len(args) + 1
112            raise TypeError('loadTestsFromModule() takes 1 positional argument but {} were given'.format(complaint))
113        if len(kws) != 0:
114            # Since the keyword arguments are unsorted (see PEP 468), just
115            # pick the alphabetically sorted first argument to complain about,
116            # if multiple were given.  At least the error message will be
117            # predictable.
118            complaint = sorted(kws)[0]
119            raise TypeError("loadTestsFromModule() got an unexpected keyword argument '{}'".format(complaint))
120        tests = []
121        for name in dir(module):
122            obj = getattr(module, name)
123            if isinstance(obj, type) and issubclass(obj, case.TestCase):
124                tests.append(self.loadTestsFromTestCase(obj))
125
126        load_tests = getattr(module, 'load_tests', None)
127        tests = self.suiteClass(tests)
128        if load_tests is not None:
129            try:
130                return load_tests(self, tests, pattern)
131            except Exception as e:
132                error_case, error_message = _make_failed_load_tests(
133                    module.__name__, e, self.suiteClass)
134                self.errors.append(error_message)
135                return error_case
136        return tests
137
138    def loadTestsFromName(self, name, module=None):
139        """Return a suite of all test cases given a string specifier.
140
141        The name may resolve either to a module, a test case class, a
142        test method within a test case class, or a callable object which
143        returns a TestCase or TestSuite instance.
144
145        The method optionally resolves the names relative to a given module.
146        """
147        parts = name.split('.')
148        error_case, error_message = None, None
149        if module is None:
150            parts_copy = parts[:]
151            while parts_copy:
152                try:
153                    module_name = '.'.join(parts_copy)
154                    module = __import__(module_name)
155                    break
156                except ImportError:
157                    next_attribute = parts_copy.pop()
158                    # Last error so we can give it to the user if needed.
159                    error_case, error_message = _make_failed_import_test(
160                        next_attribute, self.suiteClass)
161                    if not parts_copy:
162                        # Even the top level import failed: report that error.
163                        self.errors.append(error_message)
164                        return error_case
165            parts = parts[1:]
166        obj = module
167        for part in parts:
168            try:
169                parent, obj = obj, getattr(obj, part)
170            except AttributeError as e:
171                # We can't traverse some part of the name.
172                if (getattr(obj, '__path__', None) is not None
173                    and error_case is not None):
174                    # This is a package (no __path__ per importlib docs), and we
175                    # encountered an error importing something. We cannot tell
176                    # the difference between package.WrongNameTestClass and
177                    # package.wrong_module_name so we just report the
178                    # ImportError - it is more informative.
179                    self.errors.append(error_message)
180                    return error_case
181                else:
182                    # Otherwise, we signal that an AttributeError has occurred.
183                    error_case, error_message = _make_failed_test(
184                        part, e, self.suiteClass,
185                        'Failed to access attribute:\n%s' % (
186                            traceback.format_exc(),))
187                    self.errors.append(error_message)
188                    return error_case
189
190        if isinstance(obj, types.ModuleType):
191            return self.loadTestsFromModule(obj)
192        elif isinstance(obj, type) and issubclass(obj, case.TestCase):
193            return self.loadTestsFromTestCase(obj)
194        elif (isinstance(obj, types.FunctionType) and
195              isinstance(parent, type) and
196              issubclass(parent, case.TestCase)):
197            name = parts[-1]
198            inst = parent(name)
199            # static methods follow a different path
200            if not isinstance(getattr(inst, name), types.FunctionType):
201                return self.suiteClass([inst])
202        elif isinstance(obj, suite.TestSuite):
203            return obj
204        if callable(obj):
205            test = obj()
206            if isinstance(test, suite.TestSuite):
207                return test
208            elif isinstance(test, case.TestCase):
209                return self.suiteClass([test])
210            else:
211                raise TypeError("calling %s returned %s, not a test" %
212                                (obj, test))
213        else:
214            raise TypeError("don't know how to make test from: %s" % obj)
215
216    def loadTestsFromNames(self, names, module=None):
217        """Return a suite of all test cases found using the given sequence
218        of string specifiers. See 'loadTestsFromName()'.
219        """
220        suites = [self.loadTestsFromName(name, module) for name in names]
221        return self.suiteClass(suites)
222
223    def getTestCaseNames(self, testCaseClass):
224        """Return a sorted sequence of method names found within testCaseClass
225        """
226        def shouldIncludeMethod(attrname):
227            if not attrname.startswith(self.testMethodPrefix):
228                return False
229            testFunc = getattr(testCaseClass, attrname)
230            if not callable(testFunc):
231                return False
232            fullName = f'%s.%s.%s' % (
233                testCaseClass.__module__, testCaseClass.__qualname__, attrname
234            )
235            return self.testNamePatterns is None or \
236                any(fnmatchcase(fullName, pattern) for pattern in self.testNamePatterns)
237        testFnNames = list(filter(shouldIncludeMethod, dir(testCaseClass)))
238        if self.sortTestMethodsUsing:
239            testFnNames.sort(key=functools.cmp_to_key(self.sortTestMethodsUsing))
240        return testFnNames
241
242    def discover(self, start_dir, pattern='test*.py', top_level_dir=None):
243        """Find and return all test modules from the specified start
244        directory, recursing into subdirectories to find them and return all
245        tests found within them. Only test files that match the pattern will
246        be loaded. (Using shell style pattern matching.)
247
248        All test modules must be importable from the top level of the project.
249        If the start directory is not the top level directory then the top
250        level directory must be specified separately.
251
252        If a test package name (directory with '__init__.py') matches the
253        pattern then the package will be checked for a 'load_tests' function. If
254        this exists then it will be called with (loader, tests, pattern) unless
255        the package has already had load_tests called from the same discovery
256        invocation, in which case the package module object is not scanned for
257        tests - this ensures that when a package uses discover to further
258        discover child tests that infinite recursion does not happen.
259
260        If load_tests exists then discovery does *not* recurse into the package,
261        load_tests is responsible for loading all tests in the package.
262
263        The pattern is deliberately not stored as a loader attribute so that
264        packages can continue discovery themselves. top_level_dir is stored so
265        load_tests does not need to pass this argument in to loader.discover().
266
267        Paths are sorted before being imported to ensure reproducible execution
268        order even on filesystems with non-alphabetical ordering like ext3/4.
269        """
270        set_implicit_top = False
271        if top_level_dir is None and self._top_level_dir is not None:
272            # make top_level_dir optional if called from load_tests in a package
273            top_level_dir = self._top_level_dir
274        elif top_level_dir is None:
275            set_implicit_top = True
276            top_level_dir = start_dir
277
278        top_level_dir = os.path.abspath(top_level_dir)
279
280        if not top_level_dir in sys.path:
281            # all test modules must be importable from the top level directory
282            # should we *unconditionally* put the start directory in first
283            # in sys.path to minimise likelihood of conflicts between installed
284            # modules and development versions?
285            sys.path.insert(0, top_level_dir)
286        self._top_level_dir = top_level_dir
287
288        is_not_importable = False
289        is_namespace = False
290        tests = []
291        if os.path.isdir(os.path.abspath(start_dir)):
292            start_dir = os.path.abspath(start_dir)
293            if start_dir != top_level_dir:
294                is_not_importable = not os.path.isfile(os.path.join(start_dir, '__init__.py'))
295        else:
296            # support for discovery from dotted module names
297            try:
298                __import__(start_dir)
299            except ImportError:
300                is_not_importable = True
301            else:
302                the_module = sys.modules[start_dir]
303                top_part = start_dir.split('.')[0]
304                try:
305                    start_dir = os.path.abspath(
306                       os.path.dirname((the_module.__file__)))
307                except AttributeError:
308                    # look for namespace packages
309                    try:
310                        spec = the_module.__spec__
311                    except AttributeError:
312                        spec = None
313
314                    if spec and spec.loader is None:
315                        if spec.submodule_search_locations is not None:
316                            is_namespace = True
317
318                            for path in the_module.__path__:
319                                if (not set_implicit_top and
320                                    not path.startswith(top_level_dir)):
321                                    continue
322                                self._top_level_dir = \
323                                    (path.split(the_module.__name__
324                                         .replace(".", os.path.sep))[0])
325                                tests.extend(self._find_tests(path,
326                                                              pattern,
327                                                              namespace=True))
328                    elif the_module.__name__ in sys.builtin_module_names:
329                        # builtin module
330                        raise TypeError('Can not use builtin modules '
331                                        'as dotted module names') from None
332                    else:
333                        raise TypeError(
334                            'don\'t know how to discover from {!r}'
335                            .format(the_module)) from None
336
337                if set_implicit_top:
338                    if not is_namespace:
339                        self._top_level_dir = \
340                           self._get_directory_containing_module(top_part)
341                        sys.path.remove(top_level_dir)
342                    else:
343                        sys.path.remove(top_level_dir)
344
345        if is_not_importable:
346            raise ImportError('Start directory is not importable: %r' % start_dir)
347
348        if not is_namespace:
349            tests = list(self._find_tests(start_dir, pattern))
350        return self.suiteClass(tests)
351
352    def _get_directory_containing_module(self, module_name):
353        module = sys.modules[module_name]
354        full_path = os.path.abspath(module.__file__)
355
356        if os.path.basename(full_path).lower().startswith('__init__.py'):
357            return os.path.dirname(os.path.dirname(full_path))
358        else:
359            # here we have been given a module rather than a package - so
360            # all we can do is search the *same* directory the module is in
361            # should an exception be raised instead
362            return os.path.dirname(full_path)
363
364    def _get_name_from_path(self, path):
365        if path == self._top_level_dir:
366            return '.'
367        path = _jython_aware_splitext(os.path.normpath(path))
368
369        _relpath = os.path.relpath(path, self._top_level_dir)
370        assert not os.path.isabs(_relpath), "Path must be within the project"
371        assert not _relpath.startswith('..'), "Path must be within the project"
372
373        name = _relpath.replace(os.path.sep, '.')
374        return name
375
376    def _get_module_from_name(self, name):
377        __import__(name)
378        return sys.modules[name]
379
380    def _match_path(self, path, full_path, pattern):
381        # override this method to use alternative matching strategy
382        return fnmatch(path, pattern)
383
384    def _find_tests(self, start_dir, pattern, namespace=False):
385        """Used by discovery. Yields test suites it loads."""
386        # Handle the __init__ in this package
387        name = self._get_name_from_path(start_dir)
388        # name is '.' when start_dir == top_level_dir (and top_level_dir is by
389        # definition not a package).
390        if name != '.' and name not in self._loading_packages:
391            # name is in self._loading_packages while we have called into
392            # loadTestsFromModule with name.
393            tests, should_recurse = self._find_test_path(
394                start_dir, pattern, namespace)
395            if tests is not None:
396                yield tests
397            if not should_recurse:
398                # Either an error occurred, or load_tests was used by the
399                # package.
400                return
401        # Handle the contents.
402        paths = sorted(os.listdir(start_dir))
403        for path in paths:
404            full_path = os.path.join(start_dir, path)
405            tests, should_recurse = self._find_test_path(
406                full_path, pattern, namespace)
407            if tests is not None:
408                yield tests
409            if should_recurse:
410                # we found a package that didn't use load_tests.
411                name = self._get_name_from_path(full_path)
412                self._loading_packages.add(name)
413                try:
414                    yield from self._find_tests(full_path, pattern, namespace)
415                finally:
416                    self._loading_packages.discard(name)
417
418    def _find_test_path(self, full_path, pattern, namespace=False):
419        """Used by discovery.
420
421        Loads tests from a single file, or a directories' __init__.py when
422        passed the directory.
423
424        Returns a tuple (None_or_tests_from_file, should_recurse).
425        """
426        basename = os.path.basename(full_path)
427        if os.path.isfile(full_path):
428            if not VALID_MODULE_NAME.match(basename):
429                # valid Python identifiers only
430                return None, False
431            if not self._match_path(basename, full_path, pattern):
432                return None, False
433            # if the test file matches, load it
434            name = self._get_name_from_path(full_path)
435            try:
436                module = self._get_module_from_name(name)
437            except case.SkipTest as e:
438                return _make_skipped_test(name, e, self.suiteClass), False
439            except:
440                error_case, error_message = \
441                    _make_failed_import_test(name, self.suiteClass)
442                self.errors.append(error_message)
443                return error_case, False
444            else:
445                mod_file = os.path.abspath(
446                    getattr(module, '__file__', full_path))
447                realpath = _jython_aware_splitext(
448                    os.path.realpath(mod_file))
449                fullpath_noext = _jython_aware_splitext(
450                    os.path.realpath(full_path))
451                if realpath.lower() != fullpath_noext.lower():
452                    module_dir = os.path.dirname(realpath)
453                    mod_name = _jython_aware_splitext(
454                        os.path.basename(full_path))
455                    expected_dir = os.path.dirname(full_path)
456                    msg = ("%r module incorrectly imported from %r. Expected "
457                           "%r. Is this module globally installed?")
458                    raise ImportError(
459                        msg % (mod_name, module_dir, expected_dir))
460                return self.loadTestsFromModule(module, pattern=pattern), False
461        elif os.path.isdir(full_path):
462            if (not namespace and
463                not os.path.isfile(os.path.join(full_path, '__init__.py'))):
464                return None, False
465
466            load_tests = None
467            tests = None
468            name = self._get_name_from_path(full_path)
469            try:
470                package = self._get_module_from_name(name)
471            except case.SkipTest as e:
472                return _make_skipped_test(name, e, self.suiteClass), False
473            except:
474                error_case, error_message = \
475                    _make_failed_import_test(name, self.suiteClass)
476                self.errors.append(error_message)
477                return error_case, False
478            else:
479                load_tests = getattr(package, 'load_tests', None)
480                # Mark this package as being in load_tests (possibly ;))
481                self._loading_packages.add(name)
482                try:
483                    tests = self.loadTestsFromModule(package, pattern=pattern)
484                    if load_tests is not None:
485                        # loadTestsFromModule(package) has loaded tests for us.
486                        return tests, False
487                    return tests, True
488                finally:
489                    self._loading_packages.discard(name)
490        else:
491            return None, False
492
493
494defaultTestLoader = TestLoader()
495
496
497def _makeLoader(prefix, sortUsing, suiteClass=None, testNamePatterns=None):
498    loader = TestLoader()
499    loader.sortTestMethodsUsing = sortUsing
500    loader.testMethodPrefix = prefix
501    loader.testNamePatterns = testNamePatterns
502    if suiteClass:
503        loader.suiteClass = suiteClass
504    return loader
505
506def getTestCaseNames(testCaseClass, prefix, sortUsing=util.three_way_cmp, testNamePatterns=None):
507    return _makeLoader(prefix, sortUsing, testNamePatterns=testNamePatterns).getTestCaseNames(testCaseClass)
508
509def makeSuite(testCaseClass, prefix='test', sortUsing=util.three_way_cmp,
510              suiteClass=suite.TestSuite):
511    return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(
512        testCaseClass)
513
514def findTestCases(module, prefix='test', sortUsing=util.three_way_cmp,
515                  suiteClass=suite.TestSuite):
516    return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(\
517        module)
518