1"""TestSuite"""
2
3import sys
4import unittest
5from unittest2 import case, util
6import six
7
8__unittest = True
9
10
11class BaseTestSuite(unittest.TestSuite):
12    """A simple test suite that doesn't provide class or module shared fixtures.
13    """
14
15    def __init__(self, tests=()):
16        self._tests = []
17        self.addTests(tests)
18
19    def __repr__(self):
20        return "<%s tests=%s>" % (util.strclass(self.__class__), list(self))
21
22    def __eq__(self, other):
23        if not isinstance(other, self.__class__):
24            return NotImplemented
25        return list(self) == list(other)
26
27    def __ne__(self, other):
28        return not self == other
29
30    # Can't guarantee hash invariant, so flag as unhashable
31    __hash__ = None
32
33    def __iter__(self):
34        return iter(self._tests)
35
36    def countTestCases(self):
37        cases = 0
38        for test in self:
39            cases += test.countTestCases()
40        return cases
41
42    def addTest(self, test):
43        # sanity checks
44        if not hasattr(test, '__call__'):
45            raise TypeError("%r is not callable" % (repr(test),))
46        if isinstance(test, type) and issubclass(test,
47                                                 (case.TestCase, TestSuite)):
48            raise TypeError("TestCases and TestSuites must be instantiated "
49                            "before passing them to addTest()")
50        self._tests.append(test)
51
52    def addTests(self, tests):
53        if isinstance(tests, six.string_types):
54            raise TypeError("tests must be an iterable of tests, not a string")
55        for test in tests:
56            self.addTest(test)
57
58    def run(self, result):
59        for test in self:
60            if result.shouldStop:
61                break
62            test(result)
63        return result
64
65    def __call__(self, *args, **kwds):
66        return self.run(*args, **kwds)
67
68    def debug(self):
69        """Run the tests without collecting errors in a TestResult"""
70        for test in self:
71            test.debug()
72
73
74class TestSuite(BaseTestSuite):
75    """A test suite is a composite test consisting of a number of TestCases.
76
77    For use, create an instance of TestSuite, then add test case instances.
78    When all tests have been added, the suite can be passed to a test
79    runner, such as TextTestRunner. It will run the individual test cases
80    in the order in which they were added, aggregating the results. When
81    subclassing, do not forget to call the base class constructor.
82    """
83
84    def run(self, result):
85        self._wrapped_run(result)
86        self._tearDownPreviousClass(None, result)
87        self._handleModuleTearDown(result)
88        return result
89
90    def debug(self):
91        """Run the tests without collecting errors in a TestResult"""
92        debug = _DebugResult()
93        self._wrapped_run(debug, True)
94        self._tearDownPreviousClass(None, debug)
95        self._handleModuleTearDown(debug)
96
97    ################################
98    # private methods
99    def _wrapped_run(self, result, debug=False):
100        for test in self:
101            if result.shouldStop:
102                break
103
104            if _isnotsuite(test):
105                self._tearDownPreviousClass(test, result)
106                self._handleModuleFixture(test, result)
107                self._handleClassSetUp(test, result)
108                result._previousTestClass = test.__class__
109
110                if (getattr(test.__class__, '_classSetupFailed', False) or
111                        getattr(result, '_moduleSetUpFailed', False)):
112                    continue
113
114            if hasattr(test, '_wrapped_run'):
115                test._wrapped_run(result, debug)
116            elif not debug:
117                test(result)
118            else:
119                test.debug()
120
121    def _handleClassSetUp(self, test, result):
122        previousClass = getattr(result, '_previousTestClass', None)
123        currentClass = test.__class__
124        if currentClass == previousClass:
125            return
126        if result._moduleSetUpFailed:
127            return
128        if getattr(currentClass, "__unittest_skip__", False):
129            return
130
131        try:
132            currentClass._classSetupFailed = False
133        except TypeError:
134            # test may actually be a function
135            # so its class will be a builtin-type
136            pass
137
138        setUpClass = getattr(currentClass, 'setUpClass', None)
139        if setUpClass is not None:
140            try:
141                setUpClass()
142            except Exception as e:
143                if isinstance(result, _DebugResult):
144                    raise
145                currentClass._classSetupFailed = True
146                className = util.strclass(currentClass)
147                errorName = 'setUpClass (%s)' % className
148                self._addClassOrModuleLevelException(result, e, errorName)
149
150    def _get_previous_module(self, result):
151        previousModule = None
152        previousClass = getattr(result, '_previousTestClass', None)
153        if previousClass is not None:
154            previousModule = previousClass.__module__
155        return previousModule
156
157    def _handleModuleFixture(self, test, result):
158        previousModule = self._get_previous_module(result)
159        currentModule = test.__class__.__module__
160        if currentModule == previousModule:
161            return
162
163        self._handleModuleTearDown(result)
164
165        result._moduleSetUpFailed = False
166        try:
167            module = sys.modules[currentModule]
168        except KeyError:
169            return
170        setUpModule = getattr(module, 'setUpModule', None)
171        if setUpModule is not None:
172            try:
173                setUpModule()
174            except Exception as e:
175                if isinstance(result, _DebugResult):
176                    raise
177                result._moduleSetUpFailed = True
178                errorName = 'setUpModule (%s)' % currentModule
179                self._addClassOrModuleLevelException(result, e, errorName)
180
181    def _addClassOrModuleLevelException(self, result, exception, errorName):
182        error = _ErrorHolder(errorName)
183        addSkip = getattr(result, 'addSkip', None)
184        if addSkip is not None and isinstance(exception, case.SkipTest):
185            addSkip(error, str(exception))
186        else:
187            result.addError(error, sys.exc_info())
188
189    def _handleModuleTearDown(self, result):
190        previousModule = self._get_previous_module(result)
191        if previousModule is None:
192            return
193        if result._moduleSetUpFailed:
194            return
195
196        try:
197            module = sys.modules[previousModule]
198        except KeyError:
199            return
200
201        tearDownModule = getattr(module, 'tearDownModule', None)
202        if tearDownModule is not None:
203            try:
204                tearDownModule()
205            except Exception as e:
206                if isinstance(result, _DebugResult):
207                    raise
208                errorName = 'tearDownModule (%s)' % previousModule
209                self._addClassOrModuleLevelException(result, e, errorName)
210
211    def _tearDownPreviousClass(self, test, result):
212        previousClass = getattr(result, '_previousTestClass', None)
213        currentClass = test.__class__
214        if currentClass == previousClass:
215            return
216        if getattr(previousClass, '_classSetupFailed', False):
217            return
218        if getattr(result, '_moduleSetUpFailed', False):
219            return
220        if getattr(previousClass, "__unittest_skip__", False):
221            return
222
223        tearDownClass = getattr(previousClass, 'tearDownClass', None)
224        if tearDownClass is not None:
225            try:
226                tearDownClass()
227            except Exception as e:
228                if isinstance(result, _DebugResult):
229                    raise
230                className = util.strclass(previousClass)
231                errorName = 'tearDownClass (%s)' % className
232                self._addClassOrModuleLevelException(result, e, errorName)
233
234
235class _ErrorHolder(object):
236    """
237    Placeholder for a TestCase inside a result. As far as a TestResult
238    is concerned, this looks exactly like a unit test. Used to insert
239    arbitrary errors into a test suite run.
240    """
241    # Inspired by the ErrorHolder from Twisted:
242    # http://twistedmatrix.com/trac/browser/trunk/twisted/trial/runner.py
243
244    # attribute used by TestResult._exc_info_to_string
245    failureException = None
246
247    def __init__(self, description):
248        self.description = description
249
250    def id(self):
251        return self.description
252
253    def shortDescription(self):
254        return None
255
256    def __repr__(self):
257        return "<ErrorHolder description=%r>" % (self.description,)
258
259    def __str__(self):
260        return self.id()
261
262    def run(self, result):
263        # could call result.addError(...) - but this test-like object
264        # shouldn't be run anyway
265        pass
266
267    def __call__(self, result):
268        return self.run(result)
269
270    def countTestCases(self):
271        return 0
272
273
274def _isnotsuite(test):
275    "A crude way to tell apart testcases and suites with duck-typing"
276    try:
277        iter(test)
278    except TypeError:
279        return True
280    return False
281
282
283class _DebugResult(object):
284    "Used by the TestSuite to hold previous class when running in debug."
285    _previousTestClass = None
286    _moduleSetUpFailed = False
287    shouldStop = False
288