1"""TestSuite"""
2
3import sys
4
5from . import case
6from . import util
7
8__unittest = True
9
10
11def _call_if_exists(parent, attr):
12    func = getattr(parent, attr, lambda: None)
13    func()
14
15
16class BaseTestSuite(object):
17    """A simple test suite that doesn't provide class or module shared fixtures.
18    """
19    _cleanup = True
20
21    def __init__(self, tests=()):
22        self._tests = []
23        self._removed_tests = 0
24        self.addTests(tests)
25
26    def __repr__(self):
27        return "<%s tests=%s>" % (util.strclass(self.__class__), list(self))
28
29    def __eq__(self, other):
30        if not isinstance(other, self.__class__):
31            return NotImplemented
32        return list(self) == list(other)
33
34    def __iter__(self):
35        return iter(self._tests)
36
37    def countTestCases(self):
38        cases = self._removed_tests
39        for test in self:
40            if test:
41                cases += test.countTestCases()
42        return cases
43
44    def addTest(self, test):
45        # sanity checks
46        if not callable(test):
47            raise TypeError("{} is not callable".format(repr(test)))
48        if isinstance(test, type) and issubclass(test,
49                                                 (case.TestCase, TestSuite)):
50            raise TypeError("TestCases and TestSuites must be instantiated "
51                            "before passing them to addTest()")
52        self._tests.append(test)
53
54    def addTests(self, tests):
55        if isinstance(tests, str):
56            raise TypeError("tests must be an iterable of tests, not a string")
57        for test in tests:
58            self.addTest(test)
59
60    def run(self, result):
61        for index, test in enumerate(self):
62            if result.shouldStop:
63                break
64            test(result)
65            if self._cleanup:
66                self._removeTestAtIndex(index)
67        return result
68
69    def _removeTestAtIndex(self, index):
70        """Stop holding a reference to the TestCase at index."""
71        try:
72            test = self._tests[index]
73        except TypeError:
74            # support for suite implementations that have overridden self._tests
75            pass
76        else:
77            # Some unittest tests add non TestCase/TestSuite objects to
78            # the suite.
79            if hasattr(test, 'countTestCases'):
80                self._removed_tests += test.countTestCases()
81            self._tests[index] = None
82
83    def __call__(self, *args, **kwds):
84        return self.run(*args, **kwds)
85
86    def debug(self):
87        """Run the tests without collecting errors in a TestResult"""
88        for test in self:
89            test.debug()
90
91
92class TestSuite(BaseTestSuite):
93    """A test suite is a composite test consisting of a number of TestCases.
94
95    For use, create an instance of TestSuite, then add test case instances.
96    When all tests have been added, the suite can be passed to a test
97    runner, such as TextTestRunner. It will run the individual test cases
98    in the order in which they were added, aggregating the results. When
99    subclassing, do not forget to call the base class constructor.
100    """
101
102    def run(self, result, debug=False):
103        topLevel = False
104        if getattr(result, '_testRunEntered', False) is False:
105            result._testRunEntered = topLevel = True
106
107        for index, test in enumerate(self):
108            if result.shouldStop:
109                break
110
111            if _isnotsuite(test):
112                self._tearDownPreviousClass(test, result)
113                self._handleModuleFixture(test, result)
114                self._handleClassSetUp(test, result)
115                result._previousTestClass = test.__class__
116
117                if (getattr(test.__class__, '_classSetupFailed', False) or
118                    getattr(result, '_moduleSetUpFailed', False)):
119                    continue
120
121            if not debug:
122                test(result)
123            else:
124                test.debug()
125
126            if self._cleanup:
127                self._removeTestAtIndex(index)
128
129        if topLevel:
130            self._tearDownPreviousClass(None, result)
131            self._handleModuleTearDown(result)
132            result._testRunEntered = False
133        return result
134
135    def debug(self):
136        """Run the tests without collecting errors in a TestResult"""
137        debug = _DebugResult()
138        self.run(debug, True)
139
140    ################################
141
142    def _handleClassSetUp(self, test, result):
143        previousClass = getattr(result, '_previousTestClass', None)
144        currentClass = test.__class__
145        if currentClass == previousClass:
146            return
147        if result._moduleSetUpFailed:
148            return
149        if getattr(currentClass, "__unittest_skip__", False):
150            return
151
152        try:
153            currentClass._classSetupFailed = False
154        except TypeError:
155            # test may actually be a function
156            # so its class will be a builtin-type
157            pass
158
159        setUpClass = getattr(currentClass, 'setUpClass', None)
160        if setUpClass is not None:
161            _call_if_exists(result, '_setupStdout')
162            try:
163                setUpClass()
164            except Exception as e:
165                if isinstance(result, _DebugResult):
166                    raise
167                currentClass._classSetupFailed = True
168                className = util.strclass(currentClass)
169                errorName = 'setUpClass (%s)' % className
170                self._addClassOrModuleLevelException(result, e, errorName)
171            finally:
172                _call_if_exists(result, '_restoreStdout')
173
174    def _get_previous_module(self, result):
175        previousModule = None
176        previousClass = getattr(result, '_previousTestClass', None)
177        if previousClass is not None:
178            previousModule = previousClass.__module__
179        return previousModule
180
181
182    def _handleModuleFixture(self, test, result):
183        previousModule = self._get_previous_module(result)
184        currentModule = test.__class__.__module__
185        if currentModule == previousModule:
186            return
187
188        self._handleModuleTearDown(result)
189
190
191        result._moduleSetUpFailed = False
192        try:
193            module = sys.modules[currentModule]
194        except KeyError:
195            return
196        setUpModule = getattr(module, 'setUpModule', None)
197        if setUpModule is not None:
198            _call_if_exists(result, '_setupStdout')
199            try:
200                setUpModule()
201            except Exception as e:
202                if isinstance(result, _DebugResult):
203                    raise
204                result._moduleSetUpFailed = True
205                errorName = 'setUpModule (%s)' % currentModule
206                self._addClassOrModuleLevelException(result, e, errorName)
207            finally:
208                _call_if_exists(result, '_restoreStdout')
209
210    def _addClassOrModuleLevelException(self, result, exception, errorName):
211        error = _ErrorHolder(errorName)
212        addSkip = getattr(result, 'addSkip', None)
213        if addSkip is not None and isinstance(exception, case.SkipTest):
214            addSkip(error, str(exception))
215        else:
216            result.addError(error, sys.exc_info())
217
218    def _handleModuleTearDown(self, result):
219        previousModule = self._get_previous_module(result)
220        if previousModule is None:
221            return
222        if result._moduleSetUpFailed:
223            return
224
225        try:
226            module = sys.modules[previousModule]
227        except KeyError:
228            return
229
230        tearDownModule = getattr(module, 'tearDownModule', None)
231        if tearDownModule is not None:
232            _call_if_exists(result, '_setupStdout')
233            try:
234                tearDownModule()
235            except Exception as e:
236                if isinstance(result, _DebugResult):
237                    raise
238                errorName = 'tearDownModule (%s)' % previousModule
239                self._addClassOrModuleLevelException(result, e, errorName)
240            finally:
241                _call_if_exists(result, '_restoreStdout')
242
243    def _tearDownPreviousClass(self, test, result):
244        previousClass = getattr(result, '_previousTestClass', None)
245        currentClass = test.__class__
246        if currentClass == previousClass:
247            return
248        if getattr(previousClass, '_classSetupFailed', False):
249            return
250        if getattr(result, '_moduleSetUpFailed', False):
251            return
252        if getattr(previousClass, "__unittest_skip__", False):
253            return
254
255        tearDownClass = getattr(previousClass, 'tearDownClass', None)
256        if tearDownClass is not None:
257            _call_if_exists(result, '_setupStdout')
258            try:
259                tearDownClass()
260            except Exception as e:
261                if isinstance(result, _DebugResult):
262                    raise
263                className = util.strclass(previousClass)
264                errorName = 'tearDownClass (%s)' % className
265                self._addClassOrModuleLevelException(result, e, errorName)
266            finally:
267                _call_if_exists(result, '_restoreStdout')
268
269
270class _ErrorHolder(object):
271    """
272    Placeholder for a TestCase inside a result. As far as a TestResult
273    is concerned, this looks exactly like a unit test. Used to insert
274    arbitrary errors into a test suite run.
275    """
276    # Inspired by the ErrorHolder from Twisted:
277    # http://twistedmatrix.com/trac/browser/trunk/twisted/trial/runner.py
278
279    # attribute used by TestResult._exc_info_to_string
280    failureException = None
281
282    def __init__(self, description):
283        self.description = description
284
285    def id(self):
286        return self.description
287
288    def shortDescription(self):
289        return None
290
291    def __repr__(self):
292        return "<ErrorHolder description=%r>" % (self.description,)
293
294    def __str__(self):
295        return self.id()
296
297    def run(self, result):
298        # could call result.addError(...) - but this test-like object
299        # shouldn't be run anyway
300        pass
301
302    def __call__(self, result):
303        return self.run(result)
304
305    def countTestCases(self):
306        return 0
307
308def _isnotsuite(test):
309    "A crude way to tell apart testcases and suites with duck-typing"
310    try:
311        iter(test)
312    except TypeError:
313        return True
314    return False
315
316
317class _DebugResult(object):
318    "Used by the TestSuite to hold previous class when running in debug."
319    _previousTestClass = None
320    _moduleSetUpFailed = False
321    shouldStop = False
322