1"""Test case implementation"""
2
3import sys
4import difflib
5import pprint
6import re
7import unittest
8import warnings
9
10import six
11
12from unittest2 import result
13from unittest2.util import (
14    safe_repr, safe_str, strclass,
15    unorderable_list_difference
16)
17
18from unittest2.compatibility import wraps
19
20__unittest = True
21
22
23DIFF_OMITTED = ('\nDiff is %s characters long. '
24                'Set self.maxDiff to None to see it.')
25
26
27class SkipTest(Exception):
28    """
29    Raise this exception in a test to skip it.
30
31    Usually you can use TestResult.skip() or one of the skipping decorators
32    instead of raising this directly.
33    """
34
35
36class _ExpectedFailure(Exception):
37    """
38    Raise this when a test is expected to fail.
39
40    This is an implementation detail.
41    """
42
43    def __init__(self, exc_info, bugnumber=None):
44        # can't use super because Python 2.4 exceptions are old style
45        Exception.__init__(self)
46        self.exc_info = exc_info
47        self.bugnumber = bugnumber
48
49
50class _UnexpectedSuccess(Exception):
51    """
52    The test was supposed to fail, but it didn't!
53    """
54
55    def __init__(self, exc_info, bugnumber=None):
56        # can't use super because Python 2.4 exceptions are old style
57        Exception.__init__(self)
58        self.exc_info = exc_info
59        self.bugnumber = bugnumber
60
61
62def _id(obj):
63    return obj
64
65
66def skip(reason):
67    """
68    Unconditionally skip a test.
69    """
70    def decorator(test_item):
71        if not (
72            isinstance(
73                test_item,
74                type) and issubclass(
75                test_item,
76                TestCase)):
77            @wraps(test_item)
78            def skip_wrapper(*args, **kwargs):
79                raise SkipTest(reason)
80            test_item = skip_wrapper
81
82        test_item.__unittest_skip__ = True
83        test_item.__unittest_skip_why__ = reason
84        return test_item
85    return decorator
86
87
88def skipIf(condition, reason):
89    """
90    Skip a test if the condition is true.
91    """
92    if condition:
93        return skip(reason)
94    return _id
95
96
97def skipUnless(condition, reason):
98    """
99    Skip a test unless the condition is true.
100    """
101    if not condition:
102        return skip(reason)
103    return _id
104
105
106def expectedFailure(bugnumber=None):
107    if callable(bugnumber):
108        @wraps(bugnumber)
109        def expectedFailure_easy_wrapper(*args, **kwargs):
110            try:
111                bugnumber(*args, **kwargs)
112            except Exception:
113                raise _ExpectedFailure(sys.exc_info(), None)
114            raise _UnexpectedSuccess(sys.exc_info(), None)
115        return expectedFailure_easy_wrapper
116    else:
117        def expectedFailure_impl(func):
118            @wraps(func)
119            def wrapper(*args, **kwargs):
120                try:
121                    func(*args, **kwargs)
122                except Exception:
123                    raise _ExpectedFailure(sys.exc_info(), bugnumber)
124                raise _UnexpectedSuccess(sys.exc_info(), bugnumber)
125            return wrapper
126        return expectedFailure_impl
127
128
129class _AssertRaisesContext(object):
130    """A context manager used to implement TestCase.assertRaises* methods."""
131
132    def __init__(self, expected, test_case, expected_regexp=None):
133        self.expected = expected
134        self.failureException = test_case.failureException
135        self.expected_regexp = expected_regexp
136
137    def __enter__(self):
138        return self
139
140    def __exit__(self, exc_type, exc_value, tb):
141        if exc_type is None:
142            try:
143                exc_name = self.expected.__name__
144            except AttributeError:
145                exc_name = str(self.expected)
146            raise self.failureException(
147                "%s not raised" % (exc_name,))
148        if not issubclass(exc_type, self.expected):
149            # let unexpected exceptions pass through
150            return False
151        self.exception = exc_value  # store for later retrieval
152        if self.expected_regexp is None:
153            return True
154
155        expected_regexp = self.expected_regexp
156        if isinstance(expected_regexp, six.string_types):
157            expected_regexp = re.compile(expected_regexp)
158        if not expected_regexp.search(str(exc_value)):
159            raise self.failureException(
160                '"%s" does not match "%s"' %
161                (expected_regexp.pattern, str(exc_value)))
162        return True
163
164
165class _TypeEqualityDict(object):
166
167    def __init__(self, testcase):
168        self.testcase = testcase
169        self._store = {}
170
171    def __setitem__(self, key, value):
172        self._store[key] = value
173
174    def __getitem__(self, key):
175        value = self._store[key]
176        if isinstance(value, six.string_types):
177            return getattr(self.testcase, value)
178        return value
179
180    def get(self, key, default=None):
181        if key in self._store:
182            return self[key]
183        return default
184
185
186class TestCase(unittest.TestCase):
187    """A class whose instances are single test cases.
188
189    By default, the test code itself should be placed in a method named
190    'runTest'.
191
192    If the fixture may be used for many test cases, create as
193    many test methods as are needed. When instantiating such a TestCase
194    subclass, specify in the constructor arguments the name of the test method
195    that the instance is to execute.
196
197    Test authors should subclass TestCase for their own tests. Construction
198    and deconstruction of the test's environment ('fixture') can be
199    implemented by overriding the 'setUp' and 'tearDown' methods respectively.
200
201    If it is necessary to override the __init__ method, the base class
202    __init__ method must always be called. It is important that subclasses
203    should not change the signature of their __init__ method, since instances
204    of the classes are instantiated automatically by parts of the framework
205    in order to be run.
206    """
207
208    # This attribute determines which exception will be raised when
209    # the instance's assertion methods fail; test methods raising this
210    # exception will be deemed to have 'failed' rather than 'errored'
211
212    failureException = AssertionError
213
214    # This attribute sets the maximum length of a diff in failure messages
215    # by assert methods using difflib. It is looked up as an instance attribute
216    # so can be configured by individual tests if required.
217
218    maxDiff = 80 * 8
219
220    # This attribute determines whether long messages (including repr of
221    # objects used in assert methods) will be printed on failure in *addition*
222    # to any explicit message passed.
223
224    longMessage = True
225
226    # Attribute used by TestSuite for classSetUp
227
228    _classSetupFailed = False
229
230    def __init__(self, methodName='runTest'):
231        """Create an instance of the class that will use the named test
232           method when executed. Raises a ValueError if the instance does
233           not have a method with the specified name.
234        """
235        self._testMethodName = methodName
236        self._resultForDoCleanups = None
237        try:
238            testMethod = getattr(self, methodName)
239        except AttributeError:
240            raise ValueError("no such test method in %s: %s" %
241                             (self.__class__, methodName))
242        self._testMethodDoc = testMethod.__doc__
243        self._cleanups = []
244
245        # Map types to custom assertEqual functions that will compare
246        # instances of said type in more detail to generate a more useful
247        # error message.
248        self._type_equality_funcs = _TypeEqualityDict(self)
249        self.addTypeEqualityFunc(dict, 'assertDictEqual')
250        self.addTypeEqualityFunc(list, 'assertListEqual')
251        self.addTypeEqualityFunc(tuple, 'assertTupleEqual')
252        self.addTypeEqualityFunc(set, 'assertSetEqual')
253        self.addTypeEqualityFunc(frozenset, 'assertSetEqual')
254        if six.PY2:
255            self.addTypeEqualityFunc(unicode, 'assertMultiLineEqual')
256        else:
257            self.addTypeEqualityFunc(str, 'assertMultiLineEqual')
258
259    def addTypeEqualityFunc(self, typeobj, function):
260        """Add a type specific assertEqual style function to compare a type.
261
262        This method is for use by TestCase subclasses that need to register
263        their own type equality functions to provide nicer error messages.
264
265        Args:
266            typeobj: The data type to call this function on when both values
267                    are of the same type in assertEqual().
268            function: The callable taking two arguments and an optional
269                    msg= argument that raises self.failureException with a
270                    useful error message when the two arguments are not equal.
271        """
272        self._type_equality_funcs[typeobj] = function
273
274    def addCleanup(self, function, *args, **kwargs):
275        """Add a function, with arguments, to be called when the test is
276        completed. Functions added are called on a LIFO basis and are
277        called after tearDown on test failure or success.
278
279        Cleanup items are called even if setUp fails (unlike tearDown)."""
280        self._cleanups.append((function, args, kwargs))
281
282    def setUp(self):
283        "Hook method for setting up the test fixture before exercising it."
284
285    @classmethod
286    def setUpClass(cls):
287        "Hook method for setting up class fixture before running tests in the class."
288
289    @classmethod
290    def tearDownClass(cls):
291        "Hook method for deconstructing the class fixture after running all tests in the class."
292
293    def tearDown(self):
294        "Hook method for deconstructing the test fixture after testing it."
295
296    def countTestCases(self):
297        return 1
298
299    def defaultTestResult(self):
300        return result.TestResult()
301
302    def shortDescription(self):
303        """Returns a one-line description of the test, or None if no
304        description has been provided.
305
306        The default implementation of this method returns the first line of
307        the specified test method's docstring.
308        """
309        doc = self._testMethodDoc
310        return doc and doc.split("\n")[0].strip() or None
311
312    def id(self):
313        return "%s.%s" % (strclass(self.__class__), self._testMethodName)
314
315    def __eq__(self, other):
316        if not isinstance(self, type(other)):
317            return NotImplemented
318
319        return self._testMethodName == other._testMethodName
320
321    def __ne__(self, other):
322        return not self == other
323
324    def __hash__(self):
325        return hash((type(self), self._testMethodName))
326
327    def __str__(self):
328        return "%s (%s)" % (self._testMethodName, strclass(self.__class__))
329
330    def __repr__(self):
331        return "<%s testMethod=%s>" % \
332               (strclass(self.__class__), self._testMethodName)
333
334    def _addSkip(self, result, reason):
335        addSkip = getattr(result, 'addSkip', None)
336        if addSkip is not None:
337            addSkip(self, reason)
338        else:
339            warnings.warn(
340                "Use of a TestResult without an addSkip method is deprecated",
341                DeprecationWarning,
342                2)
343            result.addSuccess(self)
344
345    def run(self, result=None):
346        orig_result = result
347        if result is None:
348            result = self.defaultTestResult()
349            startTestRun = getattr(result, 'startTestRun', None)
350            if startTestRun is not None:
351                startTestRun()
352
353        self._resultForDoCleanups = result
354        result.startTest(self)
355
356        testMethod = getattr(self, self._testMethodName)
357
358        if (getattr(self.__class__, "__unittest_skip__", False) or
359                getattr(testMethod, "__unittest_skip__", False)):
360            # If the class or method was skipped.
361            try:
362                skip_why = (
363                    getattr(
364                        self.__class__,
365                        '__unittest_skip_why__',
366                        '') or getattr(
367                        testMethod,
368                        '__unittest_skip_why__',
369                        ''))
370                self._addSkip(result, skip_why)
371            finally:
372                result.stopTest(self)
373            return
374        try:
375            success = False
376            try:
377                self.setUp()
378            except SkipTest as e:
379                self._addSkip(result, str(e))
380            except Exception:
381                result.addError(self, sys.exc_info())
382            else:
383                success = self.runMethod(testMethod, result)
384
385                try:
386                    self.tearDown()
387                except Exception:
388                    result.addCleanupError(self, sys.exc_info())
389                    success = False
390
391                self.dumpSessionInfo()
392
393            cleanUpSuccess = self.doCleanups()
394            success = success and cleanUpSuccess
395            if success:
396                result.addSuccess(self)
397        finally:
398            result.stopTest(self)
399            if orig_result is None:
400                stopTestRun = getattr(result, 'stopTestRun', None)
401                if stopTestRun is not None:
402                    stopTestRun()
403
404    def runMethod(self, testMethod, result):
405        """Runs the test method and catches any exception that might be thrown.
406
407        This is factored out of TestCase.run() to ensure that any exception
408        thrown during the test goes out of scope before tearDown.  Otherwise, an
409        exception could hold references to Python objects that are bound to
410        SB objects and prevent them from being deleted in time.
411        """
412        try:
413            testMethod()
414        except self.failureException:
415            result.addFailure(self, sys.exc_info())
416        except _ExpectedFailure as e:
417            addExpectedFailure = getattr(result, 'addExpectedFailure', None)
418            if addExpectedFailure is not None:
419                addExpectedFailure(self, e.exc_info, e.bugnumber)
420            else:
421                warnings.warn(
422                    "Use of a TestResult without an addExpectedFailure method is deprecated",
423                    DeprecationWarning)
424                result.addSuccess(self)
425        except _UnexpectedSuccess as x:
426            addUnexpectedSuccess = getattr(
427                result, 'addUnexpectedSuccess', None)
428            if addUnexpectedSuccess is not None:
429                addUnexpectedSuccess(self, x.bugnumber)
430            else:
431                warnings.warn(
432                    "Use of a TestResult without an addUnexpectedSuccess method is deprecated",
433                    DeprecationWarning)
434                result.addFailure(self, sys.exc_info())
435        except SkipTest as e:
436            self._addSkip(result, str(e))
437        except Exception:
438            result.addError(self, sys.exc_info())
439        else:
440            return True
441        return False
442
443    def doCleanups(self):
444        """Execute all cleanup functions. Normally called for you after
445        tearDown."""
446        result = self._resultForDoCleanups
447        ok = True
448        while self._cleanups:
449            function, args, kwargs = self._cleanups.pop(-1)
450            try:
451                function(*args, **kwargs)
452            except Exception:
453                ok = False
454                result.addError(self, sys.exc_info())
455        return ok
456
457    def __call__(self, *args, **kwds):
458        return self.run(*args, **kwds)
459
460    def debug(self):
461        """Run the test without collecting errors in a TestResult"""
462        self.setUp()
463        getattr(self, self._testMethodName)()
464        self.tearDown()
465        while self._cleanups:
466            function, args, kwargs = self._cleanups.pop(-1)
467            function(*args, **kwargs)
468
469    def skipTest(self, reason):
470        """Skip this test."""
471        raise SkipTest(reason)
472
473    def fail(self, msg=None):
474        """Fail immediately, with the given message."""
475        raise self.failureException(msg)
476
477    def assertFalse(self, expr, msg=None):
478        "Fail the test if the expression is true."
479        if expr:
480            msg = self._formatMessage(msg, "%s is not False" % safe_repr(expr))
481            raise self.failureException(msg)
482
483    def assertTrue(self, expr, msg=None):
484        """Fail the test unless the expression is true."""
485        if not expr:
486            msg = self._formatMessage(msg, "%s is not True" % safe_repr(expr))
487            raise self.failureException(msg)
488
489    def _formatMessage(self, msg, standardMsg):
490        """Honour the longMessage attribute when generating failure messages.
491        If longMessage is False this means:
492        * Use only an explicit message if it is provided
493        * Otherwise use the standard message for the assert
494
495        If longMessage is True:
496        * Use the standard message
497        * If an explicit message is provided, plus ' : ' and the explicit message
498        """
499        if not self.longMessage:
500            return msg or standardMsg
501        if msg is None:
502            return standardMsg
503        try:
504            return '%s : %s' % (standardMsg, msg)
505        except UnicodeDecodeError:
506            return '%s : %s' % (safe_str(standardMsg), safe_str(msg))
507
508    def assertRaises(self, excClass, callableObj=None, *args, **kwargs):
509        """Fail unless an exception of class excClass is thrown
510           by callableObj when invoked with arguments args and keyword
511           arguments kwargs. If a different type of exception is
512           thrown, it will not be caught, and the test case will be
513           deemed to have suffered an error, exactly as for an
514           unexpected exception.
515
516           If called with callableObj omitted or None, will return a
517           context object used like this::
518
519                with self.assertRaises(SomeException):
520                    do_something()
521
522           The context manager keeps a reference to the exception as
523           the 'exception' attribute. This allows you to inspect the
524           exception after the assertion::
525
526               with self.assertRaises(SomeException) as cm:
527                   do_something()
528               the_exception = cm.exception
529               self.assertEqual(the_exception.error_code, 3)
530        """
531        if callableObj is None:
532            return _AssertRaisesContext(excClass, self)
533        try:
534            callableObj(*args, **kwargs)
535        except excClass:
536            return
537
538        if hasattr(excClass, '__name__'):
539            excName = excClass.__name__
540        else:
541            excName = str(excClass)
542        raise self.failureException("%s not raised" % excName)
543
544    def _getAssertEqualityFunc(self, first, second):
545        """Get a detailed comparison function for the types of the two args.
546
547        Returns: A callable accepting (first, second, msg=None) that will
548        raise a failure exception if first != second with a useful human
549        readable error message for those types.
550        """
551        #
552        # NOTE(gregory.p.smith): I considered isinstance(first, type(second))
553        # and vice versa.  I opted for the conservative approach in case
554        # subclasses are not intended to be compared in detail to their super
555        # class instances using a type equality func.  This means testing
556        # subtypes won't automagically use the detailed comparison.  Callers
557        # should use their type specific assertSpamEqual method to compare
558        # subclasses if the detailed comparison is desired and appropriate.
559        # See the discussion in http://bugs.python.org/issue2578.
560        #
561        if isinstance(first, type(second)):
562            asserter = self._type_equality_funcs.get(type(first))
563            if asserter is not None:
564                return asserter
565
566        return self._baseAssertEqual
567
568    def _baseAssertEqual(self, first, second, msg=None):
569        """The default assertEqual implementation, not type specific."""
570        if not first == second:
571            standardMsg = '%s != %s' % (safe_repr(first), safe_repr(second))
572            msg = self._formatMessage(msg, standardMsg)
573            raise self.failureException(msg)
574
575    def assertEqual(self, first, second, msg=None):
576        """Fail if the two objects are unequal as determined by the '=='
577           operator.
578        """
579        assertion_func = self._getAssertEqualityFunc(first, second)
580        assertion_func(first, second, msg=msg)
581
582    def assertNotEqual(self, first, second, msg=None):
583        """Fail if the two objects are equal as determined by the '=='
584           operator.
585        """
586        if not first != second:
587            msg = self._formatMessage(msg, '%s == %s' % (safe_repr(first),
588                                                         safe_repr(second)))
589            raise self.failureException(msg)
590
591    def assertAlmostEqual(
592            self,
593            first,
594            second,
595            places=None,
596            msg=None,
597            delta=None):
598        """Fail if the two objects are unequal as determined by their
599           difference rounded to the given number of decimal places
600           (default 7) and comparing to zero, or by comparing that the
601           between the two objects is more than the given delta.
602
603           Note that decimal places (from zero) are usually not the same
604           as significant digits (measured from the most signficant digit).
605
606           If the two objects compare equal then they will automatically
607           compare almost equal.
608        """
609        if first == second:
610            # shortcut
611            return
612        if delta is not None and places is not None:
613            raise TypeError("specify delta or places not both")
614
615        if delta is not None:
616            if abs(first - second) <= delta:
617                return
618
619            standardMsg = '%s != %s within %s delta' % (safe_repr(first),
620                                                        safe_repr(second),
621                                                        safe_repr(delta))
622        else:
623            if places is None:
624                places = 7
625
626            if round(abs(second - first), places) == 0:
627                return
628
629            standardMsg = '%s != %s within %r places' % (safe_repr(first),
630                                                         safe_repr(second),
631                                                         places)
632        msg = self._formatMessage(msg, standardMsg)
633        raise self.failureException(msg)
634
635    def assertNotAlmostEqual(
636            self,
637            first,
638            second,
639            places=None,
640            msg=None,
641            delta=None):
642        """Fail if the two objects are equal as determined by their
643           difference rounded to the given number of decimal places
644           (default 7) and comparing to zero, or by comparing that the
645           between the two objects is less than the given delta.
646
647           Note that decimal places (from zero) are usually not the same
648           as significant digits (measured from the most signficant digit).
649
650           Objects that are equal automatically fail.
651        """
652        if delta is not None and places is not None:
653            raise TypeError("specify delta or places not both")
654        if delta is not None:
655            if not (first == second) and abs(first - second) > delta:
656                return
657            standardMsg = '%s == %s within %s delta' % (safe_repr(first),
658                                                        safe_repr(second),
659                                                        safe_repr(delta))
660        else:
661            if places is None:
662                places = 7
663            if not (first == second) and round(
664                    abs(second - first), places) != 0:
665                return
666            standardMsg = '%s == %s within %r places' % (safe_repr(first),
667                                                         safe_repr(second),
668                                                         places)
669
670        msg = self._formatMessage(msg, standardMsg)
671        raise self.failureException(msg)
672
673    # Synonyms for assertion methods
674
675    # The plurals are undocumented.  Keep them that way to discourage use.
676    # Do not add more.  Do not remove.
677    # Going through a deprecation cycle on these would annoy many people.
678    assertEquals = assertEqual
679    assertNotEquals = assertNotEqual
680    assertAlmostEquals = assertAlmostEqual
681    assertNotAlmostEquals = assertNotAlmostEqual
682    assert_ = assertTrue
683
684    # These fail* assertion method names are pending deprecation and will
685    # be a DeprecationWarning in 3.2; http://bugs.python.org/issue2578
686    def _deprecate(original_func):
687        def deprecated_func(*args, **kwargs):
688            warnings.warn(
689                ('Please use %s instead.' % original_func.__name__),
690                PendingDeprecationWarning, 2)
691            return original_func(*args, **kwargs)
692        return deprecated_func
693
694    failUnlessEqual = _deprecate(assertEqual)
695    failIfEqual = _deprecate(assertNotEqual)
696    failUnlessAlmostEqual = _deprecate(assertAlmostEqual)
697    failIfAlmostEqual = _deprecate(assertNotAlmostEqual)
698    failUnless = _deprecate(assertTrue)
699    failUnlessRaises = _deprecate(assertRaises)
700    failIf = _deprecate(assertFalse)
701
702    def assertSequenceEqual(self, seq1, seq2,
703                            msg=None, seq_type=None, max_diff=80 * 8):
704        """An equality assertion for ordered sequences (like lists and tuples).
705
706        For the purposes of this function, a valid ordered sequence type is one
707        which can be indexed, has a length, and has an equality operator.
708
709        Args:
710            seq1: The first sequence to compare.
711            seq2: The second sequence to compare.
712            seq_type: The expected datatype of the sequences, or None if no
713                    datatype should be enforced.
714            msg: Optional message to use on failure instead of a list of
715                    differences.
716            max_diff: Maximum size off the diff, larger diffs are not shown
717        """
718        if seq_type is not None:
719            seq_type_name = seq_type.__name__
720            if not isinstance(seq1, seq_type):
721                raise self.failureException('First sequence is not a %s: %s'
722                                            % (seq_type_name, safe_repr(seq1)))
723            if not isinstance(seq2, seq_type):
724                raise self.failureException('Second sequence is not a %s: %s'
725                                            % (seq_type_name, safe_repr(seq2)))
726        else:
727            seq_type_name = "sequence"
728
729        differing = None
730        try:
731            len1 = len(seq1)
732        except (TypeError, NotImplementedError):
733            differing = 'First %s has no length.    Non-sequence?' % (
734                seq_type_name)
735
736        if differing is None:
737            try:
738                len2 = len(seq2)
739            except (TypeError, NotImplementedError):
740                differing = 'Second %s has no length.    Non-sequence?' % (
741                    seq_type_name)
742
743        if differing is None:
744            if seq1 == seq2:
745                return
746
747            seq1_repr = repr(seq1)
748            seq2_repr = repr(seq2)
749            if len(seq1_repr) > 30:
750                seq1_repr = seq1_repr[:30] + '...'
751            if len(seq2_repr) > 30:
752                seq2_repr = seq2_repr[:30] + '...'
753            elements = (seq_type_name.capitalize(), seq1_repr, seq2_repr)
754            differing = '%ss differ: %s != %s\n' % elements
755
756            for i in range(min(len1, len2)):
757                try:
758                    item1 = seq1[i]
759                except (TypeError, IndexError, NotImplementedError):
760                    differing += ('\nUnable to index element %d of first %s\n' %
761                                  (i, seq_type_name))
762                    break
763
764                try:
765                    item2 = seq2[i]
766                except (TypeError, IndexError, NotImplementedError):
767                    differing += ('\nUnable to index element %d of second %s\n' %
768                                  (i, seq_type_name))
769                    break
770
771                if item1 != item2:
772                    differing += ('\nFirst differing element %d:\n%s\n%s\n' %
773                                  (i, item1, item2))
774                    break
775            else:
776                if (len1 == len2 and seq_type is None and
777                        not isinstance(seq1, type(seq2))):
778                    # The sequences are the same, but have differing types.
779                    return
780
781            if len1 > len2:
782                differing += ('\nFirst %s contains %d additional '
783                              'elements.\n' % (seq_type_name, len1 - len2))
784                try:
785                    differing += ('First extra element %d:\n%s\n' %
786                                  (len2, seq1[len2]))
787                except (TypeError, IndexError, NotImplementedError):
788                    differing += ('Unable to index element %d '
789                                  'of first %s\n' % (len2, seq_type_name))
790            elif len1 < len2:
791                differing += ('\nSecond %s contains %d additional '
792                              'elements.\n' % (seq_type_name, len2 - len1))
793                try:
794                    differing += ('First extra element %d:\n%s\n' %
795                                  (len1, seq2[len1]))
796                except (TypeError, IndexError, NotImplementedError):
797                    differing += ('Unable to index element %d '
798                                  'of second %s\n' % (len1, seq_type_name))
799        standardMsg = differing
800        diffMsg = '\n' + '\n'.join(
801            difflib.ndiff(pprint.pformat(seq1).splitlines(),
802                          pprint.pformat(seq2).splitlines()))
803
804        standardMsg = self._truncateMessage(standardMsg, diffMsg)
805        msg = self._formatMessage(msg, standardMsg)
806        self.fail(msg)
807
808    def _truncateMessage(self, message, diff):
809        max_diff = self.maxDiff
810        if max_diff is None or len(diff) <= max_diff:
811            return message + diff
812        return message + (DIFF_OMITTED % len(diff))
813
814    def assertListEqual(self, list1, list2, msg=None):
815        """A list-specific equality assertion.
816
817        Args:
818            list1: The first list to compare.
819            list2: The second list to compare.
820            msg: Optional message to use on failure instead of a list of
821                    differences.
822
823        """
824        self.assertSequenceEqual(list1, list2, msg, seq_type=list)
825
826    def assertTupleEqual(self, tuple1, tuple2, msg=None):
827        """A tuple-specific equality assertion.
828
829        Args:
830            tuple1: The first tuple to compare.
831            tuple2: The second tuple to compare.
832            msg: Optional message to use on failure instead of a list of
833                    differences.
834        """
835        self.assertSequenceEqual(tuple1, tuple2, msg, seq_type=tuple)
836
837    def assertSetEqual(self, set1, set2, msg=None):
838        """A set-specific equality assertion.
839
840        Args:
841            set1: The first set to compare.
842            set2: The second set to compare.
843            msg: Optional message to use on failure instead of a list of
844                    differences.
845
846        assertSetEqual uses ducktyping to support
847        different types of sets, and is optimized for sets specifically
848        (parameters must support a difference method).
849        """
850        try:
851            difference1 = set1.difference(set2)
852        except TypeError as e:
853            self.fail('invalid type when attempting set difference: %s' % e)
854        except AttributeError as e:
855            self.fail('first argument does not support set difference: %s' % e)
856
857        try:
858            difference2 = set2.difference(set1)
859        except TypeError as e:
860            self.fail('invalid type when attempting set difference: %s' % e)
861        except AttributeError as e:
862            self.fail(
863                'second argument does not support set difference: %s' %
864                e)
865
866        if not (difference1 or difference2):
867            return
868
869        lines = []
870        if difference1:
871            lines.append('Items in the first set but not the second:')
872            for item in difference1:
873                lines.append(repr(item))
874        if difference2:
875            lines.append('Items in the second set but not the first:')
876            for item in difference2:
877                lines.append(repr(item))
878
879        standardMsg = '\n'.join(lines)
880        self.fail(self._formatMessage(msg, standardMsg))
881
882    def assertIn(self, member, container, msg=None):
883        """Just like self.assertTrue(a in b), but with a nicer default message."""
884        if member not in container:
885            standardMsg = '%s not found in %s' % (safe_repr(member),
886                                                  safe_repr(container))
887            self.fail(self._formatMessage(msg, standardMsg))
888
889    def assertNotIn(self, member, container, msg=None):
890        """Just like self.assertTrue(a not in b), but with a nicer default message."""
891        if member in container:
892            standardMsg = '%s unexpectedly found in %s' % (
893                safe_repr(member), safe_repr(container))
894            self.fail(self._formatMessage(msg, standardMsg))
895
896    def assertIs(self, expr1, expr2, msg=None):
897        """Just like self.assertTrue(a is b), but with a nicer default message."""
898        if expr1 is not expr2:
899            standardMsg = '%s is not %s' % (safe_repr(expr1), safe_repr(expr2))
900            self.fail(self._formatMessage(msg, standardMsg))
901
902    def assertIsNot(self, expr1, expr2, msg=None):
903        """Just like self.assertTrue(a is not b), but with a nicer default message."""
904        if expr1 is expr2:
905            standardMsg = 'unexpectedly identical: %s' % (safe_repr(expr1),)
906            self.fail(self._formatMessage(msg, standardMsg))
907
908    def assertDictEqual(self, d1, d2, msg=None):
909        self.assert_(
910            isinstance(
911                d1,
912                dict),
913            'First argument is not a dictionary')
914        self.assert_(
915            isinstance(
916                d2,
917                dict),
918            'Second argument is not a dictionary')
919
920        if d1 != d2:
921            standardMsg = '%s != %s' % (
922                safe_repr(d1, True), safe_repr(d2, True))
923            diff = ('\n' + '\n'.join(difflib.ndiff(
924                           pprint.pformat(d1).splitlines(),
925                           pprint.pformat(d2).splitlines())))
926            standardMsg = self._truncateMessage(standardMsg, diff)
927            self.fail(self._formatMessage(msg, standardMsg))
928
929    def assertDictContainsSubset(self, expected, actual, msg=None):
930        """Checks whether actual is a superset of expected."""
931        missing = []
932        mismatched = []
933        for key, value in expected.iteritems():
934            if key not in actual:
935                missing.append(key)
936            elif value != actual[key]:
937                mismatched.append('%s, expected: %s, actual: %s' %
938                                  (safe_repr(key), safe_repr(value),
939                                   safe_repr(actual[key])))
940
941        if not (missing or mismatched):
942            return
943
944        standardMsg = ''
945        if missing:
946            standardMsg = 'Missing: %s' % ','.join(safe_repr(m) for m in
947                                                   missing)
948        if mismatched:
949            if standardMsg:
950                standardMsg += '; '
951            standardMsg += 'Mismatched values: %s' % ','.join(mismatched)
952
953        self.fail(self._formatMessage(msg, standardMsg))
954
955    def assertItemsEqual(self, expected_seq, actual_seq, msg=None):
956        """An unordered sequence specific comparison. It asserts that
957        expected_seq and actual_seq contain the same elements. It is
958        the equivalent of::
959
960            self.assertEqual(sorted(expected_seq), sorted(actual_seq))
961
962        Raises with an error message listing which elements of expected_seq
963        are missing from actual_seq and vice versa if any.
964
965        Asserts that each element has the same count in both sequences.
966        Example:
967            - [0, 1, 1] and [1, 0, 1] compare equal.
968            - [0, 0, 1] and [0, 1] compare unequal.
969        """
970        try:
971            expected = sorted(expected_seq)
972            actual = sorted(actual_seq)
973        except TypeError:
974            # Unsortable items (example: set(), complex(), ...)
975            expected = list(expected_seq)
976            actual = list(actual_seq)
977            missing, unexpected = unorderable_list_difference(
978                expected, actual, ignore_duplicate=False
979            )
980        else:
981            return self.assertSequenceEqual(expected, actual, msg=msg)
982
983        errors = []
984        if missing:
985            errors.append('Expected, but missing:\n    %s' %
986                          safe_repr(missing))
987        if unexpected:
988            errors.append('Unexpected, but present:\n    %s' %
989                          safe_repr(unexpected))
990        if errors:
991            standardMsg = '\n'.join(errors)
992            self.fail(self._formatMessage(msg, standardMsg))
993
994    def assertMultiLineEqual(self, first, second, msg=None):
995        """Assert that two multi-line strings are equal."""
996        self.assert_(isinstance(first, six.string_types), (
997            'First argument is not a string'))
998        self.assert_(isinstance(second, six.string_types), (
999            'Second argument is not a string'))
1000
1001        if first != second:
1002            standardMsg = '%s != %s' % (
1003                safe_repr(first, True), safe_repr(second, True))
1004            diff = '\n' + ''.join(difflib.ndiff(first.splitlines(True),
1005                                                second.splitlines(True)))
1006            standardMsg = self._truncateMessage(standardMsg, diff)
1007            self.fail(self._formatMessage(msg, standardMsg))
1008
1009    def assertLess(self, a, b, msg=None):
1010        """Just like self.assertTrue(a < b), but with a nicer default message."""
1011        if not a < b:
1012            standardMsg = '%s not less than %s' % (safe_repr(a), safe_repr(b))
1013            self.fail(self._formatMessage(msg, standardMsg))
1014
1015    def assertLessEqual(self, a, b, msg=None):
1016        """Just like self.assertTrue(a <= b), but with a nicer default message."""
1017        if not a <= b:
1018            standardMsg = '%s not less than or equal to %s' % (
1019                safe_repr(a), safe_repr(b))
1020            self.fail(self._formatMessage(msg, standardMsg))
1021
1022    def assertGreater(self, a, b, msg=None):
1023        """Just like self.assertTrue(a > b), but with a nicer default message."""
1024        if not a > b:
1025            standardMsg = '%s not greater than %s' % (
1026                safe_repr(a), safe_repr(b))
1027            self.fail(self._formatMessage(msg, standardMsg))
1028
1029    def assertGreaterEqual(self, a, b, msg=None):
1030        """Just like self.assertTrue(a >= b), but with a nicer default message."""
1031        if not a >= b:
1032            standardMsg = '%s not greater than or equal to %s' % (
1033                safe_repr(a), safe_repr(b))
1034            self.fail(self._formatMessage(msg, standardMsg))
1035
1036    def assertIsNone(self, obj, msg=None):
1037        """Same as self.assertTrue(obj is None), with a nicer default message."""
1038        if obj is not None:
1039            standardMsg = '%s is not None' % (safe_repr(obj),)
1040            self.fail(self._formatMessage(msg, standardMsg))
1041
1042    def assertIsNotNone(self, obj, msg=None):
1043        """Included for symmetry with assertIsNone."""
1044        if obj is None:
1045            standardMsg = 'unexpectedly None'
1046            self.fail(self._formatMessage(msg, standardMsg))
1047
1048    def assertIsInstance(self, obj, cls, msg=None):
1049        """Same as self.assertTrue(isinstance(obj, cls)), with a nicer
1050        default message."""
1051        if not isinstance(obj, cls):
1052            standardMsg = '%s is not an instance of %r' % (safe_repr(obj), cls)
1053            self.fail(self._formatMessage(msg, standardMsg))
1054
1055    def assertNotIsInstance(self, obj, cls, msg=None):
1056        """Included for symmetry with assertIsInstance."""
1057        if isinstance(obj, cls):
1058            standardMsg = '%s is an instance of %r' % (safe_repr(obj), cls)
1059            self.fail(self._formatMessage(msg, standardMsg))
1060
1061    def assertRaisesRegexp(self, expected_exception, expected_regexp,
1062                           callable_obj=None, *args, **kwargs):
1063        """Asserts that the message in a raised exception matches a regexp.
1064
1065        Args:
1066            expected_exception: Exception class expected to be raised.
1067            expected_regexp: Regexp (re pattern object or string) expected
1068                    to be found in error message.
1069            callable_obj: Function to be called.
1070            args: Extra args.
1071            kwargs: Extra kwargs.
1072        """
1073        if callable_obj is None:
1074            return _AssertRaisesContext(
1075                expected_exception, self, expected_regexp)
1076        try:
1077            callable_obj(*args, **kwargs)
1078        except expected_exception as exc_value:
1079            if isinstance(expected_regexp, six.string_types):
1080                expected_regexp = re.compile(expected_regexp)
1081            if not expected_regexp.search(str(exc_value)):
1082                raise self.failureException(
1083                    '"%s" does not match "%s"' %
1084                    (expected_regexp.pattern, str(exc_value)))
1085        else:
1086            if hasattr(expected_exception, '__name__'):
1087                excName = expected_exception.__name__
1088            else:
1089                excName = str(expected_exception)
1090            raise self.failureException("%s not raised" % excName)
1091
1092    def assertRegexpMatches(self, text, expected_regexp, msg=None):
1093        """Fail the test unless the text matches the regular expression."""
1094        if isinstance(expected_regexp, six.string_types):
1095            expected_regexp = re.compile(expected_regexp)
1096        if not expected_regexp.search(text):
1097            msg = msg or "Regexp didn't match"
1098            msg = '%s: %r not found in %r' % (
1099                msg, expected_regexp.pattern, text)
1100            raise self.failureException(msg)
1101
1102    def assertNotRegexpMatches(self, text, unexpected_regexp, msg=None):
1103        """Fail the test if the text matches the regular expression."""
1104        if isinstance(unexpected_regexp, six.string_types):
1105            unexpected_regexp = re.compile(unexpected_regexp)
1106        match = unexpected_regexp.search(text)
1107        if match:
1108            msg = msg or "Regexp matched"
1109            msg = '%s: %r matches %r in %r' % (msg,
1110                                               text[match.start():match.end()],
1111                                               unexpected_regexp.pattern,
1112                                               text)
1113            raise self.failureException(msg)
1114
1115
1116class FunctionTestCase(TestCase):
1117    """A test case that wraps a test function.
1118
1119    This is useful for slipping pre-existing test functions into the
1120    unittest framework. Optionally, set-up and tidy-up functions can be
1121    supplied. As with TestCase, the tidy-up ('tearDown') function will
1122    always be called if the set-up ('setUp') function ran successfully.
1123    """
1124
1125    def __init__(self, testFunc, setUp=None, tearDown=None, description=None):
1126        super(FunctionTestCase, self).__init__()
1127        self._setUpFunc = setUp
1128        self._tearDownFunc = tearDown
1129        self._testFunc = testFunc
1130        self._description = description
1131
1132    def setUp(self):
1133        if self._setUpFunc is not None:
1134            self._setUpFunc()
1135
1136    def tearDown(self):
1137        if self._tearDownFunc is not None:
1138            self._tearDownFunc()
1139
1140    def runTest(self):
1141        self._testFunc()
1142
1143    def id(self):
1144        return self._testFunc.__name__
1145
1146    def __eq__(self, other):
1147        if not isinstance(other, self.__class__):
1148            return NotImplemented
1149
1150        return self._setUpFunc == other._setUpFunc and \
1151            self._tearDownFunc == other._tearDownFunc and \
1152            self._testFunc == other._testFunc and \
1153            self._description == other._description
1154
1155    def __ne__(self, other):
1156        return not self == other
1157
1158    def __hash__(self):
1159        return hash((type(self), self._setUpFunc, self._tearDownFunc,
1160                     self._testFunc, self._description))
1161
1162    def __str__(self):
1163        return "%s (%s)" % (strclass(self.__class__),
1164                            self._testFunc.__name__)
1165
1166    def __repr__(self):
1167        return "<%s testFunc=%s>" % (strclass(self.__class__),
1168                                     self._testFunc)
1169
1170    def shortDescription(self):
1171        if self._description is not None:
1172            return self._description
1173        doc = self._testFunc.__doc__
1174        return doc and doc.split("\n")[0].strip() or None
1175