1"""TestSuite"""
2
3import sys
4
5from . import case
6from . import util
7
8__unittest = True
9
10
11def _call_if_exists(parent, attr):
12    func = getattr(parent, attr, lambda: None)
13    func()
14
15
16class BaseTestSuite(object):
17    """A simple test suite that doesn't provide class or module shared fixtures.
18    """
19    def __init__(self, tests=()):
20        self._tests = []
21        self.addTests(tests)
22
23    def __repr__(self):
24        return "<%s tests=%s>" % (util.strclass(self.__class__), list(self))
25
26    def __eq__(self, other):
27        if not isinstance(other, self.__class__):
28            return NotImplemented
29        return list(self) == list(other)
30
31    def __ne__(self, other):
32        return not self == other
33
34    # Can't guarantee hash invariant, so flag as unhashable
35    __hash__ = None
36
37    def __iter__(self):
38        return iter(self._tests)
39
40    def countTestCases(self):
41        cases = 0
42        for test in self:
43            cases += test.countTestCases()
44        return cases
45
46    def addTest(self, test):
47        # sanity checks
48        if not hasattr(test, '__call__'):
49            raise TypeError("{} is not callable".format(repr(test)))
50        if isinstance(test, type) and issubclass(test,
51                                                 (case.TestCase, TestSuite)):
52            raise TypeError("TestCases and TestSuites must be instantiated "
53                            "before passing them to addTest()")
54        self._tests.append(test)
55
56    def addTests(self, tests):
57        if isinstance(tests, basestring):
58            raise TypeError("tests must be an iterable of tests, not a string")
59        for test in tests:
60            self.addTest(test)
61
62    def run(self, result):
63        for test in self:
64            if result.shouldStop:
65                break
66            test(result)
67        return result
68
69    def __call__(self, *args, **kwds):
70        return self.run(*args, **kwds)
71
72    def debug(self):
73        """Run the tests without collecting errors in a TestResult"""
74        for test in self:
75            test.debug()
76
77
78class TestSuite(BaseTestSuite):
79    """A test suite is a composite test consisting of a number of TestCases.
80
81    For use, create an instance of TestSuite, then add test case instances.
82    When all tests have been added, the suite can be passed to a test
83    runner, such as TextTestRunner. It will run the individual test cases
84    in the order in which they were added, aggregating the results. When
85    subclassing, do not forget to call the base class constructor.
86    """
87
88    def run(self, result, debug=False):
89        topLevel = False
90        if getattr(result, '_testRunEntered', False) is False:
91            result._testRunEntered = topLevel = True
92
93        for test in self:
94            if result.shouldStop:
95                break
96
97            if _isnotsuite(test):
98                self._tearDownPreviousClass(test, result)
99                self._handleModuleFixture(test, result)
100                self._handleClassSetUp(test, result)
101                result._previousTestClass = test.__class__
102
103                if (getattr(test.__class__, '_classSetupFailed', False) or
104                    getattr(result, '_moduleSetUpFailed', False)):
105                    continue
106
107            if not debug:
108                test(result)
109            else:
110                test.debug()
111
112        if topLevel:
113            self._tearDownPreviousClass(None, result)
114            self._handleModuleTearDown(result)
115            result._testRunEntered = False
116        return result
117
118    def debug(self):
119        """Run the tests without collecting errors in a TestResult"""
120        debug = _DebugResult()
121        self.run(debug, True)
122
123    ################################
124
125    def _handleClassSetUp(self, test, result):
126        previousClass = getattr(result, '_previousTestClass', None)
127        currentClass = test.__class__
128        if currentClass == previousClass:
129            return
130        if result._moduleSetUpFailed:
131            return
132        if getattr(currentClass, "__unittest_skip__", False):
133            return
134
135        try:
136            currentClass._classSetupFailed = False
137        except TypeError:
138            # test may actually be a function
139            # so its class will be a builtin-type
140            pass
141
142        setUpClass = getattr(currentClass, 'setUpClass', None)
143        if setUpClass is not None:
144            _call_if_exists(result, '_setupStdout')
145            try:
146                setUpClass()
147            except Exception as e:
148                if isinstance(result, _DebugResult):
149                    raise
150                currentClass._classSetupFailed = True
151                className = util.strclass(currentClass)
152                errorName = 'setUpClass (%s)' % className
153                self._addClassOrModuleLevelException(result, e, errorName)
154            finally:
155                _call_if_exists(result, '_restoreStdout')
156
157    def _get_previous_module(self, result):
158        previousModule = None
159        previousClass = getattr(result, '_previousTestClass', None)
160        if previousClass is not None:
161            previousModule = previousClass.__module__
162        return previousModule
163
164
165    def _handleModuleFixture(self, test, result):
166        previousModule = self._get_previous_module(result)
167        currentModule = test.__class__.__module__
168        if currentModule == previousModule:
169            return
170
171        self._handleModuleTearDown(result)
172
173        result._moduleSetUpFailed = False
174        try:
175            module = sys.modules[currentModule]
176        except KeyError:
177            return
178        setUpModule = getattr(module, 'setUpModule', None)
179        if setUpModule is not None:
180            _call_if_exists(result, '_setupStdout')
181            try:
182                setUpModule()
183            except Exception, e:
184                if isinstance(result, _DebugResult):
185                    raise
186                result._moduleSetUpFailed = True
187                errorName = 'setUpModule (%s)' % currentModule
188                self._addClassOrModuleLevelException(result, e, errorName)
189            finally:
190                _call_if_exists(result, '_restoreStdout')
191
192    def _addClassOrModuleLevelException(self, result, exception, errorName):
193        error = _ErrorHolder(errorName)
194        addSkip = getattr(result, 'addSkip', None)
195        if addSkip is not None and isinstance(exception, case.SkipTest):
196            addSkip(error, str(exception))
197        else:
198            result.addError(error, sys.exc_info())
199
200    def _handleModuleTearDown(self, result):
201        previousModule = self._get_previous_module(result)
202        if previousModule is None:
203            return
204        if result._moduleSetUpFailed:
205            return
206
207        try:
208            module = sys.modules[previousModule]
209        except KeyError:
210            return
211
212        tearDownModule = getattr(module, 'tearDownModule', None)
213        if tearDownModule is not None:
214            _call_if_exists(result, '_setupStdout')
215            try:
216                tearDownModule()
217            except Exception as e:
218                if isinstance(result, _DebugResult):
219                    raise
220                errorName = 'tearDownModule (%s)' % previousModule
221                self._addClassOrModuleLevelException(result, e, errorName)
222            finally:
223                _call_if_exists(result, '_restoreStdout')
224
225    def _tearDownPreviousClass(self, test, result):
226        previousClass = getattr(result, '_previousTestClass', None)
227        currentClass = test.__class__
228        if currentClass == previousClass:
229            return
230        if getattr(previousClass, '_classSetupFailed', False):
231            return
232        if getattr(result, '_moduleSetUpFailed', False):
233            return
234        if getattr(previousClass, "__unittest_skip__", False):
235            return
236
237        tearDownClass = getattr(previousClass, 'tearDownClass', None)
238        if tearDownClass is not None:
239            _call_if_exists(result, '_setupStdout')
240            try:
241                tearDownClass()
242            except Exception, e:
243                if isinstance(result, _DebugResult):
244                    raise
245                className = util.strclass(previousClass)
246                errorName = 'tearDownClass (%s)' % className
247                self._addClassOrModuleLevelException(result, e, errorName)
248            finally:
249                _call_if_exists(result, '_restoreStdout')
250
251
252class _ErrorHolder(object):
253    """
254    Placeholder for a TestCase inside a result. As far as a TestResult
255    is concerned, this looks exactly like a unit test. Used to insert
256    arbitrary errors into a test suite run.
257    """
258    # Inspired by the ErrorHolder from Twisted:
259    # http://twistedmatrix.com/trac/browser/trunk/twisted/trial/runner.py
260
261    # attribute used by TestResult._exc_info_to_string
262    failureException = None
263
264    def __init__(self, description):
265        self.description = description
266
267    def id(self):
268        return self.description
269
270    def shortDescription(self):
271        return None
272
273    def __repr__(self):
274        return "<ErrorHolder description=%r>" % (self.description,)
275
276    def __str__(self):
277        return self.id()
278
279    def run(self, result):
280        # could call result.addError(...) - but this test-like object
281        # shouldn't be run anyway
282        pass
283
284    def __call__(self, result):
285        return self.run(result)
286
287    def countTestCases(self):
288        return 0
289
290def _isnotsuite(test):
291    "A crude way to tell apart testcases and suites with duck-typing"
292    try:
293        iter(test)
294    except TypeError:
295        return True
296    return False
297
298
299class _DebugResult(object):
300    "Used by the TestSuite to hold previous class when running in debug."
301    _previousTestClass = None
302    _moduleSetUpFailed = False
303    shouldStop = False
304