1"""Test case implementation"""
2
3import sys
4import difflib
5import pprint
6import re
7import unittest
8import warnings
9
10from unittest2 import result
11from unittest2.util import (
12    safe_repr, safe_str, strclass,
13    unorderable_list_difference
14)
15
16from unittest2.compatibility import wraps
17
18__unittest = True
19
20
21DIFF_OMITTED = ('\nDiff is %s characters long. '
22                'Set self.maxDiff to None to see it.')
23
24
25class SkipTest(Exception):
26    """
27    Raise this exception in a test to skip it.
28
29    Usually you can use TestResult.skip() or one of the skipping decorators
30    instead of raising this directly.
31    """
32
33
34class _ExpectedFailure(Exception):
35    """
36    Raise this when a test is expected to fail.
37
38    This is an implementation detail.
39    """
40
41    def __init__(self, exc_info, bugnumber=None):
42        # can't use super because Python 2.4 exceptions are old style
43        Exception.__init__(self)
44        self.exc_info = exc_info
45        self.bugnumber = bugnumber
46
47
48class _UnexpectedSuccess(Exception):
49    """
50    The test was supposed to fail, but it didn't!
51    """
52
53    def __init__(self, exc_info, bugnumber=None):
54        # can't use super because Python 2.4 exceptions are old style
55        Exception.__init__(self)
56        self.exc_info = exc_info
57        self.bugnumber = bugnumber
58
59
60def _id(obj):
61    return obj
62
63
64def skip(reason):
65    """
66    Unconditionally skip a test.
67    """
68    def decorator(test_item):
69        if not (
70            isinstance(
71                test_item,
72                type) and issubclass(
73                test_item,
74                TestCase)):
75            @wraps(test_item)
76            def skip_wrapper(*args, **kwargs):
77                raise SkipTest(reason)
78            test_item = skip_wrapper
79
80        test_item.__unittest_skip__ = True
81        test_item.__unittest_skip_why__ = reason
82        return test_item
83    return decorator
84
85
86def skipIf(condition, reason):
87    """
88    Skip a test if the condition is true.
89    """
90    if condition:
91        return skip(reason)
92    return _id
93
94
95def skipUnless(condition, reason):
96    """
97    Skip a test unless the condition is true.
98    """
99    if not condition:
100        return skip(reason)
101    return _id
102
103
104def expectedFailure(bugnumber=None):
105    if callable(bugnumber):
106        @wraps(bugnumber)
107        def expectedFailure_easy_wrapper(*args, **kwargs):
108            try:
109                bugnumber(*args, **kwargs)
110            except Exception:
111                raise _ExpectedFailure(sys.exc_info(), None)
112            raise _UnexpectedSuccess(sys.exc_info(), None)
113        return expectedFailure_easy_wrapper
114    else:
115        def expectedFailure_impl(func):
116            @wraps(func)
117            def wrapper(*args, **kwargs):
118                try:
119                    func(*args, **kwargs)
120                except Exception:
121                    raise _ExpectedFailure(sys.exc_info(), bugnumber)
122                raise _UnexpectedSuccess(sys.exc_info(), bugnumber)
123            return wrapper
124        return expectedFailure_impl
125
126
127class _AssertRaisesContext(object):
128    """A context manager used to implement TestCase.assertRaises* methods."""
129
130    def __init__(self, expected, test_case, expected_regexp=None):
131        self.expected = expected
132        self.failureException = test_case.failureException
133        self.expected_regexp = expected_regexp
134
135    def __enter__(self):
136        return self
137
138    def __exit__(self, exc_type, exc_value, tb):
139        if exc_type is None:
140            try:
141                exc_name = self.expected.__name__
142            except AttributeError:
143                exc_name = str(self.expected)
144            raise self.failureException(
145                "%s not raised" % (exc_name,))
146        if not issubclass(exc_type, self.expected):
147            # let unexpected exceptions pass through
148            return False
149        self.exception = exc_value  # store for later retrieval
150        if self.expected_regexp is None:
151            return True
152
153        expected_regexp = self.expected_regexp
154        if isinstance(expected_regexp, str):
155            expected_regexp = re.compile(expected_regexp)
156        if not expected_regexp.search(str(exc_value)):
157            raise self.failureException(
158                '"%s" does not match "%s"' %
159                (expected_regexp.pattern, str(exc_value)))
160        return True
161
162
163class _TypeEqualityDict(object):
164
165    def __init__(self, testcase):
166        self.testcase = testcase
167        self._store = {}
168
169    def __setitem__(self, key, value):
170        self._store[key] = value
171
172    def __getitem__(self, key):
173        value = self._store[key]
174        if isinstance(value, str):
175            return getattr(self.testcase, value)
176        return value
177
178    def get(self, key, default=None):
179        if key in self._store:
180            return self[key]
181        return default
182
183
184class TestCase(unittest.TestCase):
185    """A class whose instances are single test cases.
186
187    By default, the test code itself should be placed in a method named
188    'runTest'.
189
190    If the fixture may be used for many test cases, create as
191    many test methods as are needed. When instantiating such a TestCase
192    subclass, specify in the constructor arguments the name of the test method
193    that the instance is to execute.
194
195    Test authors should subclass TestCase for their own tests. Construction
196    and deconstruction of the test's environment ('fixture') can be
197    implemented by overriding the 'setUp' and 'tearDown' methods respectively.
198
199    If it is necessary to override the __init__ method, the base class
200    __init__ method must always be called. It is important that subclasses
201    should not change the signature of their __init__ method, since instances
202    of the classes are instantiated automatically by parts of the framework
203    in order to be run.
204    """
205
206    # This attribute determines which exception will be raised when
207    # the instance's assertion methods fail; test methods raising this
208    # exception will be deemed to have 'failed' rather than 'errored'
209
210    failureException = AssertionError
211
212    # This attribute sets the maximum length of a diff in failure messages
213    # by assert methods using difflib. It is looked up as an instance attribute
214    # so can be configured by individual tests if required.
215
216    maxDiff = 80 * 8
217
218    # This attribute determines whether long messages (including repr of
219    # objects used in assert methods) will be printed on failure in *addition*
220    # to any explicit message passed.
221
222    longMessage = True
223
224    # Attribute used by TestSuite for classSetUp
225
226    _classSetupFailed = False
227
228    def __init__(self, methodName='runTest'):
229        """Create an instance of the class that will use the named test
230           method when executed. Raises a ValueError if the instance does
231           not have a method with the specified name.
232        """
233        self._testMethodName = methodName
234        self._resultForDoCleanups = None
235        try:
236            testMethod = getattr(self, methodName)
237        except AttributeError:
238            raise ValueError("no such test method in %s: %s" %
239                             (self.__class__, methodName))
240        self._testMethodDoc = testMethod.__doc__
241        self._cleanups = []
242
243        # Map types to custom assertEqual functions that will compare
244        # instances of said type in more detail to generate a more useful
245        # error message.
246        self._type_equality_funcs = _TypeEqualityDict(self)
247        self.addTypeEqualityFunc(dict, 'assertDictEqual')
248        self.addTypeEqualityFunc(list, 'assertListEqual')
249        self.addTypeEqualityFunc(tuple, 'assertTupleEqual')
250        self.addTypeEqualityFunc(set, 'assertSetEqual')
251        self.addTypeEqualityFunc(frozenset, 'assertSetEqual')
252        self.addTypeEqualityFunc(str, 'assertMultiLineEqual')
253
254    def addTypeEqualityFunc(self, typeobj, function):
255        """Add a type specific assertEqual style function to compare a type.
256
257        This method is for use by TestCase subclasses that need to register
258        their own type equality functions to provide nicer error messages.
259
260        Args:
261            typeobj: The data type to call this function on when both values
262                    are of the same type in assertEqual().
263            function: The callable taking two arguments and an optional
264                    msg= argument that raises self.failureException with a
265                    useful error message when the two arguments are not equal.
266        """
267        self._type_equality_funcs[typeobj] = function
268
269    def addCleanup(self, function, *args, **kwargs):
270        """Add a function, with arguments, to be called when the test is
271        completed. Functions added are called on a LIFO basis and are
272        called after tearDown on test failure or success.
273
274        Cleanup items are called even if setUp fails (unlike tearDown)."""
275        self._cleanups.append((function, args, kwargs))
276
277    def setUp(self):
278        "Hook method for setting up the test fixture before exercising it."
279
280    @classmethod
281    def setUpClass(cls):
282        "Hook method for setting up class fixture before running tests in the class."
283
284    @classmethod
285    def tearDownClass(cls):
286        "Hook method for deconstructing the class fixture after running all tests in the class."
287
288    def tearDown(self):
289        "Hook method for deconstructing the test fixture after testing it."
290
291    def countTestCases(self):
292        return 1
293
294    def defaultTestResult(self):
295        return result.TestResult()
296
297    def shortDescription(self):
298        """Returns a one-line description of the test, or None if no
299        description has been provided.
300
301        The default implementation of this method returns the first line of
302        the specified test method's docstring.
303        """
304        doc = self._testMethodDoc
305        return doc and doc.split("\n")[0].strip() or None
306
307    def id(self):
308        return "%s.%s" % (strclass(self.__class__), self._testMethodName)
309
310    def __eq__(self, other):
311        if not isinstance(self, type(other)):
312            return NotImplemented
313
314        return self._testMethodName == other._testMethodName
315
316    def __ne__(self, other):
317        return not self == other
318
319    def __hash__(self):
320        return hash((type(self), self._testMethodName))
321
322    def __str__(self):
323        return "%s (%s)" % (self._testMethodName, strclass(self.__class__))
324
325    def __repr__(self):
326        return "<%s testMethod=%s>" % \
327               (strclass(self.__class__), self._testMethodName)
328
329    def _addSkip(self, result, reason):
330        addSkip = getattr(result, 'addSkip', None)
331        if addSkip is not None:
332            addSkip(self, reason)
333        else:
334            warnings.warn(
335                "Use of a TestResult without an addSkip method is deprecated",
336                DeprecationWarning,
337                2)
338            result.addSuccess(self)
339
340    def run(self, result=None):
341        orig_result = result
342        if result is None:
343            result = self.defaultTestResult()
344            startTestRun = getattr(result, 'startTestRun', None)
345            if startTestRun is not None:
346                startTestRun()
347
348        self._resultForDoCleanups = result
349        result.startTest(self)
350
351        testMethod = getattr(self, self._testMethodName)
352
353        if (getattr(self.__class__, "__unittest_skip__", False) or
354                getattr(testMethod, "__unittest_skip__", False)):
355            # If the class or method was skipped.
356            try:
357                skip_why = (
358                    getattr(
359                        self.__class__,
360                        '__unittest_skip_why__',
361                        '') or getattr(
362                        testMethod,
363                        '__unittest_skip_why__',
364                        ''))
365                self._addSkip(result, skip_why)
366            finally:
367                result.stopTest(self)
368            return
369        try:
370            success = False
371            try:
372                self.setUp()
373            except SkipTest as e:
374                self._addSkip(result, str(e))
375            except Exception:
376                result.addError(self, sys.exc_info())
377            else:
378                success = self.runMethod(testMethod, result)
379
380                try:
381                    self.tearDown()
382                except Exception:
383                    result.addCleanupError(self, sys.exc_info())
384                    success = False
385
386                self.dumpSessionInfo()
387
388            cleanUpSuccess = self.doCleanups()
389            success = success and cleanUpSuccess
390            if success:
391                result.addSuccess(self)
392        finally:
393            result.stopTest(self)
394            if orig_result is None:
395                stopTestRun = getattr(result, 'stopTestRun', None)
396                if stopTestRun is not None:
397                    stopTestRun()
398
399    def runMethod(self, testMethod, result):
400        """Runs the test method and catches any exception that might be thrown.
401
402        This is factored out of TestCase.run() to ensure that any exception
403        thrown during the test goes out of scope before tearDown.  Otherwise, an
404        exception could hold references to Python objects that are bound to
405        SB objects and prevent them from being deleted in time.
406        """
407        try:
408            testMethod()
409        except self.failureException:
410            result.addFailure(self, sys.exc_info())
411        except _ExpectedFailure as e:
412            addExpectedFailure = getattr(result, 'addExpectedFailure', None)
413            if addExpectedFailure is not None:
414                addExpectedFailure(self, e.exc_info, e.bugnumber)
415            else:
416                warnings.warn(
417                    "Use of a TestResult without an addExpectedFailure method is deprecated",
418                    DeprecationWarning)
419                result.addSuccess(self)
420        except _UnexpectedSuccess as x:
421            addUnexpectedSuccess = getattr(
422                result, 'addUnexpectedSuccess', None)
423            if addUnexpectedSuccess is not None:
424                addUnexpectedSuccess(self, x.bugnumber)
425            else:
426                warnings.warn(
427                    "Use of a TestResult without an addUnexpectedSuccess method is deprecated",
428                    DeprecationWarning)
429                result.addFailure(self, sys.exc_info())
430        except SkipTest as e:
431            self._addSkip(result, str(e))
432        except Exception:
433            result.addError(self, sys.exc_info())
434        else:
435            return True
436        return False
437
438    def doCleanups(self):
439        """Execute all cleanup functions. Normally called for you after
440        tearDown."""
441        result = self._resultForDoCleanups
442        ok = True
443        while self._cleanups:
444            function, args, kwargs = self._cleanups.pop(-1)
445            try:
446                function(*args, **kwargs)
447            except Exception:
448                ok = False
449                result.addError(self, sys.exc_info())
450        return ok
451
452    def __call__(self, *args, **kwds):
453        return self.run(*args, **kwds)
454
455    def debug(self):
456        """Run the test without collecting errors in a TestResult"""
457        self.setUp()
458        getattr(self, self._testMethodName)()
459        self.tearDown()
460        while self._cleanups:
461            function, args, kwargs = self._cleanups.pop(-1)
462            function(*args, **kwargs)
463
464    def skipTest(self, reason):
465        """Skip this test."""
466        raise SkipTest(reason)
467
468    def fail(self, msg=None):
469        """Fail immediately, with the given message."""
470        raise self.failureException(msg)
471
472    def assertFalse(self, expr, msg=None):
473        "Fail the test if the expression is true."
474        if expr:
475            msg = self._formatMessage(msg, "%s is not False" % safe_repr(expr))
476            raise self.failureException(msg)
477
478    def assertTrue(self, expr, msg=None):
479        """Fail the test unless the expression is true."""
480        if not expr:
481            msg = self._formatMessage(msg, "%s is not True" % safe_repr(expr))
482            raise self.failureException(msg)
483
484    def _formatMessage(self, msg, standardMsg):
485        """Honour the longMessage attribute when generating failure messages.
486        If longMessage is False this means:
487        * Use only an explicit message if it is provided
488        * Otherwise use the standard message for the assert
489
490        If longMessage is True:
491        * Use the standard message
492        * If an explicit message is provided, plus ' : ' and the explicit message
493        """
494        if not self.longMessage:
495            return msg or standardMsg
496        if msg is None:
497            return standardMsg
498        try:
499            return '%s : %s' % (standardMsg, msg)
500        except UnicodeDecodeError:
501            return '%s : %s' % (safe_str(standardMsg), safe_str(msg))
502
503    def assertRaises(self, excClass, callableObj=None, *args, **kwargs):
504        """Fail unless an exception of class excClass is thrown
505           by callableObj when invoked with arguments args and keyword
506           arguments kwargs. If a different type of exception is
507           thrown, it will not be caught, and the test case will be
508           deemed to have suffered an error, exactly as for an
509           unexpected exception.
510
511           If called with callableObj omitted or None, will return a
512           context object used like this::
513
514                with self.assertRaises(SomeException):
515                    do_something()
516
517           The context manager keeps a reference to the exception as
518           the 'exception' attribute. This allows you to inspect the
519           exception after the assertion::
520
521               with self.assertRaises(SomeException) as cm:
522                   do_something()
523               the_exception = cm.exception
524               self.assertEqual(the_exception.error_code, 3)
525        """
526        if callableObj is None:
527            return _AssertRaisesContext(excClass, self)
528        try:
529            callableObj(*args, **kwargs)
530        except excClass:
531            return
532
533        if hasattr(excClass, '__name__'):
534            excName = excClass.__name__
535        else:
536            excName = str(excClass)
537        raise self.failureException("%s not raised" % excName)
538
539    def _getAssertEqualityFunc(self, first, second):
540        """Get a detailed comparison function for the types of the two args.
541
542        Returns: A callable accepting (first, second, msg=None) that will
543        raise a failure exception if first != second with a useful human
544        readable error message for those types.
545        """
546        #
547        # NOTE(gregory.p.smith): I considered isinstance(first, type(second))
548        # and vice versa.  I opted for the conservative approach in case
549        # subclasses are not intended to be compared in detail to their super
550        # class instances using a type equality func.  This means testing
551        # subtypes won't automagically use the detailed comparison.  Callers
552        # should use their type specific assertSpamEqual method to compare
553        # subclasses if the detailed comparison is desired and appropriate.
554        # See the discussion in http://bugs.python.org/issue2578.
555        #
556        if isinstance(first, type(second)):
557            asserter = self._type_equality_funcs.get(type(first))
558            if asserter is not None:
559                return asserter
560
561        return self._baseAssertEqual
562
563    def _baseAssertEqual(self, first, second, msg=None):
564        """The default assertEqual implementation, not type specific."""
565        if not first == second:
566            standardMsg = '%s != %s' % (safe_repr(first), safe_repr(second))
567            msg = self._formatMessage(msg, standardMsg)
568            raise self.failureException(msg)
569
570    def assertEqual(self, first, second, msg=None):
571        """Fail if the two objects are unequal as determined by the '=='
572           operator.
573        """
574        assertion_func = self._getAssertEqualityFunc(first, second)
575        assertion_func(first, second, msg=msg)
576
577    def assertNotEqual(self, first, second, msg=None):
578        """Fail if the two objects are equal as determined by the '=='
579           operator.
580        """
581        if not first != second:
582            msg = self._formatMessage(msg, '%s == %s' % (safe_repr(first),
583                                                         safe_repr(second)))
584            raise self.failureException(msg)
585
586    def assertAlmostEqual(
587            self,
588            first,
589            second,
590            places=None,
591            msg=None,
592            delta=None):
593        """Fail if the two objects are unequal as determined by their
594           difference rounded to the given number of decimal places
595           (default 7) and comparing to zero, or by comparing that the
596           between the two objects is more than the given delta.
597
598           Note that decimal places (from zero) are usually not the same
599           as significant digits (measured from the most signficant digit).
600
601           If the two objects compare equal then they will automatically
602           compare almost equal.
603        """
604        if first == second:
605            # shortcut
606            return
607        if delta is not None and places is not None:
608            raise TypeError("specify delta or places not both")
609
610        if delta is not None:
611            if abs(first - second) <= delta:
612                return
613
614            standardMsg = '%s != %s within %s delta' % (safe_repr(first),
615                                                        safe_repr(second),
616                                                        safe_repr(delta))
617        else:
618            if places is None:
619                places = 7
620
621            if round(abs(second - first), places) == 0:
622                return
623
624            standardMsg = '%s != %s within %r places' % (safe_repr(first),
625                                                         safe_repr(second),
626                                                         places)
627        msg = self._formatMessage(msg, standardMsg)
628        raise self.failureException(msg)
629
630    def assertNotAlmostEqual(
631            self,
632            first,
633            second,
634            places=None,
635            msg=None,
636            delta=None):
637        """Fail if the two objects are equal as determined by their
638           difference rounded to the given number of decimal places
639           (default 7) and comparing to zero, or by comparing that the
640           between the two objects is less than the given delta.
641
642           Note that decimal places (from zero) are usually not the same
643           as significant digits (measured from the most signficant digit).
644
645           Objects that are equal automatically fail.
646        """
647        if delta is not None and places is not None:
648            raise TypeError("specify delta or places not both")
649        if delta is not None:
650            if not (first == second) and abs(first - second) > delta:
651                return
652            standardMsg = '%s == %s within %s delta' % (safe_repr(first),
653                                                        safe_repr(second),
654                                                        safe_repr(delta))
655        else:
656            if places is None:
657                places = 7
658            if not (first == second) and round(
659                    abs(second - first), places) != 0:
660                return
661            standardMsg = '%s == %s within %r places' % (safe_repr(first),
662                                                         safe_repr(second),
663                                                         places)
664
665        msg = self._formatMessage(msg, standardMsg)
666        raise self.failureException(msg)
667
668    # Synonyms for assertion methods
669
670    # The plurals are undocumented.  Keep them that way to discourage use.
671    # Do not add more.  Do not remove.
672    # Going through a deprecation cycle on these would annoy many people.
673    assertEquals = assertEqual
674    assertNotEquals = assertNotEqual
675    assertAlmostEquals = assertAlmostEqual
676    assertNotAlmostEquals = assertNotAlmostEqual
677    assert_ = assertTrue
678
679    # These fail* assertion method names are pending deprecation and will
680    # be a DeprecationWarning in 3.2; http://bugs.python.org/issue2578
681    def _deprecate(original_func):
682        def deprecated_func(*args, **kwargs):
683            warnings.warn(
684                ('Please use %s instead.' % original_func.__name__),
685                PendingDeprecationWarning, 2)
686            return original_func(*args, **kwargs)
687        return deprecated_func
688
689    failUnlessEqual = _deprecate(assertEqual)
690    failIfEqual = _deprecate(assertNotEqual)
691    failUnlessAlmostEqual = _deprecate(assertAlmostEqual)
692    failIfAlmostEqual = _deprecate(assertNotAlmostEqual)
693    failUnless = _deprecate(assertTrue)
694    failUnlessRaises = _deprecate(assertRaises)
695    failIf = _deprecate(assertFalse)
696
697    def assertSequenceEqual(self, seq1, seq2,
698                            msg=None, seq_type=None, max_diff=80 * 8):
699        """An equality assertion for ordered sequences (like lists and tuples).
700
701        For the purposes of this function, a valid ordered sequence type is one
702        which can be indexed, has a length, and has an equality operator.
703
704        Args:
705            seq1: The first sequence to compare.
706            seq2: The second sequence to compare.
707            seq_type: The expected datatype of the sequences, or None if no
708                    datatype should be enforced.
709            msg: Optional message to use on failure instead of a list of
710                    differences.
711            max_diff: Maximum size off the diff, larger diffs are not shown
712        """
713        if seq_type is not None:
714            seq_type_name = seq_type.__name__
715            if not isinstance(seq1, seq_type):
716                raise self.failureException('First sequence is not a %s: %s'
717                                            % (seq_type_name, safe_repr(seq1)))
718            if not isinstance(seq2, seq_type):
719                raise self.failureException('Second sequence is not a %s: %s'
720                                            % (seq_type_name, safe_repr(seq2)))
721        else:
722            seq_type_name = "sequence"
723
724        differing = None
725        try:
726            len1 = len(seq1)
727        except (TypeError, NotImplementedError):
728            differing = 'First %s has no length.    Non-sequence?' % (
729                seq_type_name)
730
731        if differing is None:
732            try:
733                len2 = len(seq2)
734            except (TypeError, NotImplementedError):
735                differing = 'Second %s has no length.    Non-sequence?' % (
736                    seq_type_name)
737
738        if differing is None:
739            if seq1 == seq2:
740                return
741
742            seq1_repr = repr(seq1)
743            seq2_repr = repr(seq2)
744            if len(seq1_repr) > 30:
745                seq1_repr = seq1_repr[:30] + '...'
746            if len(seq2_repr) > 30:
747                seq2_repr = seq2_repr[:30] + '...'
748            elements = (seq_type_name.capitalize(), seq1_repr, seq2_repr)
749            differing = '%ss differ: %s != %s\n' % elements
750
751            for i in range(min(len1, len2)):
752                try:
753                    item1 = seq1[i]
754                except (TypeError, IndexError, NotImplementedError):
755                    differing += ('\nUnable to index element %d of first %s\n' %
756                                  (i, seq_type_name))
757                    break
758
759                try:
760                    item2 = seq2[i]
761                except (TypeError, IndexError, NotImplementedError):
762                    differing += ('\nUnable to index element %d of second %s\n' %
763                                  (i, seq_type_name))
764                    break
765
766                if item1 != item2:
767                    differing += ('\nFirst differing element %d:\n%s\n%s\n' %
768                                  (i, item1, item2))
769                    break
770            else:
771                if (len1 == len2 and seq_type is None and
772                        not isinstance(seq1, type(seq2))):
773                    # The sequences are the same, but have differing types.
774                    return
775
776            if len1 > len2:
777                differing += ('\nFirst %s contains %d additional '
778                              'elements.\n' % (seq_type_name, len1 - len2))
779                try:
780                    differing += ('First extra element %d:\n%s\n' %
781                                  (len2, seq1[len2]))
782                except (TypeError, IndexError, NotImplementedError):
783                    differing += ('Unable to index element %d '
784                                  'of first %s\n' % (len2, seq_type_name))
785            elif len1 < len2:
786                differing += ('\nSecond %s contains %d additional '
787                              'elements.\n' % (seq_type_name, len2 - len1))
788                try:
789                    differing += ('First extra element %d:\n%s\n' %
790                                  (len1, seq2[len1]))
791                except (TypeError, IndexError, NotImplementedError):
792                    differing += ('Unable to index element %d '
793                                  'of second %s\n' % (len1, seq_type_name))
794        standardMsg = differing
795        diffMsg = '\n' + '\n'.join(
796            difflib.ndiff(pprint.pformat(seq1).splitlines(),
797                          pprint.pformat(seq2).splitlines()))
798
799        standardMsg = self._truncateMessage(standardMsg, diffMsg)
800        msg = self._formatMessage(msg, standardMsg)
801        self.fail(msg)
802
803    def _truncateMessage(self, message, diff):
804        max_diff = self.maxDiff
805        if max_diff is None or len(diff) <= max_diff:
806            return message + diff
807        return message + (DIFF_OMITTED % len(diff))
808
809    def assertListEqual(self, list1, list2, msg=None):
810        """A list-specific equality assertion.
811
812        Args:
813            list1: The first list to compare.
814            list2: The second list to compare.
815            msg: Optional message to use on failure instead of a list of
816                    differences.
817
818        """
819        self.assertSequenceEqual(list1, list2, msg, seq_type=list)
820
821    def assertTupleEqual(self, tuple1, tuple2, msg=None):
822        """A tuple-specific equality assertion.
823
824        Args:
825            tuple1: The first tuple to compare.
826            tuple2: The second tuple to compare.
827            msg: Optional message to use on failure instead of a list of
828                    differences.
829        """
830        self.assertSequenceEqual(tuple1, tuple2, msg, seq_type=tuple)
831
832    def assertSetEqual(self, set1, set2, msg=None):
833        """A set-specific equality assertion.
834
835        Args:
836            set1: The first set to compare.
837            set2: The second set to compare.
838            msg: Optional message to use on failure instead of a list of
839                    differences.
840
841        assertSetEqual uses ducktyping to support
842        different types of sets, and is optimized for sets specifically
843        (parameters must support a difference method).
844        """
845        try:
846            difference1 = set1.difference(set2)
847        except TypeError as e:
848            self.fail('invalid type when attempting set difference: %s' % e)
849        except AttributeError as e:
850            self.fail('first argument does not support set difference: %s' % e)
851
852        try:
853            difference2 = set2.difference(set1)
854        except TypeError as e:
855            self.fail('invalid type when attempting set difference: %s' % e)
856        except AttributeError as e:
857            self.fail(
858                'second argument does not support set difference: %s' %
859                e)
860
861        if not (difference1 or difference2):
862            return
863
864        lines = []
865        if difference1:
866            lines.append('Items in the first set but not the second:')
867            for item in difference1:
868                lines.append(repr(item))
869        if difference2:
870            lines.append('Items in the second set but not the first:')
871            for item in difference2:
872                lines.append(repr(item))
873
874        standardMsg = '\n'.join(lines)
875        self.fail(self._formatMessage(msg, standardMsg))
876
877    def assertIn(self, member, container, msg=None):
878        """Just like self.assertTrue(a in b), but with a nicer default message."""
879        if member not in container:
880            standardMsg = '%s not found in %s' % (safe_repr(member),
881                                                  safe_repr(container))
882            self.fail(self._formatMessage(msg, standardMsg))
883
884    def assertNotIn(self, member, container, msg=None):
885        """Just like self.assertTrue(a not in b), but with a nicer default message."""
886        if member in container:
887            standardMsg = '%s unexpectedly found in %s' % (
888                safe_repr(member), safe_repr(container))
889            self.fail(self._formatMessage(msg, standardMsg))
890
891    def assertIs(self, expr1, expr2, msg=None):
892        """Just like self.assertTrue(a is b), but with a nicer default message."""
893        if expr1 is not expr2:
894            standardMsg = '%s is not %s' % (safe_repr(expr1), safe_repr(expr2))
895            self.fail(self._formatMessage(msg, standardMsg))
896
897    def assertIsNot(self, expr1, expr2, msg=None):
898        """Just like self.assertTrue(a is not b), but with a nicer default message."""
899        if expr1 is expr2:
900            standardMsg = 'unexpectedly identical: %s' % (safe_repr(expr1),)
901            self.fail(self._formatMessage(msg, standardMsg))
902
903    def assertDictEqual(self, d1, d2, msg=None):
904        self.assert_(
905            isinstance(
906                d1,
907                dict),
908            'First argument is not a dictionary')
909        self.assert_(
910            isinstance(
911                d2,
912                dict),
913            'Second argument is not a dictionary')
914
915        if d1 != d2:
916            standardMsg = '%s != %s' % (
917                safe_repr(d1, True), safe_repr(d2, True))
918            diff = ('\n' + '\n'.join(difflib.ndiff(
919                           pprint.pformat(d1).splitlines(),
920                           pprint.pformat(d2).splitlines())))
921            standardMsg = self._truncateMessage(standardMsg, diff)
922            self.fail(self._formatMessage(msg, standardMsg))
923
924    def assertDictContainsSubset(self, expected, actual, msg=None):
925        """Checks whether actual is a superset of expected."""
926        missing = []
927        mismatched = []
928        for key, value in expected.iteritems():
929            if key not in actual:
930                missing.append(key)
931            elif value != actual[key]:
932                mismatched.append('%s, expected: %s, actual: %s' %
933                                  (safe_repr(key), safe_repr(value),
934                                   safe_repr(actual[key])))
935
936        if not (missing or mismatched):
937            return
938
939        standardMsg = ''
940        if missing:
941            standardMsg = 'Missing: %s' % ','.join(safe_repr(m) for m in
942                                                   missing)
943        if mismatched:
944            if standardMsg:
945                standardMsg += '; '
946            standardMsg += 'Mismatched values: %s' % ','.join(mismatched)
947
948        self.fail(self._formatMessage(msg, standardMsg))
949
950    def assertItemsEqual(self, expected_seq, actual_seq, msg=None):
951        """An unordered sequence specific comparison. It asserts that
952        expected_seq and actual_seq contain the same elements. It is
953        the equivalent of::
954
955            self.assertEqual(sorted(expected_seq), sorted(actual_seq))
956
957        Raises with an error message listing which elements of expected_seq
958        are missing from actual_seq and vice versa if any.
959
960        Asserts that each element has the same count in both sequences.
961        Example:
962            - [0, 1, 1] and [1, 0, 1] compare equal.
963            - [0, 0, 1] and [0, 1] compare unequal.
964        """
965        try:
966            expected = sorted(expected_seq)
967            actual = sorted(actual_seq)
968        except TypeError:
969            # Unsortable items (example: set(), complex(), ...)
970            expected = list(expected_seq)
971            actual = list(actual_seq)
972            missing, unexpected = unorderable_list_difference(
973                expected, actual, ignore_duplicate=False
974            )
975        else:
976            return self.assertSequenceEqual(expected, actual, msg=msg)
977
978        errors = []
979        if missing:
980            errors.append('Expected, but missing:\n    %s' %
981                          safe_repr(missing))
982        if unexpected:
983            errors.append('Unexpected, but present:\n    %s' %
984                          safe_repr(unexpected))
985        if errors:
986            standardMsg = '\n'.join(errors)
987            self.fail(self._formatMessage(msg, standardMsg))
988
989    def assertMultiLineEqual(self, first, second, msg=None):
990        """Assert that two multi-line strings are equal."""
991        self.assert_(isinstance(first, str), (
992            'First argument is not a string'))
993        self.assert_(isinstance(second, str), (
994            'Second argument is not a string'))
995
996        if first != second:
997            standardMsg = '%s != %s' % (
998                safe_repr(first, True), safe_repr(second, True))
999            diff = '\n' + ''.join(difflib.ndiff(first.splitlines(True),
1000                                                second.splitlines(True)))
1001            standardMsg = self._truncateMessage(standardMsg, diff)
1002            self.fail(self._formatMessage(msg, standardMsg))
1003
1004    def assertLess(self, a, b, msg=None):
1005        """Just like self.assertTrue(a < b), but with a nicer default message."""
1006        if not a < b:
1007            standardMsg = '%s not less than %s' % (safe_repr(a), safe_repr(b))
1008            self.fail(self._formatMessage(msg, standardMsg))
1009
1010    def assertLessEqual(self, a, b, msg=None):
1011        """Just like self.assertTrue(a <= b), but with a nicer default message."""
1012        if not a <= b:
1013            standardMsg = '%s not less than or equal to %s' % (
1014                safe_repr(a), safe_repr(b))
1015            self.fail(self._formatMessage(msg, standardMsg))
1016
1017    def assertGreater(self, a, b, msg=None):
1018        """Just like self.assertTrue(a > b), but with a nicer default message."""
1019        if not a > b:
1020            standardMsg = '%s not greater than %s' % (
1021                safe_repr(a), safe_repr(b))
1022            self.fail(self._formatMessage(msg, standardMsg))
1023
1024    def assertGreaterEqual(self, a, b, msg=None):
1025        """Just like self.assertTrue(a >= b), but with a nicer default message."""
1026        if not a >= b:
1027            standardMsg = '%s not greater than or equal to %s' % (
1028                safe_repr(a), safe_repr(b))
1029            self.fail(self._formatMessage(msg, standardMsg))
1030
1031    def assertIsNone(self, obj, msg=None):
1032        """Same as self.assertTrue(obj is None), with a nicer default message."""
1033        if obj is not None:
1034            standardMsg = '%s is not None' % (safe_repr(obj),)
1035            self.fail(self._formatMessage(msg, standardMsg))
1036
1037    def assertIsNotNone(self, obj, msg=None):
1038        """Included for symmetry with assertIsNone."""
1039        if obj is None:
1040            standardMsg = 'unexpectedly None'
1041            self.fail(self._formatMessage(msg, standardMsg))
1042
1043    def assertIsInstance(self, obj, cls, msg=None):
1044        """Same as self.assertTrue(isinstance(obj, cls)), with a nicer
1045        default message."""
1046        if not isinstance(obj, cls):
1047            standardMsg = '%s is not an instance of %r' % (safe_repr(obj), cls)
1048            self.fail(self._formatMessage(msg, standardMsg))
1049
1050    def assertNotIsInstance(self, obj, cls, msg=None):
1051        """Included for symmetry with assertIsInstance."""
1052        if isinstance(obj, cls):
1053            standardMsg = '%s is an instance of %r' % (safe_repr(obj), cls)
1054            self.fail(self._formatMessage(msg, standardMsg))
1055
1056    def assertRaisesRegexp(self, expected_exception, expected_regexp,
1057                           callable_obj=None, *args, **kwargs):
1058        """Asserts that the message in a raised exception matches a regexp.
1059
1060        Args:
1061            expected_exception: Exception class expected to be raised.
1062            expected_regexp: Regexp (re pattern object or string) expected
1063                    to be found in error message.
1064            callable_obj: Function to be called.
1065            args: Extra args.
1066            kwargs: Extra kwargs.
1067        """
1068        if callable_obj is None:
1069            return _AssertRaisesContext(
1070                expected_exception, self, expected_regexp)
1071        try:
1072            callable_obj(*args, **kwargs)
1073        except expected_exception as exc_value:
1074            if isinstance(expected_regexp, str):
1075                expected_regexp = re.compile(expected_regexp)
1076            if not expected_regexp.search(str(exc_value)):
1077                raise self.failureException(
1078                    '"%s" does not match "%s"' %
1079                    (expected_regexp.pattern, str(exc_value)))
1080        else:
1081            if hasattr(expected_exception, '__name__'):
1082                excName = expected_exception.__name__
1083            else:
1084                excName = str(expected_exception)
1085            raise self.failureException("%s not raised" % excName)
1086
1087    def assertRegexpMatches(self, text, expected_regexp, msg=None):
1088        """Fail the test unless the text matches the regular expression."""
1089        if isinstance(expected_regexp, str):
1090            expected_regexp = re.compile(expected_regexp)
1091        if not expected_regexp.search(text):
1092            msg = msg or "Regexp didn't match"
1093            msg = '%s: %r not found in %r' % (
1094                msg, expected_regexp.pattern, text)
1095            raise self.failureException(msg)
1096
1097    def assertNotRegexpMatches(self, text, unexpected_regexp, msg=None):
1098        """Fail the test if the text matches the regular expression."""
1099        if isinstance(unexpected_regexp, str):
1100            unexpected_regexp = re.compile(unexpected_regexp)
1101        match = unexpected_regexp.search(text)
1102        if match:
1103            msg = msg or "Regexp matched"
1104            msg = '%s: %r matches %r in %r' % (msg,
1105                                               text[match.start():match.end()],
1106                                               unexpected_regexp.pattern,
1107                                               text)
1108            raise self.failureException(msg)
1109
1110
1111class FunctionTestCase(TestCase):
1112    """A test case that wraps a test function.
1113
1114    This is useful for slipping pre-existing test functions into the
1115    unittest framework. Optionally, set-up and tidy-up functions can be
1116    supplied. As with TestCase, the tidy-up ('tearDown') function will
1117    always be called if the set-up ('setUp') function ran successfully.
1118    """
1119
1120    def __init__(self, testFunc, setUp=None, tearDown=None, description=None):
1121        super(FunctionTestCase, self).__init__()
1122        self._setUpFunc = setUp
1123        self._tearDownFunc = tearDown
1124        self._testFunc = testFunc
1125        self._description = description
1126
1127    def setUp(self):
1128        if self._setUpFunc is not None:
1129            self._setUpFunc()
1130
1131    def tearDown(self):
1132        if self._tearDownFunc is not None:
1133            self._tearDownFunc()
1134
1135    def runTest(self):
1136        self._testFunc()
1137
1138    def id(self):
1139        return self._testFunc.__name__
1140
1141    def __eq__(self, other):
1142        if not isinstance(other, self.__class__):
1143            return NotImplemented
1144
1145        return self._setUpFunc == other._setUpFunc and \
1146            self._tearDownFunc == other._tearDownFunc and \
1147            self._testFunc == other._testFunc and \
1148            self._description == other._description
1149
1150    def __ne__(self, other):
1151        return not self == other
1152
1153    def __hash__(self):
1154        return hash((type(self), self._setUpFunc, self._tearDownFunc,
1155                     self._testFunc, self._description))
1156
1157    def __str__(self):
1158        return "%s (%s)" % (strclass(self.__class__),
1159                            self._testFunc.__name__)
1160
1161    def __repr__(self):
1162        return "<%s testFunc=%s>" % (strclass(self.__class__),
1163                                     self._testFunc)
1164
1165    def shortDescription(self):
1166        if self._description is not None:
1167            return self._description
1168        doc = self._testFunc.__doc__
1169        return doc and doc.split("\n")[0].strip() or None
1170