1# Copyright (c) 2008-2011 testtools developers. See LICENSE for details.
2
3"""Test case related stuff."""
4
5__metaclass__ = type
6__all__ = [
7    'clone_test_with_new_id',
8    'ExpectedException',
9    'run_test_with',
10    'skip',
11    'skipIf',
12    'skipUnless',
13    'TestCase',
14    ]
15
16import copy
17import itertools
18import re
19import sys
20import types
21import unittest
22
23from testtools import (
24    content,
25    try_import,
26    )
27from testtools.compat import advance_iterator
28from testtools.matchers import (
29    Annotate,
30    Equals,
31    )
32from testtools.monkey import patch
33from testtools.runtest import RunTest
34from testtools.testresult import TestResult
35
36wraps = try_import('functools.wraps')
37
38class TestSkipped(Exception):
39    """Raised within TestCase.run() when a test is skipped."""
40testSkipped = try_import('unittest2.case.SkipTest', TestSkipped)
41TestSkipped = try_import('unittest.case.SkipTest', TestSkipped)
42
43
44class _UnexpectedSuccess(Exception):
45    """An unexpected success was raised.
46
47    Note that this exception is private plumbing in testtools' testcase
48    module.
49    """
50_UnexpectedSuccess = try_import(
51    'unittest2.case._UnexpectedSuccess', _UnexpectedSuccess)
52_UnexpectedSuccess = try_import(
53    'unittest.case._UnexpectedSuccess', _UnexpectedSuccess)
54
55class _ExpectedFailure(Exception):
56    """An expected failure occured.
57
58    Note that this exception is private plumbing in testtools' testcase
59    module.
60    """
61_ExpectedFailure = try_import(
62    'unittest2.case._ExpectedFailure', _ExpectedFailure)
63_ExpectedFailure = try_import(
64    'unittest.case._ExpectedFailure', _ExpectedFailure)
65
66
67def run_test_with(test_runner, **kwargs):
68    """Decorate a test as using a specific ``RunTest``.
69
70    e.g.::
71
72      @run_test_with(CustomRunner, timeout=42)
73      def test_foo(self):
74          self.assertTrue(True)
75
76    The returned decorator works by setting an attribute on the decorated
77    function.  `TestCase.__init__` looks for this attribute when deciding on a
78    ``RunTest`` factory.  If you wish to use multiple decorators on a test
79    method, then you must either make this one the top-most decorator, or you
80    must write your decorators so that they update the wrapping function with
81    the attributes of the wrapped function.  The latter is recommended style
82    anyway.  ``functools.wraps``, ``functools.wrapper`` and
83    ``twisted.python.util.mergeFunctionMetadata`` can help you do this.
84
85    :param test_runner: A ``RunTest`` factory that takes a test case and an
86        optional list of exception handlers.  See ``RunTest``.
87    :param kwargs: Keyword arguments to pass on as extra arguments to
88        'test_runner'.
89    :return: A decorator to be used for marking a test as needing a special
90        runner.
91    """
92    def decorator(function):
93        # Set an attribute on 'function' which will inform TestCase how to
94        # make the runner.
95        function._run_test_with = (
96            lambda case, handlers=None:
97                test_runner(case, handlers=handlers, **kwargs))
98        return function
99    return decorator
100
101
102class TestCase(unittest.TestCase):
103    """Extensions to the basic TestCase.
104
105    :ivar exception_handlers: Exceptions to catch from setUp, runTest and
106        tearDown. This list is able to be modified at any time and consists of
107        (exception_class, handler(case, result, exception_value)) pairs.
108    :cvar run_tests_with: A factory to make the ``RunTest`` to run tests with.
109        Defaults to ``RunTest``.  The factory is expected to take a test case
110        and an optional list of exception handlers.
111    """
112
113    skipException = TestSkipped
114
115    run_tests_with = RunTest
116
117    def __init__(self, *args, **kwargs):
118        """Construct a TestCase.
119
120        :param testMethod: The name of the method to run.
121        :keyword runTest: Optional class to use to execute the test. If not
122            supplied ``RunTest`` is used. The instance to be used is created
123            when run() is invoked, so will be fresh each time. Overrides
124            ``TestCase.run_tests_with`` if given.
125        """
126        runTest = kwargs.pop('runTest', None)
127        unittest.TestCase.__init__(self, *args, **kwargs)
128        self._cleanups = []
129        self._unique_id_gen = itertools.count(1)
130        # Generators to ensure unique traceback ids.  Maps traceback label to
131        # iterators.
132        self._traceback_id_gens = {}
133        self.__setup_called = False
134        self.__teardown_called = False
135        # __details is lazy-initialized so that a constructed-but-not-run
136        # TestCase is safe to use with clone_test_with_new_id.
137        self.__details = None
138        test_method = self._get_test_method()
139        if runTest is None:
140            runTest = getattr(
141                test_method, '_run_test_with', self.run_tests_with)
142        self.__RunTest = runTest
143        self.__exception_handlers = []
144        self.exception_handlers = [
145            (self.skipException, self._report_skip),
146            (self.failureException, self._report_failure),
147            (_ExpectedFailure, self._report_expected_failure),
148            (_UnexpectedSuccess, self._report_unexpected_success),
149            (Exception, self._report_error),
150            ]
151        if sys.version_info < (2, 6):
152            # Catch old-style string exceptions with None as the instance
153            self.exception_handlers.append((type(None), self._report_error))
154
155    def __eq__(self, other):
156        eq = getattr(unittest.TestCase, '__eq__', None)
157        if eq is not None and not unittest.TestCase.__eq__(self, other):
158            return False
159        return self.__dict__ == other.__dict__
160
161    def __repr__(self):
162        # We add id to the repr because it makes testing testtools easier.
163        return "<%s id=0x%0x>" % (self.id(), id(self))
164
165    def addDetail(self, name, content_object):
166        """Add a detail to be reported with this test's outcome.
167
168        For more details see pydoc testtools.TestResult.
169
170        :param name: The name to give this detail.
171        :param content_object: The content object for this detail. See
172            testtools.content for more detail.
173        """
174        if self.__details is None:
175            self.__details = {}
176        self.__details[name] = content_object
177
178    def getDetails(self):
179        """Get the details dict that will be reported with this test's outcome.
180
181        For more details see pydoc testtools.TestResult.
182        """
183        if self.__details is None:
184            self.__details = {}
185        return self.__details
186
187    def patch(self, obj, attribute, value):
188        """Monkey-patch 'obj.attribute' to 'value' while the test is running.
189
190        If 'obj' has no attribute, then the monkey-patch will still go ahead,
191        and the attribute will be deleted instead of restored to its original
192        value.
193
194        :param obj: The object to patch. Can be anything.
195        :param attribute: The attribute on 'obj' to patch.
196        :param value: The value to set 'obj.attribute' to.
197        """
198        self.addCleanup(patch(obj, attribute, value))
199
200    def shortDescription(self):
201        return self.id()
202
203    def skipTest(self, reason):
204        """Cause this test to be skipped.
205
206        This raises self.skipException(reason). skipException is raised
207        to permit a skip to be triggered at any point (during setUp or the
208        testMethod itself). The run() method catches skipException and
209        translates that into a call to the result objects addSkip method.
210
211        :param reason: The reason why the test is being skipped. This must
212            support being cast into a unicode string for reporting.
213        """
214        raise self.skipException(reason)
215
216    # skipTest is how python2.7 spells this. Sometime in the future
217    # This should be given a deprecation decorator - RBC 20100611.
218    skip = skipTest
219
220    def _formatTypes(self, classOrIterable):
221        """Format a class or a bunch of classes for display in an error."""
222        className = getattr(classOrIterable, '__name__', None)
223        if className is None:
224            className = ', '.join(klass.__name__ for klass in classOrIterable)
225        return className
226
227    def addCleanup(self, function, *arguments, **keywordArguments):
228        """Add a cleanup function to be called after tearDown.
229
230        Functions added with addCleanup will be called in reverse order of
231        adding after tearDown, or after setUp if setUp raises an exception.
232
233        If a function added with addCleanup raises an exception, the error
234        will be recorded as a test error, and the next cleanup will then be
235        run.
236
237        Cleanup functions are always called before a test finishes running,
238        even if setUp is aborted by an exception.
239        """
240        self._cleanups.append((function, arguments, keywordArguments))
241
242    def addOnException(self, handler):
243        """Add a handler to be called when an exception occurs in test code.
244
245        This handler cannot affect what result methods are called, and is
246        called before any outcome is called on the result object. An example
247        use for it is to add some diagnostic state to the test details dict
248        which is expensive to calculate and not interesting for reporting in
249        the success case.
250
251        Handlers are called before the outcome (such as addFailure) that
252        the exception has caused.
253
254        Handlers are called in first-added, first-called order, and if they
255        raise an exception, that will propogate out of the test running
256        machinery, halting test processing. As a result, do not call code that
257        may unreasonably fail.
258        """
259        self.__exception_handlers.append(handler)
260
261    def _add_reason(self, reason):
262        self.addDetail('reason', content.Content(
263            content.ContentType('text', 'plain'),
264            lambda: [reason.encode('utf8')]))
265
266    def assertEqual(self, expected, observed, message=''):
267        """Assert that 'expected' is equal to 'observed'.
268
269        :param expected: The expected value.
270        :param observed: The observed value.
271        :param message: An optional message to include in the error.
272        """
273        matcher = Equals(expected)
274        if message:
275            matcher = Annotate(message, matcher)
276        self.assertThat(observed, matcher)
277
278    failUnlessEqual = assertEquals = assertEqual
279
280    def assertIn(self, needle, haystack):
281        """Assert that needle is in haystack."""
282        self.assertTrue(
283            needle in haystack, '%r not in %r' % (needle, haystack))
284
285    def assertIs(self, expected, observed, message=''):
286        """Assert that 'expected' is 'observed'.
287
288        :param expected: The expected value.
289        :param observed: The observed value.
290        :param message: An optional message describing the error.
291        """
292        if message:
293            message = ': ' + message
294        self.assertTrue(
295            expected is observed,
296            '%r is not %r%s' % (expected, observed, message))
297
298    def assertIsNot(self, expected, observed, message=''):
299        """Assert that 'expected' is not 'observed'."""
300        if message:
301            message = ': ' + message
302        self.assertTrue(
303            expected is not observed,
304            '%r is %r%s' % (expected, observed, message))
305
306    def assertNotIn(self, needle, haystack):
307        """Assert that needle is not in haystack."""
308        self.assertTrue(
309            needle not in haystack, '%r in %r' % (needle, haystack))
310
311    def assertIsInstance(self, obj, klass, msg=None):
312        if msg is None:
313            msg = '%r is not an instance of %s' % (
314                obj, self._formatTypes(klass))
315        self.assertTrue(isinstance(obj, klass), msg)
316
317    def assertRaises(self, excClass, callableObj, *args, **kwargs):
318        """Fail unless an exception of class excClass is thrown
319           by callableObj when invoked with arguments args and keyword
320           arguments kwargs. If a different type of exception is
321           thrown, it will not be caught, and the test case will be
322           deemed to have suffered an error, exactly as for an
323           unexpected exception.
324        """
325        try:
326            ret = callableObj(*args, **kwargs)
327        except excClass:
328            return sys.exc_info()[1]
329        else:
330            excName = self._formatTypes(excClass)
331            self.fail("%s not raised, %r returned instead." % (excName, ret))
332    failUnlessRaises = assertRaises
333
334    def assertThat(self, matchee, matcher):
335        """Assert that matchee is matched by matcher.
336
337        :param matchee: An object to match with matcher.
338        :param matcher: An object meeting the testtools.Matcher protocol.
339        :raises self.failureException: When matcher does not match thing.
340        """
341        mismatch = matcher.match(matchee)
342        if not mismatch:
343            return
344        existing_details = self.getDetails()
345        for (name, content) in mismatch.get_details().items():
346            full_name = name
347            suffix = 1
348            while full_name in existing_details:
349                full_name = "%s-%d" % (name, suffix)
350                suffix += 1
351            self.addDetail(full_name, content)
352        self.fail('Match failed. Matchee: "%s"\nMatcher: %s\nDifference: %s\n'
353            % (matchee, matcher, mismatch.describe()))
354
355    def defaultTestResult(self):
356        return TestResult()
357
358    def expectFailure(self, reason, predicate, *args, **kwargs):
359        """Check that a test fails in a particular way.
360
361        If the test fails in the expected way, a KnownFailure is caused. If it
362        succeeds an UnexpectedSuccess is caused.
363
364        The expected use of expectFailure is as a barrier at the point in a
365        test where the test would fail. For example:
366        >>> def test_foo(self):
367        >>>    self.expectFailure("1 should be 0", self.assertNotEqual, 1, 0)
368        >>>    self.assertEqual(1, 0)
369
370        If in the future 1 were to equal 0, the expectFailure call can simply
371        be removed. This separation preserves the original intent of the test
372        while it is in the expectFailure mode.
373        """
374        self._add_reason(reason)
375        try:
376            predicate(*args, **kwargs)
377        except self.failureException:
378            # GZ 2010-08-12: Don't know how to avoid exc_info cycle as the new
379            #                unittest _ExpectedFailure wants old traceback
380            exc_info = sys.exc_info()
381            try:
382                self._report_traceback(exc_info)
383                raise _ExpectedFailure(exc_info)
384            finally:
385                del exc_info
386        else:
387            raise _UnexpectedSuccess(reason)
388
389    def getUniqueInteger(self):
390        """Get an integer unique to this test.
391
392        Returns an integer that is guaranteed to be unique to this instance.
393        Use this when you need an arbitrary integer in your test, or as a
394        helper for custom anonymous factory methods.
395        """
396        return advance_iterator(self._unique_id_gen)
397
398    def getUniqueString(self, prefix=None):
399        """Get a string unique to this test.
400
401        Returns a string that is guaranteed to be unique to this instance. Use
402        this when you need an arbitrary string in your test, or as a helper
403        for custom anonymous factory methods.
404
405        :param prefix: The prefix of the string. If not provided, defaults
406            to the id of the tests.
407        :return: A bytestring of '<prefix>-<unique_int>'.
408        """
409        if prefix is None:
410            prefix = self.id()
411        return '%s-%d' % (prefix, self.getUniqueInteger())
412
413    def onException(self, exc_info, tb_label='traceback'):
414        """Called when an exception propogates from test code.
415
416        :seealso addOnException:
417        """
418        if exc_info[0] not in [
419            TestSkipped, _UnexpectedSuccess, _ExpectedFailure]:
420            self._report_traceback(exc_info, tb_label=tb_label)
421        for handler in self.__exception_handlers:
422            handler(exc_info)
423
424    @staticmethod
425    def _report_error(self, result, err):
426        result.addError(self, details=self.getDetails())
427
428    @staticmethod
429    def _report_expected_failure(self, result, err):
430        result.addExpectedFailure(self, details=self.getDetails())
431
432    @staticmethod
433    def _report_failure(self, result, err):
434        result.addFailure(self, details=self.getDetails())
435
436    @staticmethod
437    def _report_skip(self, result, err):
438        if err.args:
439            reason = err.args[0]
440        else:
441            reason = "no reason given."
442        self._add_reason(reason)
443        result.addSkip(self, details=self.getDetails())
444
445    def _report_traceback(self, exc_info, tb_label='traceback'):
446        id_gen = self._traceback_id_gens.setdefault(
447            tb_label, itertools.count(0))
448        tb_id = advance_iterator(id_gen)
449        if tb_id:
450            tb_label = '%s-%d' % (tb_label, tb_id)
451        self.addDetail(tb_label, content.TracebackContent(exc_info, self))
452
453    @staticmethod
454    def _report_unexpected_success(self, result, err):
455        result.addUnexpectedSuccess(self, details=self.getDetails())
456
457    def run(self, result=None):
458        return self.__RunTest(self, self.exception_handlers).run(result)
459
460    def _run_setup(self, result):
461        """Run the setUp function for this test.
462
463        :param result: A testtools.TestResult to report activity to.
464        :raises ValueError: If the base class setUp is not called, a
465            ValueError is raised.
466        """
467        ret = self.setUp()
468        if not self.__setup_called:
469            raise ValueError(
470                "TestCase.setUp was not called. Have you upcalled all the "
471                "way up the hierarchy from your setUp? e.g. Call "
472                "super(%s, self).setUp() from your setUp()."
473                % self.__class__.__name__)
474        return ret
475
476    def _run_teardown(self, result):
477        """Run the tearDown function for this test.
478
479        :param result: A testtools.TestResult to report activity to.
480        :raises ValueError: If the base class tearDown is not called, a
481            ValueError is raised.
482        """
483        ret = self.tearDown()
484        if not self.__teardown_called:
485            raise ValueError(
486                "TestCase.tearDown was not called. Have you upcalled all the "
487                "way up the hierarchy from your tearDown? e.g. Call "
488                "super(%s, self).tearDown() from your tearDown()."
489                % self.__class__.__name__)
490        return ret
491
492    def _get_test_method(self):
493        absent_attr = object()
494        # Python 2.5+
495        method_name = getattr(self, '_testMethodName', absent_attr)
496        if method_name is absent_attr:
497            # Python 2.4
498            method_name = getattr(self, '_TestCase__testMethodName')
499        return getattr(self, method_name)
500
501    def _run_test_method(self, result):
502        """Run the test method for this test.
503
504        :param result: A testtools.TestResult to report activity to.
505        :return: None.
506        """
507        return self._get_test_method()()
508
509    def useFixture(self, fixture):
510        """Use fixture in a test case.
511
512        The fixture will be setUp, and self.addCleanup(fixture.cleanUp) called.
513
514        :param fixture: The fixture to use.
515        :return: The fixture, after setting it up and scheduling a cleanup for
516           it.
517        """
518        fixture.setUp()
519        self.addCleanup(fixture.cleanUp)
520        self.addCleanup(self._gather_details, fixture.getDetails)
521        return fixture
522
523    def _gather_details(self, getDetails):
524        """Merge the details from getDetails() into self.getDetails()."""
525        details = getDetails()
526        my_details = self.getDetails()
527        for name, content_object in details.items():
528            new_name = name
529            disambiguator = itertools.count(1)
530            while new_name in my_details:
531                new_name = '%s-%d' % (name, advance_iterator(disambiguator))
532            name = new_name
533            content_bytes = list(content_object.iter_bytes())
534            content_callback = lambda:content_bytes
535            self.addDetail(name,
536                content.Content(content_object.content_type, content_callback))
537
538    def setUp(self):
539        unittest.TestCase.setUp(self)
540        self.__setup_called = True
541
542    def tearDown(self):
543        unittest.TestCase.tearDown(self)
544        self.__teardown_called = True
545
546
547class PlaceHolder(object):
548    """A placeholder test.
549
550    `PlaceHolder` implements much of the same interface as TestCase and is
551    particularly suitable for being added to TestResults.
552    """
553
554    def __init__(self, test_id, short_description=None):
555        """Construct a `PlaceHolder`.
556
557        :param test_id: The id of the placeholder test.
558        :param short_description: The short description of the place holder
559            test. If not provided, the id will be used instead.
560        """
561        self._test_id = test_id
562        self._short_description = short_description
563
564    def __call__(self, result=None):
565        return self.run(result=result)
566
567    def __repr__(self):
568        internal = [self._test_id]
569        if self._short_description is not None:
570            internal.append(self._short_description)
571        return "<%s.%s(%s)>" % (
572            self.__class__.__module__,
573            self.__class__.__name__,
574            ", ".join(map(repr, internal)))
575
576    def __str__(self):
577        return self.id()
578
579    def countTestCases(self):
580        return 1
581
582    def debug(self):
583        pass
584
585    def id(self):
586        return self._test_id
587
588    def run(self, result=None):
589        if result is None:
590            result = TestResult()
591        result.startTest(self)
592        result.addSuccess(self)
593        result.stopTest(self)
594
595    def shortDescription(self):
596        if self._short_description is None:
597            return self.id()
598        else:
599            return self._short_description
600
601
602class ErrorHolder(PlaceHolder):
603    """A placeholder test that will error out when run."""
604
605    failureException = None
606
607    def __init__(self, test_id, error, short_description=None):
608        """Construct an `ErrorHolder`.
609
610        :param test_id: The id of the test.
611        :param error: The exc info tuple that will be used as the test's error.
612        :param short_description: An optional short description of the test.
613        """
614        super(ErrorHolder, self).__init__(
615            test_id, short_description=short_description)
616        self._error = error
617
618    def __repr__(self):
619        internal = [self._test_id, self._error]
620        if self._short_description is not None:
621            internal.append(self._short_description)
622        return "<%s.%s(%s)>" % (
623            self.__class__.__module__,
624            self.__class__.__name__,
625            ", ".join(map(repr, internal)))
626
627    def run(self, result=None):
628        if result is None:
629            result = TestResult()
630        result.startTest(self)
631        result.addError(self, self._error)
632        result.stopTest(self)
633
634
635# Python 2.4 did not know how to copy functions.
636if types.FunctionType not in copy._copy_dispatch:
637    copy._copy_dispatch[types.FunctionType] = copy._copy_immutable
638
639
640def clone_test_with_new_id(test, new_id):
641    """Copy a `TestCase`, and give the copied test a new id.
642
643    This is only expected to be used on tests that have been constructed but
644    not executed.
645    """
646    newTest = copy.copy(test)
647    newTest.id = lambda: new_id
648    return newTest
649
650
651def skip(reason):
652    """A decorator to skip unit tests.
653
654    This is just syntactic sugar so users don't have to change any of their
655    unit tests in order to migrate to python 2.7, which provides the
656    @unittest.skip decorator.
657    """
658    def decorator(test_item):
659        if wraps is not None:
660            @wraps(test_item)
661            def skip_wrapper(*args, **kwargs):
662                raise TestCase.skipException(reason)
663        else:
664            def skip_wrapper(test_item):
665                test_item.skip(reason)
666        return skip_wrapper
667    return decorator
668
669
670def skipIf(condition, reason):
671    """Skip a test if the condition is true."""
672    if condition:
673        return skip(reason)
674    def _id(obj):
675        return obj
676    return _id
677
678
679def skipUnless(condition, reason):
680    """Skip a test unless the condition is true."""
681    if not condition:
682        return skip(reason)
683    def _id(obj):
684        return obj
685    return _id
686
687
688class ExpectedException:
689    """A context manager to handle expected exceptions.
690
691    In Python 2.5 or later::
692
693      def test_foo(self):
694          with ExpectedException(ValueError, 'fo.*'):
695              raise ValueError('foo')
696
697    will pass.  If the raised exception has a type other than the specified
698    type, it will be re-raised.  If it has a 'str()' that does not match the
699    given regular expression, an AssertionError will be raised.  If no
700    exception is raised, an AssertionError will be raised.
701    """
702
703    def __init__(self, exc_type, value_re):
704        """Construct an `ExpectedException`.
705
706        :param exc_type: The type of exception to expect.
707        :param value_re: A regular expression to match against the
708            'str()' of the raised exception.
709        """
710        self.exc_type = exc_type
711        self.value_re = value_re
712
713    def __enter__(self):
714        pass
715
716    def __exit__(self, exc_type, exc_value, traceback):
717        if exc_type is None:
718            raise AssertionError('%s not raised.' % self.exc_type.__name__)
719        if exc_type != self.exc_type:
720            return False
721        if not re.match(self.value_re, str(exc_value)):
722            raise AssertionError('"%s" does not match "%s".' %
723                                 (str(exc_value), self.value_re))
724        return True
725