1"""Test case implementation"""
2
3import sys
4import difflib
5import pprint
6import re
7import unittest
8import warnings
9
10from unittest2 import result
11from unittest2.util import (
12    safe_repr, safe_str, strclass,
13    unorderable_list_difference
14)
15
16from unittest2.compatibility import wraps
17
18__unittest = True
19
20
21DIFF_OMITTED = ('\nDiff is %s characters long. '
22                 'Set self.maxDiff to None to see it.')
23
24class SkipTest(Exception):
25    """
26    Raise this exception in a test to skip it.
27
28    Usually you can use TestResult.skip() or one of the skipping decorators
29    instead of raising this directly.
30    """
31
32class _ExpectedFailure(Exception):
33    """
34    Raise this when a test is expected to fail.
35
36    This is an implementation detail.
37    """
38
39    def __init__(self, exc_info):
40        # can't use super because Python 2.4 exceptions are old style
41        Exception.__init__(self)
42        self.exc_info = exc_info
43
44class _UnexpectedSuccess(Exception):
45    """
46    The test was supposed to fail, but it didn't!
47    """
48
49def _id(obj):
50    return obj
51
52def skip(reason):
53    """
54    Unconditionally skip a test.
55    """
56    def decorator(test_item):
57        if not (isinstance(test_item, type) and issubclass(test_item, TestCase)):
58            @wraps(test_item)
59            def skip_wrapper(*args, **kwargs):
60                raise SkipTest(reason)
61            test_item = skip_wrapper
62
63        test_item.__unittest_skip__ = True
64        test_item.__unittest_skip_why__ = reason
65        return test_item
66    return decorator
67
68def skipIf(condition, reason):
69    """
70    Skip a test if the condition is true.
71    """
72    if condition:
73        return skip(reason)
74    return _id
75
76def skipUnless(condition, reason):
77    """
78    Skip a test unless the condition is true.
79    """
80    if not condition:
81        return skip(reason)
82    return _id
83
84
85def expectedFailure(func):
86    @wraps(func)
87    def wrapper(*args, **kwargs):
88        try:
89            func(*args, **kwargs)
90        except Exception:
91            raise _ExpectedFailure(sys.exc_info())
92        raise _UnexpectedSuccess
93    return wrapper
94
95
96class _AssertRaisesContext(object):
97    """A context manager used to implement TestCase.assertRaises* methods."""
98
99    def __init__(self, expected, test_case, expected_regexp=None):
100        self.expected = expected
101        self.failureException = test_case.failureException
102        self.expected_regexp = expected_regexp
103
104    def __enter__(self):
105        return self
106
107    def __exit__(self, exc_type, exc_value, tb):
108        if exc_type is None:
109            try:
110                exc_name = self.expected.__name__
111            except AttributeError:
112                exc_name = str(self.expected)
113            raise self.failureException(
114                "%s not raised" % (exc_name,))
115        if not issubclass(exc_type, self.expected):
116            # let unexpected exceptions pass through
117            return False
118        self.exception = exc_value # store for later retrieval
119        if self.expected_regexp is None:
120            return True
121
122        expected_regexp = self.expected_regexp
123        if isinstance(expected_regexp, basestring):
124            expected_regexp = re.compile(expected_regexp)
125        if not expected_regexp.search(str(exc_value)):
126            raise self.failureException('"%s" does not match "%s"' %
127                     (expected_regexp.pattern, str(exc_value)))
128        return True
129
130
131class _TypeEqualityDict(object):
132
133    def __init__(self, testcase):
134        self.testcase = testcase
135        self._store = {}
136
137    def __setitem__(self, key, value):
138        self._store[key] = value
139
140    def __getitem__(self, key):
141        value = self._store[key]
142        if isinstance(value, basestring):
143            return getattr(self.testcase, value)
144        return value
145
146    def get(self, key, default=None):
147        if key in self._store:
148            return self[key]
149        return default
150
151
152class TestCase(unittest.TestCase):
153    """A class whose instances are single test cases.
154
155    By default, the test code itself should be placed in a method named
156    'runTest'.
157
158    If the fixture may be used for many test cases, create as
159    many test methods as are needed. When instantiating such a TestCase
160    subclass, specify in the constructor arguments the name of the test method
161    that the instance is to execute.
162
163    Test authors should subclass TestCase for their own tests. Construction
164    and deconstruction of the test's environment ('fixture') can be
165    implemented by overriding the 'setUp' and 'tearDown' methods respectively.
166
167    If it is necessary to override the __init__ method, the base class
168    __init__ method must always be called. It is important that subclasses
169    should not change the signature of their __init__ method, since instances
170    of the classes are instantiated automatically by parts of the framework
171    in order to be run.
172    """
173
174    # This attribute determines which exception will be raised when
175    # the instance's assertion methods fail; test methods raising this
176    # exception will be deemed to have 'failed' rather than 'errored'
177
178    failureException = AssertionError
179
180    # This attribute sets the maximum length of a diff in failure messages
181    # by assert methods using difflib. It is looked up as an instance attribute
182    # so can be configured by individual tests if required.
183
184    maxDiff = 80*8
185
186    # This attribute determines whether long messages (including repr of
187    # objects used in assert methods) will be printed on failure in *addition*
188    # to any explicit message passed.
189
190    longMessage = True
191
192    # Attribute used by TestSuite for classSetUp
193
194    _classSetupFailed = False
195
196    def __init__(self, methodName='runTest'):
197        """Create an instance of the class that will use the named test
198           method when executed. Raises a ValueError if the instance does
199           not have a method with the specified name.
200        """
201        self._testMethodName = methodName
202        self._resultForDoCleanups = None
203        try:
204            testMethod = getattr(self, methodName)
205        except AttributeError:
206            raise ValueError("no such test method in %s: %s" % \
207                  (self.__class__, methodName))
208        self._testMethodDoc = testMethod.__doc__
209        self._cleanups = []
210
211        # Map types to custom assertEqual functions that will compare
212        # instances of said type in more detail to generate a more useful
213        # error message.
214        self._type_equality_funcs = _TypeEqualityDict(self)
215        self.addTypeEqualityFunc(dict, 'assertDictEqual')
216        self.addTypeEqualityFunc(list, 'assertListEqual')
217        self.addTypeEqualityFunc(tuple, 'assertTupleEqual')
218        self.addTypeEqualityFunc(set, 'assertSetEqual')
219        self.addTypeEqualityFunc(frozenset, 'assertSetEqual')
220        self.addTypeEqualityFunc(unicode, 'assertMultiLineEqual')
221
222    def addTypeEqualityFunc(self, typeobj, function):
223        """Add a type specific assertEqual style function to compare a type.
224
225        This method is for use by TestCase subclasses that need to register
226        their own type equality functions to provide nicer error messages.
227
228        Args:
229            typeobj: The data type to call this function on when both values
230                    are of the same type in assertEqual().
231            function: The callable taking two arguments and an optional
232                    msg= argument that raises self.failureException with a
233                    useful error message when the two arguments are not equal.
234        """
235        self._type_equality_funcs[typeobj] = function
236
237    def addCleanup(self, function, *args, **kwargs):
238        """Add a function, with arguments, to be called when the test is
239        completed. Functions added are called on a LIFO basis and are
240        called after tearDown on test failure or success.
241
242        Cleanup items are called even if setUp fails (unlike tearDown)."""
243        self._cleanups.append((function, args, kwargs))
244
245    def setUp(self):
246        "Hook method for setting up the test fixture before exercising it."
247
248    @classmethod
249    def setUpClass(cls):
250        "Hook method for setting up class fixture before running tests in the class."
251
252    @classmethod
253    def tearDownClass(cls):
254        "Hook method for deconstructing the class fixture after running all tests in the class."
255
256    def tearDown(self):
257        "Hook method for deconstructing the test fixture after testing it."
258
259    def countTestCases(self):
260        return 1
261
262    def defaultTestResult(self):
263        return result.TestResult()
264
265    def shortDescription(self):
266        """Returns a one-line description of the test, or None if no
267        description has been provided.
268
269        The default implementation of this method returns the first line of
270        the specified test method's docstring.
271        """
272        doc = self._testMethodDoc
273        return doc and doc.split("\n")[0].strip() or None
274
275
276    def id(self):
277        return "%s.%s" % (strclass(self.__class__), self._testMethodName)
278
279    def __eq__(self, other):
280        if type(self) is not type(other):
281            return NotImplemented
282
283        return self._testMethodName == other._testMethodName
284
285    def __ne__(self, other):
286        return not self == other
287
288    def __hash__(self):
289        return hash((type(self), self._testMethodName))
290
291    def __str__(self):
292        return "%s (%s)" % (self._testMethodName, strclass(self.__class__))
293
294    def __repr__(self):
295        return "<%s testMethod=%s>" % \
296               (strclass(self.__class__), self._testMethodName)
297
298    def _addSkip(self, result, reason):
299        addSkip = getattr(result, 'addSkip', None)
300        if addSkip is not None:
301            addSkip(self, reason)
302        else:
303            warnings.warn("Use of a TestResult without an addSkip method is deprecated",
304                          DeprecationWarning, 2)
305            result.addSuccess(self)
306
307    def run(self, result=None):
308        orig_result = result
309        if result is None:
310            result = self.defaultTestResult()
311            startTestRun = getattr(result, 'startTestRun', None)
312            if startTestRun is not None:
313                startTestRun()
314
315        self._resultForDoCleanups = result
316        result.startTest(self)
317
318        testMethod = getattr(self, self._testMethodName)
319
320        if (getattr(self.__class__, "__unittest_skip__", False) or
321            getattr(testMethod, "__unittest_skip__", False)):
322            # If the class or method was skipped.
323            try:
324                skip_why = (getattr(self.__class__, '__unittest_skip_why__', '')
325                            or getattr(testMethod, '__unittest_skip_why__', ''))
326                self._addSkip(result, skip_why)
327            finally:
328                result.stopTest(self)
329            return
330        try:
331            success = False
332            try:
333                self.setUp()
334            except SkipTest, e:
335                self._addSkip(result, str(e))
336            except Exception:
337                result.addError(self, sys.exc_info())
338            else:
339                try:
340                    testMethod()
341                except self.failureException:
342                    result.addFailure(self, sys.exc_info())
343                except _ExpectedFailure, e:
344                    addExpectedFailure = getattr(result, 'addExpectedFailure', None)
345                    if addExpectedFailure is not None:
346                        addExpectedFailure(self, e.exc_info)
347                    else:
348                        warnings.warn("Use of a TestResult without an addExpectedFailure method is deprecated",
349                                      DeprecationWarning)
350                        result.addSuccess(self)
351                except _UnexpectedSuccess:
352                    addUnexpectedSuccess = getattr(result, 'addUnexpectedSuccess', None)
353                    if addUnexpectedSuccess is not None:
354                        addUnexpectedSuccess(self)
355                    else:
356                        warnings.warn("Use of a TestResult without an addUnexpectedSuccess method is deprecated",
357                                      DeprecationWarning)
358                        result.addFailure(self, sys.exc_info())
359                except SkipTest, e:
360                    self._addSkip(result, str(e))
361                except Exception:
362                    result.addError(self, sys.exc_info())
363                else:
364                    success = True
365
366                try:
367                    self.tearDown()
368                except Exception:
369                    result.addError(self, sys.exc_info())
370                    success = False
371
372            cleanUpSuccess = self.doCleanups()
373            success = success and cleanUpSuccess
374            if success:
375                result.addSuccess(self)
376        finally:
377            result.stopTest(self)
378            if orig_result is None:
379                stopTestRun = getattr(result, 'stopTestRun', None)
380                if stopTestRun is not None:
381                    stopTestRun()
382
383    def doCleanups(self):
384        """Execute all cleanup functions. Normally called for you after
385        tearDown."""
386        result = self._resultForDoCleanups
387        ok = True
388        while self._cleanups:
389            function, args, kwargs = self._cleanups.pop(-1)
390            try:
391                function(*args, **kwargs)
392            except Exception:
393                ok = False
394                result.addError(self, sys.exc_info())
395        return ok
396
397    def __call__(self, *args, **kwds):
398        return self.run(*args, **kwds)
399
400    def debug(self):
401        """Run the test without collecting errors in a TestResult"""
402        self.setUp()
403        getattr(self, self._testMethodName)()
404        self.tearDown()
405        while self._cleanups:
406            function, args, kwargs = self._cleanups.pop(-1)
407            function(*args, **kwargs)
408
409    def skipTest(self, reason):
410        """Skip this test."""
411        raise SkipTest(reason)
412
413    def fail(self, msg=None):
414        """Fail immediately, with the given message."""
415        raise self.failureException(msg)
416
417    def assertFalse(self, expr, msg=None):
418        "Fail the test if the expression is true."
419        if expr:
420            msg = self._formatMessage(msg, "%s is not False" % safe_repr(expr))
421            raise self.failureException(msg)
422
423    def assertTrue(self, expr, msg=None):
424        """Fail the test unless the expression is true."""
425        if not expr:
426            msg = self._formatMessage(msg, "%s is not True" % safe_repr(expr))
427            raise self.failureException(msg)
428
429    def _formatMessage(self, msg, standardMsg):
430        """Honour the longMessage attribute when generating failure messages.
431        If longMessage is False this means:
432        * Use only an explicit message if it is provided
433        * Otherwise use the standard message for the assert
434
435        If longMessage is True:
436        * Use the standard message
437        * If an explicit message is provided, plus ' : ' and the explicit message
438        """
439        if not self.longMessage:
440            return msg or standardMsg
441        if msg is None:
442            return standardMsg
443        try:
444            return '%s : %s' % (standardMsg, msg)
445        except UnicodeDecodeError:
446            return '%s : %s' % (safe_str(standardMsg), safe_str(msg))
447
448
449    def assertRaises(self, excClass, callableObj=None, *args, **kwargs):
450        """Fail unless an exception of class excClass is thrown
451           by callableObj when invoked with arguments args and keyword
452           arguments kwargs. If a different type of exception is
453           thrown, it will not be caught, and the test case will be
454           deemed to have suffered an error, exactly as for an
455           unexpected exception.
456
457           If called with callableObj omitted or None, will return a
458           context object used like this::
459
460                with self.assertRaises(SomeException):
461                    do_something()
462
463           The context manager keeps a reference to the exception as
464           the 'exception' attribute. This allows you to inspect the
465           exception after the assertion::
466
467               with self.assertRaises(SomeException) as cm:
468                   do_something()
469               the_exception = cm.exception
470               self.assertEqual(the_exception.error_code, 3)
471        """
472        if callableObj is None:
473            return _AssertRaisesContext(excClass, self)
474        try:
475            callableObj(*args, **kwargs)
476        except excClass:
477            return
478
479        if hasattr(excClass,'__name__'):
480            excName = excClass.__name__
481        else:
482            excName = str(excClass)
483        raise self.failureException, "%s not raised" % excName
484
485    def _getAssertEqualityFunc(self, first, second):
486        """Get a detailed comparison function for the types of the two args.
487
488        Returns: A callable accepting (first, second, msg=None) that will
489        raise a failure exception if first != second with a useful human
490        readable error message for those types.
491        """
492        #
493        # NOTE(gregory.p.smith): I considered isinstance(first, type(second))
494        # and vice versa.  I opted for the conservative approach in case
495        # subclasses are not intended to be compared in detail to their super
496        # class instances using a type equality func.  This means testing
497        # subtypes won't automagically use the detailed comparison.  Callers
498        # should use their type specific assertSpamEqual method to compare
499        # subclasses if the detailed comparison is desired and appropriate.
500        # See the discussion in http://bugs.python.org/issue2578.
501        #
502        if type(first) is type(second):
503            asserter = self._type_equality_funcs.get(type(first))
504            if asserter is not None:
505                return asserter
506
507        return self._baseAssertEqual
508
509    def _baseAssertEqual(self, first, second, msg=None):
510        """The default assertEqual implementation, not type specific."""
511        if not first == second:
512            standardMsg = '%s != %s' % (safe_repr(first), safe_repr(second))
513            msg = self._formatMessage(msg, standardMsg)
514            raise self.failureException(msg)
515
516    def assertEqual(self, first, second, msg=None):
517        """Fail if the two objects are unequal as determined by the '=='
518           operator.
519        """
520        assertion_func = self._getAssertEqualityFunc(first, second)
521        assertion_func(first, second, msg=msg)
522
523    def assertNotEqual(self, first, second, msg=None):
524        """Fail if the two objects are equal as determined by the '=='
525           operator.
526        """
527        if not first != second:
528            msg = self._formatMessage(msg, '%s == %s' % (safe_repr(first),
529                                                           safe_repr(second)))
530            raise self.failureException(msg)
531
532    def assertAlmostEqual(self, first, second, places=None, msg=None, delta=None):
533        """Fail if the two objects are unequal as determined by their
534           difference rounded to the given number of decimal places
535           (default 7) and comparing to zero, or by comparing that the
536           between the two objects is more than the given delta.
537
538           Note that decimal places (from zero) are usually not the same
539           as significant digits (measured from the most signficant digit).
540
541           If the two objects compare equal then they will automatically
542           compare almost equal.
543        """
544        if first == second:
545            # shortcut
546            return
547        if delta is not None and places is not None:
548            raise TypeError("specify delta or places not both")
549
550        if delta is not None:
551            if abs(first - second) <= delta:
552                return
553
554            standardMsg = '%s != %s within %s delta' % (safe_repr(first),
555                                                        safe_repr(second),
556                                                        safe_repr(delta))
557        else:
558            if places is None:
559                places = 7
560
561            if round(abs(second-first), places) == 0:
562                return
563
564            standardMsg = '%s != %s within %r places' % (safe_repr(first),
565                                                          safe_repr(second),
566                                                          places)
567        msg = self._formatMessage(msg, standardMsg)
568        raise self.failureException(msg)
569
570    def assertNotAlmostEqual(self, first, second, places=None, msg=None, delta=None):
571        """Fail if the two objects are equal as determined by their
572           difference rounded to the given number of decimal places
573           (default 7) and comparing to zero, or by comparing that the
574           between the two objects is less than the given delta.
575
576           Note that decimal places (from zero) are usually not the same
577           as significant digits (measured from the most signficant digit).
578
579           Objects that are equal automatically fail.
580        """
581        if delta is not None and places is not None:
582            raise TypeError("specify delta or places not both")
583        if delta is not None:
584            if not (first == second) and abs(first - second) > delta:
585                return
586            standardMsg = '%s == %s within %s delta' % (safe_repr(first),
587                                                        safe_repr(second),
588                                                        safe_repr(delta))
589        else:
590            if places is None:
591                places = 7
592            if not (first == second) and round(abs(second-first), places) != 0:
593                return
594            standardMsg = '%s == %s within %r places' % (safe_repr(first),
595                                                         safe_repr(second),
596                                                         places)
597
598        msg = self._formatMessage(msg, standardMsg)
599        raise self.failureException(msg)
600
601    # Synonyms for assertion methods
602
603    # The plurals are undocumented.  Keep them that way to discourage use.
604    # Do not add more.  Do not remove.
605    # Going through a deprecation cycle on these would annoy many people.
606    assertEquals = assertEqual
607    assertNotEquals = assertNotEqual
608    assertAlmostEquals = assertAlmostEqual
609    assertNotAlmostEquals = assertNotAlmostEqual
610    assert_ = assertTrue
611
612    # These fail* assertion method names are pending deprecation and will
613    # be a DeprecationWarning in 3.2; http://bugs.python.org/issue2578
614    def _deprecate(original_func):
615        def deprecated_func(*args, **kwargs):
616            warnings.warn(
617                ('Please use %s instead.' % original_func.__name__),
618                PendingDeprecationWarning, 2)
619            return original_func(*args, **kwargs)
620        return deprecated_func
621
622    failUnlessEqual = _deprecate(assertEqual)
623    failIfEqual = _deprecate(assertNotEqual)
624    failUnlessAlmostEqual = _deprecate(assertAlmostEqual)
625    failIfAlmostEqual = _deprecate(assertNotAlmostEqual)
626    failUnless = _deprecate(assertTrue)
627    failUnlessRaises = _deprecate(assertRaises)
628    failIf = _deprecate(assertFalse)
629
630    def assertSequenceEqual(self, seq1, seq2,
631                            msg=None, seq_type=None, max_diff=80*8):
632        """An equality assertion for ordered sequences (like lists and tuples).
633
634        For the purposes of this function, a valid ordered sequence type is one
635        which can be indexed, has a length, and has an equality operator.
636
637        Args:
638            seq1: The first sequence to compare.
639            seq2: The second sequence to compare.
640            seq_type: The expected datatype of the sequences, or None if no
641                    datatype should be enforced.
642            msg: Optional message to use on failure instead of a list of
643                    differences.
644            max_diff: Maximum size off the diff, larger diffs are not shown
645        """
646        if seq_type is not None:
647            seq_type_name = seq_type.__name__
648            if not isinstance(seq1, seq_type):
649                raise self.failureException('First sequence is not a %s: %s'
650                                            % (seq_type_name, safe_repr(seq1)))
651            if not isinstance(seq2, seq_type):
652                raise self.failureException('Second sequence is not a %s: %s'
653                                            % (seq_type_name, safe_repr(seq2)))
654        else:
655            seq_type_name = "sequence"
656
657        differing = None
658        try:
659            len1 = len(seq1)
660        except (TypeError, NotImplementedError):
661            differing = 'First %s has no length.    Non-sequence?' % (
662                    seq_type_name)
663
664        if differing is None:
665            try:
666                len2 = len(seq2)
667            except (TypeError, NotImplementedError):
668                differing = 'Second %s has no length.    Non-sequence?' % (
669                        seq_type_name)
670
671        if differing is None:
672            if seq1 == seq2:
673                return
674
675            seq1_repr = repr(seq1)
676            seq2_repr = repr(seq2)
677            if len(seq1_repr) > 30:
678                seq1_repr = seq1_repr[:30] + '...'
679            if len(seq2_repr) > 30:
680                seq2_repr = seq2_repr[:30] + '...'
681            elements = (seq_type_name.capitalize(), seq1_repr, seq2_repr)
682            differing = '%ss differ: %s != %s\n' % elements
683
684            for i in xrange(min(len1, len2)):
685                try:
686                    item1 = seq1[i]
687                except (TypeError, IndexError, NotImplementedError):
688                    differing += ('\nUnable to index element %d of first %s\n' %
689                                 (i, seq_type_name))
690                    break
691
692                try:
693                    item2 = seq2[i]
694                except (TypeError, IndexError, NotImplementedError):
695                    differing += ('\nUnable to index element %d of second %s\n' %
696                                 (i, seq_type_name))
697                    break
698
699                if item1 != item2:
700                    differing += ('\nFirst differing element %d:\n%s\n%s\n' %
701                                 (i, item1, item2))
702                    break
703            else:
704                if (len1 == len2 and seq_type is None and
705                    type(seq1) != type(seq2)):
706                    # The sequences are the same, but have differing types.
707                    return
708
709            if len1 > len2:
710                differing += ('\nFirst %s contains %d additional '
711                             'elements.\n' % (seq_type_name, len1 - len2))
712                try:
713                    differing += ('First extra element %d:\n%s\n' %
714                                  (len2, seq1[len2]))
715                except (TypeError, IndexError, NotImplementedError):
716                    differing += ('Unable to index element %d '
717                                  'of first %s\n' % (len2, seq_type_name))
718            elif len1 < len2:
719                differing += ('\nSecond %s contains %d additional '
720                             'elements.\n' % (seq_type_name, len2 - len1))
721                try:
722                    differing += ('First extra element %d:\n%s\n' %
723                                  (len1, seq2[len1]))
724                except (TypeError, IndexError, NotImplementedError):
725                    differing += ('Unable to index element %d '
726                                  'of second %s\n' % (len1, seq_type_name))
727        standardMsg = differing
728        diffMsg = '\n' + '\n'.join(
729            difflib.ndiff(pprint.pformat(seq1).splitlines(),
730                          pprint.pformat(seq2).splitlines()))
731
732        standardMsg = self._truncateMessage(standardMsg, diffMsg)
733        msg = self._formatMessage(msg, standardMsg)
734        self.fail(msg)
735
736    def _truncateMessage(self, message, diff):
737        max_diff = self.maxDiff
738        if max_diff is None or len(diff) <= max_diff:
739            return message + diff
740        return message + (DIFF_OMITTED % len(diff))
741
742    def assertListEqual(self, list1, list2, msg=None):
743        """A list-specific equality assertion.
744
745        Args:
746            list1: The first list to compare.
747            list2: The second list to compare.
748            msg: Optional message to use on failure instead of a list of
749                    differences.
750
751        """
752        self.assertSequenceEqual(list1, list2, msg, seq_type=list)
753
754    def assertTupleEqual(self, tuple1, tuple2, msg=None):
755        """A tuple-specific equality assertion.
756
757        Args:
758            tuple1: The first tuple to compare.
759            tuple2: The second tuple to compare.
760            msg: Optional message to use on failure instead of a list of
761                    differences.
762        """
763        self.assertSequenceEqual(tuple1, tuple2, msg, seq_type=tuple)
764
765    def assertSetEqual(self, set1, set2, msg=None):
766        """A set-specific equality assertion.
767
768        Args:
769            set1: The first set to compare.
770            set2: The second set to compare.
771            msg: Optional message to use on failure instead of a list of
772                    differences.
773
774        assertSetEqual uses ducktyping to support
775        different types of sets, and is optimized for sets specifically
776        (parameters must support a difference method).
777        """
778        try:
779            difference1 = set1.difference(set2)
780        except TypeError, e:
781            self.fail('invalid type when attempting set difference: %s' % e)
782        except AttributeError, e:
783            self.fail('first argument does not support set difference: %s' % e)
784
785        try:
786            difference2 = set2.difference(set1)
787        except TypeError, e:
788            self.fail('invalid type when attempting set difference: %s' % e)
789        except AttributeError, e:
790            self.fail('second argument does not support set difference: %s' % e)
791
792        if not (difference1 or difference2):
793            return
794
795        lines = []
796        if difference1:
797            lines.append('Items in the first set but not the second:')
798            for item in difference1:
799                lines.append(repr(item))
800        if difference2:
801            lines.append('Items in the second set but not the first:')
802            for item in difference2:
803                lines.append(repr(item))
804
805        standardMsg = '\n'.join(lines)
806        self.fail(self._formatMessage(msg, standardMsg))
807
808    def assertIn(self, member, container, msg=None):
809        """Just like self.assertTrue(a in b), but with a nicer default message."""
810        if member not in container:
811            standardMsg = '%s not found in %s' % (safe_repr(member),
812                                                   safe_repr(container))
813            self.fail(self._formatMessage(msg, standardMsg))
814
815    def assertNotIn(self, member, container, msg=None):
816        """Just like self.assertTrue(a not in b), but with a nicer default message."""
817        if member in container:
818            standardMsg = '%s unexpectedly found in %s' % (safe_repr(member),
819                                                            safe_repr(container))
820            self.fail(self._formatMessage(msg, standardMsg))
821
822    def assertIs(self, expr1, expr2, msg=None):
823        """Just like self.assertTrue(a is b), but with a nicer default message."""
824        if expr1 is not expr2:
825            standardMsg = '%s is not %s' % (safe_repr(expr1), safe_repr(expr2))
826            self.fail(self._formatMessage(msg, standardMsg))
827
828    def assertIsNot(self, expr1, expr2, msg=None):
829        """Just like self.assertTrue(a is not b), but with a nicer default message."""
830        if expr1 is expr2:
831            standardMsg = 'unexpectedly identical: %s' % (safe_repr(expr1),)
832            self.fail(self._formatMessage(msg, standardMsg))
833
834    def assertDictEqual(self, d1, d2, msg=None):
835        self.assert_(isinstance(d1, dict), 'First argument is not a dictionary')
836        self.assert_(isinstance(d2, dict), 'Second argument is not a dictionary')
837
838        if d1 != d2:
839            standardMsg = '%s != %s' % (safe_repr(d1, True), safe_repr(d2, True))
840            diff = ('\n' + '\n'.join(difflib.ndiff(
841                           pprint.pformat(d1).splitlines(),
842                           pprint.pformat(d2).splitlines())))
843            standardMsg = self._truncateMessage(standardMsg, diff)
844            self.fail(self._formatMessage(msg, standardMsg))
845
846    def assertDictContainsSubset(self, expected, actual, msg=None):
847        """Checks whether actual is a superset of expected."""
848        missing = []
849        mismatched = []
850        for key, value in expected.iteritems():
851            if key not in actual:
852                missing.append(key)
853            elif value != actual[key]:
854                mismatched.append('%s, expected: %s, actual: %s' %
855                                  (safe_repr(key), safe_repr(value),
856                                   safe_repr(actual[key])))
857
858        if not (missing or mismatched):
859            return
860
861        standardMsg = ''
862        if missing:
863            standardMsg = 'Missing: %s' % ','.join(safe_repr(m) for m in
864                                                    missing)
865        if mismatched:
866            if standardMsg:
867                standardMsg += '; '
868            standardMsg += 'Mismatched values: %s' % ','.join(mismatched)
869
870        self.fail(self._formatMessage(msg, standardMsg))
871
872    def assertItemsEqual(self, expected_seq, actual_seq, msg=None):
873        """An unordered sequence specific comparison. It asserts that
874        expected_seq and actual_seq contain the same elements. It is
875        the equivalent of::
876
877            self.assertEqual(sorted(expected_seq), sorted(actual_seq))
878
879        Raises with an error message listing which elements of expected_seq
880        are missing from actual_seq and vice versa if any.
881
882        Asserts that each element has the same count in both sequences.
883        Example:
884            - [0, 1, 1] and [1, 0, 1] compare equal.
885            - [0, 0, 1] and [0, 1] compare unequal.
886        """
887        try:
888            expected = sorted(expected_seq)
889            actual = sorted(actual_seq)
890        except TypeError:
891            # Unsortable items (example: set(), complex(), ...)
892            expected = list(expected_seq)
893            actual = list(actual_seq)
894            missing, unexpected = unorderable_list_difference(
895                expected, actual, ignore_duplicate=False
896            )
897        else:
898            return self.assertSequenceEqual(expected, actual, msg=msg)
899
900        errors = []
901        if missing:
902            errors.append('Expected, but missing:\n    %s' %
903                           safe_repr(missing))
904        if unexpected:
905            errors.append('Unexpected, but present:\n    %s' %
906                           safe_repr(unexpected))
907        if errors:
908            standardMsg = '\n'.join(errors)
909            self.fail(self._formatMessage(msg, standardMsg))
910
911    def assertMultiLineEqual(self, first, second, msg=None):
912        """Assert that two multi-line strings are equal."""
913        self.assert_(isinstance(first, basestring), (
914                'First argument is not a string'))
915        self.assert_(isinstance(second, basestring), (
916                'Second argument is not a string'))
917
918        if first != second:
919            standardMsg = '%s != %s' % (safe_repr(first, True), safe_repr(second, True))
920            diff = '\n' + ''.join(difflib.ndiff(first.splitlines(True),
921                                                       second.splitlines(True)))
922            standardMsg = self._truncateMessage(standardMsg, diff)
923            self.fail(self._formatMessage(msg, standardMsg))
924
925    def assertLess(self, a, b, msg=None):
926        """Just like self.assertTrue(a < b), but with a nicer default message."""
927        if not a < b:
928            standardMsg = '%s not less than %s' % (safe_repr(a), safe_repr(b))
929            self.fail(self._formatMessage(msg, standardMsg))
930
931    def assertLessEqual(self, a, b, msg=None):
932        """Just like self.assertTrue(a <= b), but with a nicer default message."""
933        if not a <= b:
934            standardMsg = '%s not less than or equal to %s' % (safe_repr(a), safe_repr(b))
935            self.fail(self._formatMessage(msg, standardMsg))
936
937    def assertGreater(self, a, b, msg=None):
938        """Just like self.assertTrue(a > b), but with a nicer default message."""
939        if not a > b:
940            standardMsg = '%s not greater than %s' % (safe_repr(a), safe_repr(b))
941            self.fail(self._formatMessage(msg, standardMsg))
942
943    def assertGreaterEqual(self, a, b, msg=None):
944        """Just like self.assertTrue(a >= b), but with a nicer default message."""
945        if not a >= b:
946            standardMsg = '%s not greater than or equal to %s' % (safe_repr(a), safe_repr(b))
947            self.fail(self._formatMessage(msg, standardMsg))
948
949    def assertIsNone(self, obj, msg=None):
950        """Same as self.assertTrue(obj is None), with a nicer default message."""
951        if obj is not None:
952            standardMsg = '%s is not None' % (safe_repr(obj),)
953            self.fail(self._formatMessage(msg, standardMsg))
954
955    def assertIsNotNone(self, obj, msg=None):
956        """Included for symmetry with assertIsNone."""
957        if obj is None:
958            standardMsg = 'unexpectedly None'
959            self.fail(self._formatMessage(msg, standardMsg))
960
961    def assertIsInstance(self, obj, cls, msg=None):
962        """Same as self.assertTrue(isinstance(obj, cls)), with a nicer
963        default message."""
964        if not isinstance(obj, cls):
965            standardMsg = '%s is not an instance of %r' % (safe_repr(obj), cls)
966            self.fail(self._formatMessage(msg, standardMsg))
967
968    def assertNotIsInstance(self, obj, cls, msg=None):
969        """Included for symmetry with assertIsInstance."""
970        if isinstance(obj, cls):
971            standardMsg = '%s is an instance of %r' % (safe_repr(obj), cls)
972            self.fail(self._formatMessage(msg, standardMsg))
973
974    def assertRaisesRegexp(self, expected_exception, expected_regexp,
975                           callable_obj=None, *args, **kwargs):
976        """Asserts that the message in a raised exception matches a regexp.
977
978        Args:
979            expected_exception: Exception class expected to be raised.
980            expected_regexp: Regexp (re pattern object or string) expected
981                    to be found in error message.
982            callable_obj: Function to be called.
983            args: Extra args.
984            kwargs: Extra kwargs.
985        """
986        if callable_obj is None:
987            return _AssertRaisesContext(expected_exception, self, expected_regexp)
988        try:
989            callable_obj(*args, **kwargs)
990        except expected_exception, exc_value:
991            if isinstance(expected_regexp, basestring):
992                expected_regexp = re.compile(expected_regexp)
993            if not expected_regexp.search(str(exc_value)):
994                raise self.failureException('"%s" does not match "%s"' %
995                         (expected_regexp.pattern, str(exc_value)))
996        else:
997            if hasattr(expected_exception, '__name__'):
998                excName = expected_exception.__name__
999            else:
1000                excName = str(expected_exception)
1001            raise self.failureException, "%s not raised" % excName
1002
1003
1004    def assertRegexpMatches(self, text, expected_regexp, msg=None):
1005        """Fail the test unless the text matches the regular expression."""
1006        if isinstance(expected_regexp, basestring):
1007            expected_regexp = re.compile(expected_regexp)
1008        if not expected_regexp.search(text):
1009            msg = msg or "Regexp didn't match"
1010            msg = '%s: %r not found in %r' % (msg, expected_regexp.pattern, text)
1011            raise self.failureException(msg)
1012
1013    def assertNotRegexpMatches(self, text, unexpected_regexp, msg=None):
1014        """Fail the test if the text matches the regular expression."""
1015        if isinstance(unexpected_regexp, basestring):
1016            unexpected_regexp = re.compile(unexpected_regexp)
1017        match = unexpected_regexp.search(text)
1018        if match:
1019            msg = msg or "Regexp matched"
1020            msg = '%s: %r matches %r in %r' % (msg,
1021                                               text[match.start():match.end()],
1022                                               unexpected_regexp.pattern,
1023                                               text)
1024            raise self.failureException(msg)
1025
1026class FunctionTestCase(TestCase):
1027    """A test case that wraps a test function.
1028
1029    This is useful for slipping pre-existing test functions into the
1030    unittest framework. Optionally, set-up and tidy-up functions can be
1031    supplied. As with TestCase, the tidy-up ('tearDown') function will
1032    always be called if the set-up ('setUp') function ran successfully.
1033    """
1034
1035    def __init__(self, testFunc, setUp=None, tearDown=None, description=None):
1036        super(FunctionTestCase, self).__init__()
1037        self._setUpFunc = setUp
1038        self._tearDownFunc = tearDown
1039        self._testFunc = testFunc
1040        self._description = description
1041
1042    def setUp(self):
1043        if self._setUpFunc is not None:
1044            self._setUpFunc()
1045
1046    def tearDown(self):
1047        if self._tearDownFunc is not None:
1048            self._tearDownFunc()
1049
1050    def runTest(self):
1051        self._testFunc()
1052
1053    def id(self):
1054        return self._testFunc.__name__
1055
1056    def __eq__(self, other):
1057        if not isinstance(other, self.__class__):
1058            return NotImplemented
1059
1060        return self._setUpFunc == other._setUpFunc and \
1061               self._tearDownFunc == other._tearDownFunc and \
1062               self._testFunc == other._testFunc and \
1063               self._description == other._description
1064
1065    def __ne__(self, other):
1066        return not self == other
1067
1068    def __hash__(self):
1069        return hash((type(self), self._setUpFunc, self._tearDownFunc,
1070                     self._testFunc, self._description))
1071
1072    def __str__(self):
1073        return "%s (%s)" % (strclass(self.__class__),
1074                            self._testFunc.__name__)
1075
1076    def __repr__(self):
1077        return "<%s testFunc=%s>" % (strclass(self.__class__),
1078                                     self._testFunc)
1079
1080    def shortDescription(self):
1081        if self._description is not None:
1082            return self._description
1083        doc = self._testFunc.__doc__
1084        return doc and doc.split("\n")[0].strip() or None
1085