1"""TestSuite"""
2
3import sys
4import unittest
5
6import six
7
8from unittest2 import case, util
9
10__unittest = True
11
12
13class BaseTestSuite(unittest.TestSuite):
14    """A simple test suite that doesn't provide class or module shared fixtures.
15    """
16    _cleanup = True
17
18    def __init__(self, tests=()):
19        self._tests = []
20        self._removed_tests = 0
21        self.addTests(tests)
22
23    def __repr__(self):
24        return "<%s tests=%s>" % (util.strclass(self.__class__), list(self))
25
26    def __eq__(self, other):
27        if not isinstance(other, self.__class__):
28            return NotImplemented
29        return list(self) == list(other)
30
31    def __ne__(self, other):
32        return not self == other
33
34    # Can't guarantee hash invariant, so flag as unhashable
35    __hash__ = None
36
37    def __iter__(self):
38        return iter(self._tests)
39
40    def countTestCases(self):
41        cases = self._removed_tests
42        for test in self:
43            if test:
44                cases += test.countTestCases()
45        return cases
46
47    def addTest(self, test):
48        # sanity checks
49        if not hasattr(test, '__call__'):
50            raise TypeError("%r is not callable" % (repr(test),))
51        if isinstance(test, type) and issubclass(test,
52                                                 (case.TestCase, TestSuite)):
53            raise TypeError("TestCases and TestSuites must be instantiated "
54                            "before passing them to addTest()")
55        self._tests.append(test)
56
57    def addTests(self, tests):
58        if isinstance(tests, six.string_types):
59            raise TypeError("tests must be an iterable of tests, not a string")
60        for test in tests:
61            self.addTest(test)
62
63    def run(self, result):
64        for index, test in enumerate(self):
65            if result.shouldStop:
66                break
67            test(result)
68            if self._cleanup:
69                self._removeTestAtIndex(index)
70        return result
71
72    def _removeTestAtIndex(self, index):
73        """Stop holding a reference to the TestCase at index."""
74        try:
75            test = self._tests[index]
76        except TypeError:
77            # support for suite implementations that have overriden self._tests
78            pass
79        else:
80            # Some unittest tests add non TestCase/TestSuite objects to
81            # the suite.
82            if hasattr(test, 'countTestCases'):
83                self._removed_tests += test.countTestCases()
84            self._tests[index] = None
85
86    def __call__(self, *args, **kwds):
87        return self.run(*args, **kwds)
88
89    def debug(self):
90        """Run the tests without collecting errors in a TestResult"""
91        for test in self:
92            test.debug()
93
94
95class TestSuite(BaseTestSuite):
96    """A test suite is a composite test consisting of a number of TestCases.
97
98    For use, create an instance of TestSuite, then add test case instances.
99    When all tests have been added, the suite can be passed to a test
100    runner, such as TextTestRunner. It will run the individual test cases
101    in the order in which they were added, aggregating the results. When
102    subclassing, do not forget to call the base class constructor.
103    """
104
105
106    def run(self, result, debug=False):
107        topLevel = False
108        if getattr(result, '_testRunEntered', False) is False:
109            result._testRunEntered = topLevel = True
110
111        for index, test in enumerate(self):
112            if result.shouldStop:
113                break
114
115            if _isnotsuite(test):
116                self._tearDownPreviousClass(test, result)
117                self._handleModuleFixture(test, result)
118                self._handleClassSetUp(test, result)
119                result._previousTestClass = test.__class__
120
121                if (getattr(test.__class__, '_classSetupFailed', False) or
122                    getattr(result, '_moduleSetUpFailed', False)):
123                    continue
124
125            if not debug:
126                test(result)
127            else:
128                test.debug()
129
130            if self._cleanup:
131                self._removeTestAtIndex(index)
132
133        if topLevel:
134            self._tearDownPreviousClass(None, result)
135            self._handleModuleTearDown(result)
136        return result
137
138    def debug(self):
139        """Run the tests without collecting errors in a TestResult"""
140        debug = _DebugResult()
141        self.run(debug, True)
142
143    ################################
144
145    def _handleClassSetUp(self, test, result):
146        previousClass = getattr(result, '_previousTestClass', None)
147        currentClass = test.__class__
148        if currentClass == previousClass:
149            return
150        if result._moduleSetUpFailed:
151            return
152        if getattr(currentClass, "__unittest_skip__", False):
153            return
154
155        try:
156            currentClass._classSetupFailed = False
157        except TypeError:
158            # test may actually be a function
159            # so its class will be a builtin-type
160            pass
161
162        setUpClass = getattr(currentClass, 'setUpClass', None)
163        if setUpClass is not None:
164            try:
165                setUpClass()
166            except Exception:
167                e = sys.exc_info()[1]
168                if isinstance(result, _DebugResult):
169                    raise
170                currentClass._classSetupFailed = True
171                className = util.strclass(currentClass)
172                errorName = 'setUpClass (%s)' % className
173                self._addClassOrModuleLevelException(result, e, errorName)
174
175    def _get_previous_module(self, result):
176        previousModule = None
177        previousClass = getattr(result, '_previousTestClass', None)
178        if previousClass is not None:
179            previousModule = previousClass.__module__
180        return previousModule
181
182
183    def _handleModuleFixture(self, test, result):
184        previousModule = self._get_previous_module(result)
185        currentModule = test.__class__.__module__
186        if currentModule == previousModule:
187            return
188
189        self._handleModuleTearDown(result)
190
191
192        result._moduleSetUpFailed = False
193        try:
194            module = sys.modules[currentModule]
195        except KeyError:
196            return
197        setUpModule = getattr(module, 'setUpModule', None)
198        if setUpModule is not None:
199            try:
200                setUpModule()
201            except Exception:
202                e = sys.exc_info()[1]
203                if isinstance(result, _DebugResult):
204                    raise
205                result._moduleSetUpFailed = True
206                errorName = 'setUpModule (%s)' % currentModule
207                self._addClassOrModuleLevelException(result, e, errorName)
208
209    def _addClassOrModuleLevelException(self, result, exception, errorName):
210        error = _ErrorHolder(errorName)
211        addSkip = getattr(result, 'addSkip', None)
212        if addSkip is not None and isinstance(exception, case.SkipTest):
213            addSkip(error, str(exception))
214        else:
215            result.addError(error, sys.exc_info())
216
217    def _handleModuleTearDown(self, result):
218        previousModule = self._get_previous_module(result)
219        if previousModule is None:
220            return
221        if result._moduleSetUpFailed:
222            return
223
224        try:
225            module = sys.modules[previousModule]
226        except KeyError:
227            return
228
229        tearDownModule = getattr(module, 'tearDownModule', None)
230        if tearDownModule is not None:
231            try:
232                tearDownModule()
233            except Exception:
234                e = sys.exc_info()[1]
235                if isinstance(result, _DebugResult):
236                    raise
237                errorName = 'tearDownModule (%s)' % previousModule
238                self._addClassOrModuleLevelException(result, e, errorName)
239
240    def _tearDownPreviousClass(self, test, result):
241        previousClass = getattr(result, '_previousTestClass', None)
242        currentClass = test.__class__
243        if currentClass == previousClass:
244            return
245        if getattr(previousClass, '_classSetupFailed', False):
246            return
247        if getattr(result, '_moduleSetUpFailed', False):
248            return
249        if getattr(previousClass, "__unittest_skip__", False):
250            return
251
252        tearDownClass = getattr(previousClass, 'tearDownClass', None)
253        if tearDownClass is not None:
254            try:
255                tearDownClass()
256            except Exception:
257                e = sys.exc_info()[1]
258                if isinstance(result, _DebugResult):
259                    raise
260                className = util.strclass(previousClass)
261                errorName = 'tearDownClass (%s)' % className
262                self._addClassOrModuleLevelException(result, e, errorName)
263
264
265class _ErrorHolder(object):
266    """
267    Placeholder for a TestCase inside a result. As far as a TestResult
268    is concerned, this looks exactly like a unit test. Used to insert
269    arbitrary errors into a test suite run.
270    """
271    # Inspired by the ErrorHolder from Twisted:
272    # http://twistedmatrix.com/trac/browser/trunk/twisted/trial/runner.py
273
274    # attribute used by TestResult._exc_info_to_string
275    failureException = None
276
277    def __init__(self, description):
278        self.description = description
279
280    def id(self):
281        return self.description
282
283    def shortDescription(self):
284        return None
285
286    def __repr__(self):
287        return "<ErrorHolder description=%r>" % (self.description,)
288
289    def __str__(self):
290        return self.id()
291
292    def run(self, result):
293        # could call result.addError(...) - but this test-like object
294        # shouldn't be run anyway
295        pass
296
297    def __call__(self, result):
298        return self.run(result)
299
300    def countTestCases(self):
301        return 0
302
303def _isnotsuite(test):
304    "A crude way to tell apart testcases and suites with duck-typing"
305    try:
306        iter(test)
307    except TypeError:
308        return True
309    return False
310
311
312class _DebugResult(object):
313    "Used by the TestSuite to hold previous class when running in debug."
314    _previousTestClass = None
315    _moduleSetUpFailed = False
316    shouldStop = False
317