1"""TestSuite"""
2
3import sys
4
5from . import case
6from . import util
7
8__unittest = True
9
10
11def _call_if_exists(parent, attr):
12    func = getattr(parent, attr, lambda: None)
13    func()
14
15
16class BaseTestSuite(object):
17    """A simple test suite that doesn't provide class or module shared fixtures.
18    """
19    _cleanup = True
20
21    def __init__(self, tests=()):
22        self._tests = []
23        self._removed_tests = 0
24        self.addTests(tests)
25
26    def __repr__(self):
27        return "<%s tests=%s>" % (util.strclass(self.__class__), list(self))
28
29    def __eq__(self, other):
30        if not isinstance(other, self.__class__):
31            return NotImplemented
32        return list(self) == list(other)
33
34    def __iter__(self):
35        return iter(self._tests)
36
37    def countTestCases(self):
38        cases = self._removed_tests
39        for test in self:
40            if test:
41                cases += test.countTestCases()
42        return cases
43
44    def addTest(self, test):
45        # sanity checks
46        if not callable(test):
47            raise TypeError("{} is not callable".format(repr(test)))
48        if isinstance(test, type) and issubclass(test,
49                                                 (case.TestCase, TestSuite)):
50            raise TypeError("TestCases and TestSuites must be instantiated "
51                            "before passing them to addTest()")
52        self._tests.append(test)
53
54    def addTests(self, tests):
55        if isinstance(tests, str):
56            raise TypeError("tests must be an iterable of tests, not a string")
57        for test in tests:
58            self.addTest(test)
59
60    def run(self, result):
61        for index, test in enumerate(self):
62            if result.shouldStop:
63                break
64            test(result)
65            if self._cleanup:
66                self._removeTestAtIndex(index)
67        return result
68
69    def _removeTestAtIndex(self, index):
70        """Stop holding a reference to the TestCase at index."""
71        try:
72            test = self._tests[index]
73        except TypeError:
74            # support for suite implementations that have overridden self._tests
75            pass
76        else:
77            # Some unittest tests add non TestCase/TestSuite objects to
78            # the suite.
79            if hasattr(test, 'countTestCases'):
80                self._removed_tests += test.countTestCases()
81            self._tests[index] = None
82
83    def __call__(self, *args, **kwds):
84        return self.run(*args, **kwds)
85
86    def debug(self):
87        """Run the tests without collecting errors in a TestResult"""
88        for test in self:
89            test.debug()
90
91
92class TestSuite(BaseTestSuite):
93    """A test suite is a composite test consisting of a number of TestCases.
94
95    For use, create an instance of TestSuite, then add test case instances.
96    When all tests have been added, the suite can be passed to a test
97    runner, such as TextTestRunner. It will run the individual test cases
98    in the order in which they were added, aggregating the results. When
99    subclassing, do not forget to call the base class constructor.
100    """
101
102    def run(self, result, debug=False):
103        topLevel = False
104        if getattr(result, '_testRunEntered', False) is False:
105            result._testRunEntered = topLevel = True
106
107        for index, test in enumerate(self):
108            if result.shouldStop:
109                break
110
111            if _isnotsuite(test):
112                self._tearDownPreviousClass(test, result)
113                self._handleModuleFixture(test, result)
114                self._handleClassSetUp(test, result)
115                result._previousTestClass = test.__class__
116
117                if (getattr(test.__class__, '_classSetupFailed', False) or
118                    getattr(result, '_moduleSetUpFailed', False)):
119                    continue
120
121            if not debug:
122                test(result)
123            else:
124                test.debug()
125
126            if self._cleanup:
127                self._removeTestAtIndex(index)
128
129        if topLevel:
130            self._tearDownPreviousClass(None, result)
131            self._handleModuleTearDown(result)
132            result._testRunEntered = False
133        return result
134
135    def debug(self):
136        """Run the tests without collecting errors in a TestResult"""
137        debug = _DebugResult()
138        self.run(debug, True)
139
140    ################################
141
142    def _handleClassSetUp(self, test, result):
143        previousClass = getattr(result, '_previousTestClass', None)
144        currentClass = test.__class__
145        if currentClass == previousClass:
146            return
147        if result._moduleSetUpFailed:
148            return
149        if getattr(currentClass, "__unittest_skip__", False):
150            return
151
152        try:
153            currentClass._classSetupFailed = False
154        except TypeError:
155            # test may actually be a function
156            # so its class will be a builtin-type
157            pass
158
159        setUpClass = getattr(currentClass, 'setUpClass', None)
160        if setUpClass is not None:
161            _call_if_exists(result, '_setupStdout')
162            try:
163                setUpClass()
164            except Exception as e:
165                if isinstance(result, _DebugResult):
166                    raise
167                currentClass._classSetupFailed = True
168                className = util.strclass(currentClass)
169                self._createClassOrModuleLevelException(result, e,
170                                                        'setUpClass',
171                                                        className)
172            finally:
173                _call_if_exists(result, '_restoreStdout')
174                if currentClass._classSetupFailed is True:
175                    currentClass.doClassCleanups()
176                    if len(currentClass.tearDown_exceptions) > 0:
177                        for exc in currentClass.tearDown_exceptions:
178                            self._createClassOrModuleLevelException(
179                                    result, exc[1], 'setUpClass', className,
180                                    info=exc)
181
182    def _get_previous_module(self, result):
183        previousModule = None
184        previousClass = getattr(result, '_previousTestClass', None)
185        if previousClass is not None:
186            previousModule = previousClass.__module__
187        return previousModule
188
189
190    def _handleModuleFixture(self, test, result):
191        previousModule = self._get_previous_module(result)
192        currentModule = test.__class__.__module__
193        if currentModule == previousModule:
194            return
195
196        self._handleModuleTearDown(result)
197
198
199        result._moduleSetUpFailed = False
200        try:
201            module = sys.modules[currentModule]
202        except KeyError:
203            return
204        setUpModule = getattr(module, 'setUpModule', None)
205        if setUpModule is not None:
206            _call_if_exists(result, '_setupStdout')
207            try:
208                setUpModule()
209            except Exception as e:
210                try:
211                    case.doModuleCleanups()
212                except Exception as exc:
213                    self._createClassOrModuleLevelException(result, exc,
214                                                            'setUpModule',
215                                                            currentModule)
216                if isinstance(result, _DebugResult):
217                    raise
218                result._moduleSetUpFailed = True
219                self._createClassOrModuleLevelException(result, e,
220                                                        'setUpModule',
221                                                        currentModule)
222            finally:
223                _call_if_exists(result, '_restoreStdout')
224
225    def _createClassOrModuleLevelException(self, result, exc, method_name,
226                                           parent, info=None):
227        errorName = f'{method_name} ({parent})'
228        self._addClassOrModuleLevelException(result, exc, errorName, info)
229
230    def _addClassOrModuleLevelException(self, result, exception, errorName,
231                                        info=None):
232        error = _ErrorHolder(errorName)
233        addSkip = getattr(result, 'addSkip', None)
234        if addSkip is not None and isinstance(exception, case.SkipTest):
235            addSkip(error, str(exception))
236        else:
237            if not info:
238                result.addError(error, sys.exc_info())
239            else:
240                result.addError(error, info)
241
242    def _handleModuleTearDown(self, result):
243        previousModule = self._get_previous_module(result)
244        if previousModule is None:
245            return
246        if result._moduleSetUpFailed:
247            return
248
249        try:
250            module = sys.modules[previousModule]
251        except KeyError:
252            return
253
254        tearDownModule = getattr(module, 'tearDownModule', None)
255        if tearDownModule is not None:
256            _call_if_exists(result, '_setupStdout')
257            try:
258                tearDownModule()
259            except Exception as e:
260                if isinstance(result, _DebugResult):
261                    raise
262                self._createClassOrModuleLevelException(result, e,
263                                                        'tearDownModule',
264                                                        previousModule)
265            finally:
266                _call_if_exists(result, '_restoreStdout')
267                try:
268                    case.doModuleCleanups()
269                except Exception as e:
270                    self._createClassOrModuleLevelException(result, e,
271                                                            'tearDownModule',
272                                                            previousModule)
273
274    def _tearDownPreviousClass(self, test, result):
275        previousClass = getattr(result, '_previousTestClass', None)
276        currentClass = test.__class__
277        if currentClass == previousClass:
278            return
279        if getattr(previousClass, '_classSetupFailed', False):
280            return
281        if getattr(result, '_moduleSetUpFailed', False):
282            return
283        if getattr(previousClass, "__unittest_skip__", False):
284            return
285
286        tearDownClass = getattr(previousClass, 'tearDownClass', None)
287        if tearDownClass is not None:
288            _call_if_exists(result, '_setupStdout')
289            try:
290                tearDownClass()
291            except Exception as e:
292                if isinstance(result, _DebugResult):
293                    raise
294                className = util.strclass(previousClass)
295                self._createClassOrModuleLevelException(result, e,
296                                                        'tearDownClass',
297                                                        className)
298            finally:
299                _call_if_exists(result, '_restoreStdout')
300                previousClass.doClassCleanups()
301                if len(previousClass.tearDown_exceptions) > 0:
302                    for exc in previousClass.tearDown_exceptions:
303                        className = util.strclass(previousClass)
304                        self._createClassOrModuleLevelException(result, exc[1],
305                                                                'tearDownClass',
306                                                                className,
307                                                                info=exc)
308
309
310class _ErrorHolder(object):
311    """
312    Placeholder for a TestCase inside a result. As far as a TestResult
313    is concerned, this looks exactly like a unit test. Used to insert
314    arbitrary errors into a test suite run.
315    """
316    # Inspired by the ErrorHolder from Twisted:
317    # http://twistedmatrix.com/trac/browser/trunk/twisted/trial/runner.py
318
319    # attribute used by TestResult._exc_info_to_string
320    failureException = None
321
322    def __init__(self, description):
323        self.description = description
324
325    def id(self):
326        return self.description
327
328    def shortDescription(self):
329        return None
330
331    def __repr__(self):
332        return "<ErrorHolder description=%r>" % (self.description,)
333
334    def __str__(self):
335        return self.id()
336
337    def run(self, result):
338        # could call result.addError(...) - but this test-like object
339        # shouldn't be run anyway
340        pass
341
342    def __call__(self, result):
343        return self.run(result)
344
345    def countTestCases(self):
346        return 0
347
348def _isnotsuite(test):
349    "A crude way to tell apart testcases and suites with duck-typing"
350    try:
351        iter(test)
352    except TypeError:
353        return True
354    return False
355
356
357class _DebugResult(object):
358    "Used by the TestSuite to hold previous class when running in debug."
359    _previousTestClass = None
360    _moduleSetUpFailed = False
361    shouldStop = False
362