1"""Loading unittests."""
2
3import os
4import re
5import sys
6import traceback
7import types
8
9from functools import cmp_to_key as _CmpToKey
10from fnmatch import fnmatch
11
12from . import case, suite
13
14__unittest = True
15
16# what about .pyc or .pyo (etc)
17# we would need to avoid loading the same tests multiple times
18# from '.py', '.pyc' *and* '.pyo'
19VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE)
20
21
22def _make_failed_import_test(name, suiteClass):
23    message = 'Failed to import test module: %s\n%s' % (name, traceback.format_exc())
24    return _make_failed_test('ModuleImportFailure', name, ImportError(message),
25                             suiteClass)
26
27def _make_failed_load_tests(name, exception, suiteClass):
28    return _make_failed_test('LoadTestsFailure', name, exception, suiteClass)
29
30def _make_failed_test(classname, methodname, exception, suiteClass):
31    def testFailure(self):
32        raise exception
33    attrs = {methodname: testFailure}
34    TestClass = type(classname, (case.TestCase,), attrs)
35    return suiteClass((TestClass(methodname),))
36
37
38class TestLoader(object):
39    """
40    This class is responsible for loading tests according to various criteria
41    and returning them wrapped in a TestSuite
42    """
43    testMethodPrefix = 'test'
44    sortTestMethodsUsing = cmp
45    suiteClass = suite.TestSuite
46    _top_level_dir = None
47
48    def loadTestsFromTestCase(self, testCaseClass):
49        """Return a suite of all test cases contained in testCaseClass"""
50        if issubclass(testCaseClass, suite.TestSuite):
51            raise TypeError("Test cases should not be derived from TestSuite." \
52                                " Maybe you meant to derive from TestCase?")
53        testCaseNames = self.getTestCaseNames(testCaseClass)
54        if not testCaseNames and hasattr(testCaseClass, 'runTest'):
55            testCaseNames = ['runTest']
56        loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames))
57        return loaded_suite
58
59    def loadTestsFromModule(self, module, use_load_tests=True):
60        """Return a suite of all test cases contained in the given module"""
61        tests = []
62        for name in dir(module):
63            obj = getattr(module, name)
64            if isinstance(obj, type) and issubclass(obj, case.TestCase):
65                tests.append(self.loadTestsFromTestCase(obj))
66
67        load_tests = getattr(module, 'load_tests', None)
68        tests = self.suiteClass(tests)
69        if use_load_tests and load_tests is not None:
70            try:
71                return load_tests(self, tests, None)
72            except Exception, e:
73                return _make_failed_load_tests(module.__name__, e,
74                                               self.suiteClass)
75        return tests
76
77    def loadTestsFromName(self, name, module=None):
78        """Return a suite of all test cases given a string specifier.
79
80        The name may resolve either to a module, a test case class, a
81        test method within a test case class, or a callable object which
82        returns a TestCase or TestSuite instance.
83
84        The method optionally resolves the names relative to a given module.
85        """
86        parts = name.split('.')
87        if module is None:
88            parts_copy = parts[:]
89            while parts_copy:
90                try:
91                    module = __import__('.'.join(parts_copy))
92                    break
93                except ImportError:
94                    del parts_copy[-1]
95                    if not parts_copy:
96                        raise
97            parts = parts[1:]
98        obj = module
99        for part in parts:
100            parent, obj = obj, getattr(obj, part)
101
102        if isinstance(obj, types.ModuleType):
103            return self.loadTestsFromModule(obj)
104        elif isinstance(obj, type) and issubclass(obj, case.TestCase):
105            return self.loadTestsFromTestCase(obj)
106        elif (isinstance(obj, types.UnboundMethodType) and
107              isinstance(parent, type) and
108              issubclass(parent, case.TestCase)):
109            name = parts[-1]
110            inst = parent(name)
111            return self.suiteClass([inst])
112        elif isinstance(obj, suite.TestSuite):
113            return obj
114        elif hasattr(obj, '__call__'):
115            test = obj()
116            if isinstance(test, suite.TestSuite):
117                return test
118            elif isinstance(test, case.TestCase):
119                return self.suiteClass([test])
120            else:
121                raise TypeError("calling %s returned %s, not a test" %
122                                (obj, test))
123        else:
124            raise TypeError("don't know how to make test from: %s" % obj)
125
126    def loadTestsFromNames(self, names, module=None):
127        """Return a suite of all test cases found using the given sequence
128        of string specifiers. See 'loadTestsFromName()'.
129        """
130        suites = [self.loadTestsFromName(name, module) for name in names]
131        return self.suiteClass(suites)
132
133    def getTestCaseNames(self, testCaseClass):
134        """Return a sorted sequence of method names found within testCaseClass
135        """
136        def isTestMethod(attrname, testCaseClass=testCaseClass,
137                         prefix=self.testMethodPrefix):
138            return attrname.startswith(prefix) and \
139                hasattr(getattr(testCaseClass, attrname), '__call__')
140        testFnNames = filter(isTestMethod, dir(testCaseClass))
141        if self.sortTestMethodsUsing:
142            testFnNames.sort(key=_CmpToKey(self.sortTestMethodsUsing))
143        return testFnNames
144
145    def discover(self, start_dir, pattern='test*.py', top_level_dir=None):
146        """Find and return all test modules from the specified start
147        directory, recursing into subdirectories to find them. Only test files
148        that match the pattern will be loaded. (Using shell style pattern
149        matching.)
150
151        All test modules must be importable from the top level of the project.
152        If the start directory is not the top level directory then the top
153        level directory must be specified separately.
154
155        If a test package name (directory with '__init__.py') matches the
156        pattern then the package will be checked for a 'load_tests' function. If
157        this exists then it will be called with loader, tests, pattern.
158
159        If load_tests exists then discovery does  *not* recurse into the package,
160        load_tests is responsible for loading all tests in the package.
161
162        The pattern is deliberately not stored as a loader attribute so that
163        packages can continue discovery themselves. top_level_dir is stored so
164        load_tests does not need to pass this argument in to loader.discover().
165        """
166        set_implicit_top = False
167        if top_level_dir is None and self._top_level_dir is not None:
168            # make top_level_dir optional if called from load_tests in a package
169            top_level_dir = self._top_level_dir
170        elif top_level_dir is None:
171            set_implicit_top = True
172            top_level_dir = start_dir
173
174        top_level_dir = os.path.abspath(top_level_dir)
175
176        if not top_level_dir in sys.path:
177            # all test modules must be importable from the top level directory
178            # should we *unconditionally* put the start directory in first
179            # in sys.path to minimise likelihood of conflicts between installed
180            # modules and development versions?
181            sys.path.insert(0, top_level_dir)
182        self._top_level_dir = top_level_dir
183
184        is_not_importable = False
185        if os.path.isdir(os.path.abspath(start_dir)):
186            start_dir = os.path.abspath(start_dir)
187            if start_dir != top_level_dir:
188                is_not_importable = not os.path.isfile(os.path.join(start_dir, '__init__.py'))
189        else:
190            # support for discovery from dotted module names
191            try:
192                __import__(start_dir)
193            except ImportError:
194                is_not_importable = True
195            else:
196                the_module = sys.modules[start_dir]
197                top_part = start_dir.split('.')[0]
198                start_dir = os.path.abspath(os.path.dirname((the_module.__file__)))
199                if set_implicit_top:
200                    self._top_level_dir = self._get_directory_containing_module(top_part)
201                    sys.path.remove(top_level_dir)
202
203        if is_not_importable:
204            raise ImportError('Start directory is not importable: %r' % start_dir)
205
206        tests = list(self._find_tests(start_dir, pattern))
207        return self.suiteClass(tests)
208
209    def _get_directory_containing_module(self, module_name):
210        module = sys.modules[module_name]
211        full_path = os.path.abspath(module.__file__)
212
213        if os.path.basename(full_path).lower().startswith('__init__.py'):
214            return os.path.dirname(os.path.dirname(full_path))
215        else:
216            # here we have been given a module rather than a package - so
217            # all we can do is search the *same* directory the module is in
218            # should an exception be raised instead
219            return os.path.dirname(full_path)
220
221    def _get_name_from_path(self, path):
222        path = os.path.splitext(os.path.normpath(path))[0]
223
224        _relpath = os.path.relpath(path, self._top_level_dir)
225        assert not os.path.isabs(_relpath), "Path must be within the project"
226        assert not _relpath.startswith('..'), "Path must be within the project"
227
228        name = _relpath.replace(os.path.sep, '.')
229        return name
230
231    def _get_module_from_name(self, name):
232        __import__(name)
233        return sys.modules[name]
234
235    def _match_path(self, path, full_path, pattern):
236        # override this method to use alternative matching strategy
237        return fnmatch(path, pattern)
238
239    def _find_tests(self, start_dir, pattern):
240        """Used by discovery. Yields test suites it loads."""
241        paths = os.listdir(start_dir)
242
243        for path in paths:
244            full_path = os.path.join(start_dir, path)
245            if os.path.isfile(full_path):
246                if not VALID_MODULE_NAME.match(path):
247                    # valid Python identifiers only
248                    continue
249                if not self._match_path(path, full_path, pattern):
250                    continue
251                # if the test file matches, load it
252                name = self._get_name_from_path(full_path)
253                try:
254                    module = self._get_module_from_name(name)
255                except:
256                    yield _make_failed_import_test(name, self.suiteClass)
257                else:
258                    mod_file = os.path.abspath(getattr(module, '__file__', full_path))
259                    realpath = os.path.splitext(os.path.realpath(mod_file))[0]
260                    fullpath_noext = os.path.splitext(os.path.realpath(full_path))[0]
261                    if realpath.lower() != fullpath_noext.lower():
262                        module_dir = os.path.dirname(realpath)
263                        mod_name = os.path.splitext(os.path.basename(full_path))[0]
264                        expected_dir = os.path.dirname(full_path)
265                        msg = ("%r module incorrectly imported from %r. Expected %r. "
266                               "Is this module globally installed?")
267                        raise ImportError(msg % (mod_name, module_dir, expected_dir))
268                    yield self.loadTestsFromModule(module)
269            elif os.path.isdir(full_path):
270                if not os.path.isfile(os.path.join(full_path, '__init__.py')):
271                    continue
272
273                load_tests = None
274                tests = None
275                if fnmatch(path, pattern):
276                    # only check load_tests if the package directory itself matches the filter
277                    name = self._get_name_from_path(full_path)
278                    package = self._get_module_from_name(name)
279                    load_tests = getattr(package, 'load_tests', None)
280                    tests = self.loadTestsFromModule(package, use_load_tests=False)
281
282                if load_tests is None:
283                    if tests is not None:
284                        # tests loaded from package file
285                        yield tests
286                    # recurse into the package
287                    for test in self._find_tests(full_path, pattern):
288                        yield test
289                else:
290                    try:
291                        yield load_tests(self, tests, pattern)
292                    except Exception, e:
293                        yield _make_failed_load_tests(package.__name__, e,
294                                                      self.suiteClass)
295
296defaultTestLoader = TestLoader()
297
298
299def _makeLoader(prefix, sortUsing, suiteClass=None):
300    loader = TestLoader()
301    loader.sortTestMethodsUsing = sortUsing
302    loader.testMethodPrefix = prefix
303    if suiteClass:
304        loader.suiteClass = suiteClass
305    return loader
306
307def getTestCaseNames(testCaseClass, prefix, sortUsing=cmp):
308    return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass)
309
310def makeSuite(testCaseClass, prefix='test', sortUsing=cmp,
311              suiteClass=suite.TestSuite):
312    return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(testCaseClass)
313
314def findTestCases(module, prefix='test', sortUsing=cmp,
315                  suiteClass=suite.TestSuite):
316    return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(module)
317