1"""Test case implementation"""
2
3import collections
4import sys
5import functools
6import difflib
7import pprint
8import re
9import types
10import warnings
11
12from . import result
13from .util import (
14    strclass, safe_repr, unorderable_list_difference,
15    _count_diff_all_purpose, _count_diff_hashable
16)
17
18
19__unittest = True
20
21
22DIFF_OMITTED = ('\nDiff is %s characters long. '
23                 'Set self.maxDiff to None to see it.')
24
25class SkipTest(Exception):
26    """
27    Raise this exception in a test to skip it.
28
29    Usually you can use TestCase.skipTest() or one of the skipping decorators
30    instead of raising this directly.
31    """
32    pass
33
34class _ExpectedFailure(Exception):
35    """
36    Raise this when a test is expected to fail.
37
38    This is an implementation detail.
39    """
40
41    def __init__(self, exc_info):
42        super(_ExpectedFailure, self).__init__()
43        self.exc_info = exc_info
44
45class _UnexpectedSuccess(Exception):
46    """
47    The test was supposed to fail, but it didn't!
48    """
49    pass
50
51def _id(obj):
52    return obj
53
54def skip(reason):
55    """
56    Unconditionally skip a test.
57    """
58    def decorator(test_item):
59        if not isinstance(test_item, (type, types.ClassType)):
60            @functools.wraps(test_item)
61            def skip_wrapper(*args, **kwargs):
62                raise SkipTest(reason)
63            test_item = skip_wrapper
64
65        test_item.__unittest_skip__ = True
66        test_item.__unittest_skip_why__ = reason
67        return test_item
68    return decorator
69
70def skipIf(condition, reason):
71    """
72    Skip a test if the condition is true.
73    """
74    if condition:
75        return skip(reason)
76    return _id
77
78def skipUnless(condition, reason):
79    """
80    Skip a test unless the condition is true.
81    """
82    if not condition:
83        return skip(reason)
84    return _id
85
86
87def expectedFailure(func):
88    @functools.wraps(func)
89    def wrapper(*args, **kwargs):
90        try:
91            func(*args, **kwargs)
92        except Exception:
93            raise _ExpectedFailure(sys.exc_info())
94        raise _UnexpectedSuccess
95    return wrapper
96
97
98class _AssertRaisesContext(object):
99    """A context manager used to implement TestCase.assertRaises* methods."""
100
101    def __init__(self, expected, test_case, expected_regexp=None):
102        self.expected = expected
103        self.failureException = test_case.failureException
104        self.expected_regexp = expected_regexp
105
106    def __enter__(self):
107        return self
108
109    def __exit__(self, exc_type, exc_value, tb):
110        if exc_type is None:
111            try:
112                exc_name = self.expected.__name__
113            except AttributeError:
114                exc_name = str(self.expected)
115            raise self.failureException(
116                "{0} not raised".format(exc_name))
117        if not issubclass(exc_type, self.expected):
118            # let unexpected exceptions pass through
119            return False
120        self.exception = exc_value # store for later retrieval
121        if self.expected_regexp is None:
122            return True
123
124        expected_regexp = self.expected_regexp
125        if not expected_regexp.search(str(exc_value)):
126            raise self.failureException('"%s" does not match "%s"' %
127                     (expected_regexp.pattern, str(exc_value)))
128        return True
129
130
131class TestCase(object):
132    """A class whose instances are single test cases.
133
134    By default, the test code itself should be placed in a method named
135    'runTest'.
136
137    If the fixture may be used for many test cases, create as
138    many test methods as are needed. When instantiating such a TestCase
139    subclass, specify in the constructor arguments the name of the test method
140    that the instance is to execute.
141
142    Test authors should subclass TestCase for their own tests. Construction
143    and deconstruction of the test's environment ('fixture') can be
144    implemented by overriding the 'setUp' and 'tearDown' methods respectively.
145
146    If it is necessary to override the __init__ method, the base class
147    __init__ method must always be called. It is important that subclasses
148    should not change the signature of their __init__ method, since instances
149    of the classes are instantiated automatically by parts of the framework
150    in order to be run.
151
152    When subclassing TestCase, you can set these attributes:
153    * failureException: determines which exception will be raised when
154        the instance's assertion methods fail; test methods raising this
155        exception will be deemed to have 'failed' rather than 'errored'.
156    * longMessage: determines whether long messages (including repr of
157        objects used in assert methods) will be printed on failure in *addition*
158        to any explicit message passed.
159    * maxDiff: sets the maximum length of a diff in failure messages
160        by assert methods using difflib. It is looked up as an instance
161        attribute so can be configured by individual tests if required.
162    """
163
164    failureException = AssertionError
165
166    longMessage = False
167
168    maxDiff = 80*8
169
170    # If a string is longer than _diffThreshold, use normal comparison instead
171    # of difflib.  See #11763.
172    _diffThreshold = 2**16
173
174    # Attribute used by TestSuite for classSetUp
175
176    _classSetupFailed = False
177
178    def __init__(self, methodName='runTest'):
179        """Create an instance of the class that will use the named test
180           method when executed. Raises a ValueError if the instance does
181           not have a method with the specified name.
182        """
183        self._testMethodName = methodName
184        self._resultForDoCleanups = None
185        try:
186            testMethod = getattr(self, methodName)
187        except AttributeError:
188            raise ValueError("no such test method in %s: %s" %
189                  (self.__class__, methodName))
190        self._testMethodDoc = testMethod.__doc__
191        self._cleanups = []
192
193        # Map types to custom assertEqual functions that will compare
194        # instances of said type in more detail to generate a more useful
195        # error message.
196        self._type_equality_funcs = {}
197        self.addTypeEqualityFunc(dict, 'assertDictEqual')
198        self.addTypeEqualityFunc(list, 'assertListEqual')
199        self.addTypeEqualityFunc(tuple, 'assertTupleEqual')
200        self.addTypeEqualityFunc(set, 'assertSetEqual')
201        self.addTypeEqualityFunc(frozenset, 'assertSetEqual')
202        try:
203            self.addTypeEqualityFunc(unicode, 'assertMultiLineEqual')
204        except NameError:
205            # No unicode support in this build
206            pass
207
208    def addTypeEqualityFunc(self, typeobj, function):
209        """Add a type specific assertEqual style function to compare a type.
210
211        This method is for use by TestCase subclasses that need to register
212        their own type equality functions to provide nicer error messages.
213
214        Args:
215            typeobj: The data type to call this function on when both values
216                    are of the same type in assertEqual().
217            function: The callable taking two arguments and an optional
218                    msg= argument that raises self.failureException with a
219                    useful error message when the two arguments are not equal.
220        """
221        self._type_equality_funcs[typeobj] = function
222
223    def addCleanup(self, function, *args, **kwargs):
224        """Add a function, with arguments, to be called when the test is
225        completed. Functions added are called on a LIFO basis and are
226        called after tearDown on test failure or success.
227
228        Cleanup items are called even if setUp fails (unlike tearDown)."""
229        self._cleanups.append((function, args, kwargs))
230
231    def setUp(self):
232        "Hook method for setting up the test fixture before exercising it."
233        pass
234
235    def tearDown(self):
236        "Hook method for deconstructing the test fixture after testing it."
237        pass
238
239    @classmethod
240    def setUpClass(cls):
241        "Hook method for setting up class fixture before running tests in the class."
242
243    @classmethod
244    def tearDownClass(cls):
245        "Hook method for deconstructing the class fixture after running all tests in the class."
246
247    def countTestCases(self):
248        return 1
249
250    def defaultTestResult(self):
251        return result.TestResult()
252
253    def shortDescription(self):
254        """Returns a one-line description of the test, or None if no
255        description has been provided.
256
257        The default implementation of this method returns the first line of
258        the specified test method's docstring.
259        """
260        doc = self._testMethodDoc
261        return doc and doc.split("\n")[0].strip() or None
262
263
264    def id(self):
265        return "%s.%s" % (strclass(self.__class__), self._testMethodName)
266
267    def __eq__(self, other):
268        if type(self) is not type(other):
269            return NotImplemented
270
271        return self._testMethodName == other._testMethodName
272
273    def __ne__(self, other):
274        return not self == other
275
276    def __hash__(self):
277        return hash((type(self), self._testMethodName))
278
279    def __str__(self):
280        return "%s (%s)" % (self._testMethodName, strclass(self.__class__))
281
282    def __repr__(self):
283        return "<%s testMethod=%s>" % \
284               (strclass(self.__class__), self._testMethodName)
285
286    def _addSkip(self, result, reason):
287        addSkip = getattr(result, 'addSkip', None)
288        if addSkip is not None:
289            addSkip(self, reason)
290        else:
291            warnings.warn("TestResult has no addSkip method, skips not reported",
292                          RuntimeWarning, 2)
293            result.addSuccess(self)
294
295    def run(self, result=None):
296        orig_result = result
297        if result is None:
298            result = self.defaultTestResult()
299            startTestRun = getattr(result, 'startTestRun', None)
300            if startTestRun is not None:
301                startTestRun()
302
303        self._resultForDoCleanups = result
304        result.startTest(self)
305
306        testMethod = getattr(self, self._testMethodName)
307        if (getattr(self.__class__, "__unittest_skip__", False) or
308            getattr(testMethod, "__unittest_skip__", False)):
309            # If the class or method was skipped.
310            try:
311                skip_why = (getattr(self.__class__, '__unittest_skip_why__', '')
312                            or getattr(testMethod, '__unittest_skip_why__', ''))
313                self._addSkip(result, skip_why)
314            finally:
315                result.stopTest(self)
316            return
317        try:
318            success = False
319            try:
320                self.setUp()
321            except SkipTest as e:
322                self._addSkip(result, str(e))
323            except KeyboardInterrupt:
324                raise
325            except:
326                result.addError(self, sys.exc_info())
327            else:
328                try:
329                    testMethod()
330                except KeyboardInterrupt:
331                    raise
332                except self.failureException:
333                    result.addFailure(self, sys.exc_info())
334                except _ExpectedFailure as e:
335                    addExpectedFailure = getattr(result, 'addExpectedFailure', None)
336                    if addExpectedFailure is not None:
337                        addExpectedFailure(self, e.exc_info)
338                    else:
339                        warnings.warn("TestResult has no addExpectedFailure method, reporting as passes",
340                                      RuntimeWarning)
341                        result.addSuccess(self)
342                except _UnexpectedSuccess:
343                    addUnexpectedSuccess = getattr(result, 'addUnexpectedSuccess', None)
344                    if addUnexpectedSuccess is not None:
345                        addUnexpectedSuccess(self)
346                    else:
347                        warnings.warn("TestResult has no addUnexpectedSuccess method, reporting as failures",
348                                      RuntimeWarning)
349                        result.addFailure(self, sys.exc_info())
350                except SkipTest as e:
351                    self._addSkip(result, str(e))
352                except:
353                    result.addError(self, sys.exc_info())
354                else:
355                    success = True
356
357                try:
358                    self.tearDown()
359                except KeyboardInterrupt:
360                    raise
361                except:
362                    result.addError(self, sys.exc_info())
363                    success = False
364
365            cleanUpSuccess = self.doCleanups()
366            success = success and cleanUpSuccess
367            if success:
368                result.addSuccess(self)
369        finally:
370            result.stopTest(self)
371            if orig_result is None:
372                stopTestRun = getattr(result, 'stopTestRun', None)
373                if stopTestRun is not None:
374                    stopTestRun()
375
376    def doCleanups(self):
377        """Execute all cleanup functions. Normally called for you after
378        tearDown."""
379        result = self._resultForDoCleanups
380        ok = True
381        while self._cleanups:
382            function, args, kwargs = self._cleanups.pop(-1)
383            try:
384                function(*args, **kwargs)
385            except KeyboardInterrupt:
386                raise
387            except:
388                ok = False
389                result.addError(self, sys.exc_info())
390        return ok
391
392    def __call__(self, *args, **kwds):
393        return self.run(*args, **kwds)
394
395    def debug(self):
396        """Run the test without collecting errors in a TestResult"""
397        self.setUp()
398        getattr(self, self._testMethodName)()
399        self.tearDown()
400        while self._cleanups:
401            function, args, kwargs = self._cleanups.pop(-1)
402            function(*args, **kwargs)
403
404    def skipTest(self, reason):
405        """Skip this test."""
406        raise SkipTest(reason)
407
408    def fail(self, msg=None):
409        """Fail immediately, with the given message."""
410        raise self.failureException(msg)
411
412    def assertFalse(self, expr, msg=None):
413        """Check that the expression is false."""
414        if expr:
415            msg = self._formatMessage(msg, "%s is not false" % safe_repr(expr))
416            raise self.failureException(msg)
417
418    def assertTrue(self, expr, msg=None):
419        """Check that the expression is true."""
420        if not expr:
421            msg = self._formatMessage(msg, "%s is not true" % safe_repr(expr))
422            raise self.failureException(msg)
423
424    def _formatMessage(self, msg, standardMsg):
425        """Honour the longMessage attribute when generating failure messages.
426        If longMessage is False this means:
427        * Use only an explicit message if it is provided
428        * Otherwise use the standard message for the assert
429
430        If longMessage is True:
431        * Use the standard message
432        * If an explicit message is provided, plus ' : ' and the explicit message
433        """
434        if not self.longMessage:
435            return msg or standardMsg
436        if msg is None:
437            return standardMsg
438        try:
439            # don't switch to '{}' formatting in Python 2.X
440            # it changes the way unicode input is handled
441            return '%s : %s' % (standardMsg, msg)
442        except UnicodeDecodeError:
443            return  '%s : %s' % (safe_repr(standardMsg), safe_repr(msg))
444
445
446    def assertRaises(self, excClass, callableObj=None, *args, **kwargs):
447        """Fail unless an exception of class excClass is raised
448           by callableObj when invoked with arguments args and keyword
449           arguments kwargs. If a different type of exception is
450           raised, it will not be caught, and the test case will be
451           deemed to have suffered an error, exactly as for an
452           unexpected exception.
453
454           If called with callableObj omitted or None, will return a
455           context object used like this::
456
457                with self.assertRaises(SomeException):
458                    do_something()
459
460           The context manager keeps a reference to the exception as
461           the 'exception' attribute. This allows you to inspect the
462           exception after the assertion::
463
464               with self.assertRaises(SomeException) as cm:
465                   do_something()
466               the_exception = cm.exception
467               self.assertEqual(the_exception.error_code, 3)
468        """
469        context = _AssertRaisesContext(excClass, self)
470        if callableObj is None:
471            return context
472        with context:
473            callableObj(*args, **kwargs)
474
475    def _getAssertEqualityFunc(self, first, second):
476        """Get a detailed comparison function for the types of the two args.
477
478        Returns: A callable accepting (first, second, msg=None) that will
479        raise a failure exception if first != second with a useful human
480        readable error message for those types.
481        """
482        #
483        # NOTE(gregory.p.smith): I considered isinstance(first, type(second))
484        # and vice versa.  I opted for the conservative approach in case
485        # subclasses are not intended to be compared in detail to their super
486        # class instances using a type equality func.  This means testing
487        # subtypes won't automagically use the detailed comparison.  Callers
488        # should use their type specific assertSpamEqual method to compare
489        # subclasses if the detailed comparison is desired and appropriate.
490        # See the discussion in http://bugs.python.org/issue2578.
491        #
492        if type(first) is type(second):
493            asserter = self._type_equality_funcs.get(type(first))
494            if asserter is not None:
495                if isinstance(asserter, basestring):
496                    asserter = getattr(self, asserter)
497                return asserter
498
499        return self._baseAssertEqual
500
501    def _baseAssertEqual(self, first, second, msg=None):
502        """The default assertEqual implementation, not type specific."""
503        if not first == second:
504            standardMsg = '%s != %s' % (safe_repr(first), safe_repr(second))
505            msg = self._formatMessage(msg, standardMsg)
506            raise self.failureException(msg)
507
508    def assertEqual(self, first, second, msg=None):
509        """Fail if the two objects are unequal as determined by the '=='
510           operator.
511        """
512        assertion_func = self._getAssertEqualityFunc(first, second)
513        assertion_func(first, second, msg=msg)
514
515    def assertNotEqual(self, first, second, msg=None):
516        """Fail if the two objects are equal as determined by the '!='
517           operator.
518        """
519        if not first != second:
520            msg = self._formatMessage(msg, '%s == %s' % (safe_repr(first),
521                                                          safe_repr(second)))
522            raise self.failureException(msg)
523
524
525    def assertAlmostEqual(self, first, second, places=None, msg=None, delta=None):
526        """Fail if the two objects are unequal as determined by their
527           difference rounded to the given number of decimal places
528           (default 7) and comparing to zero, or by comparing that the
529           difference between the two objects is more than the given
530           delta.
531
532           Note that decimal places (from zero) are usually not the same
533           as significant digits (measured from the most significant digit).
534
535           If the two objects compare equal then they will automatically
536           compare almost equal.
537        """
538        if first == second:
539            # shortcut
540            return
541        if delta is not None and places is not None:
542            raise TypeError("specify delta or places not both")
543
544        if delta is not None:
545            if abs(first - second) <= delta:
546                return
547
548            standardMsg = '%s != %s within %s delta' % (safe_repr(first),
549                                                        safe_repr(second),
550                                                        safe_repr(delta))
551        else:
552            if places is None:
553                places = 7
554
555            if round(abs(second-first), places) == 0:
556                return
557
558            standardMsg = '%s != %s within %r places' % (safe_repr(first),
559                                                          safe_repr(second),
560                                                          places)
561        msg = self._formatMessage(msg, standardMsg)
562        raise self.failureException(msg)
563
564    def assertNotAlmostEqual(self, first, second, places=None, msg=None, delta=None):
565        """Fail if the two objects are equal as determined by their
566           difference rounded to the given number of decimal places
567           (default 7) and comparing to zero, or by comparing that the
568           difference between the two objects is less than the given delta.
569
570           Note that decimal places (from zero) are usually not the same
571           as significant digits (measured from the most significant digit).
572
573           Objects that are equal automatically fail.
574        """
575        if delta is not None and places is not None:
576            raise TypeError("specify delta or places not both")
577        if delta is not None:
578            if not (first == second) and abs(first - second) > delta:
579                return
580            standardMsg = '%s == %s within %s delta' % (safe_repr(first),
581                                                        safe_repr(second),
582                                                        safe_repr(delta))
583        else:
584            if places is None:
585                places = 7
586            if not (first == second) and round(abs(second-first), places) != 0:
587                return
588            standardMsg = '%s == %s within %r places' % (safe_repr(first),
589                                                         safe_repr(second),
590                                                         places)
591
592        msg = self._formatMessage(msg, standardMsg)
593        raise self.failureException(msg)
594
595    # Synonyms for assertion methods
596
597    # The plurals are undocumented.  Keep them that way to discourage use.
598    # Do not add more.  Do not remove.
599    # Going through a deprecation cycle on these would annoy many people.
600    assertEquals = assertEqual
601    assertNotEquals = assertNotEqual
602    assertAlmostEquals = assertAlmostEqual
603    assertNotAlmostEquals = assertNotAlmostEqual
604    assert_ = assertTrue
605
606    # These fail* assertion method names are pending deprecation and will
607    # be a DeprecationWarning in 3.2; http://bugs.python.org/issue2578
608    def _deprecate(original_func):
609        def deprecated_func(*args, **kwargs):
610            warnings.warn(
611                'Please use {0} instead.'.format(original_func.__name__),
612                PendingDeprecationWarning, 2)
613            return original_func(*args, **kwargs)
614        return deprecated_func
615
616    failUnlessEqual = _deprecate(assertEqual)
617    failIfEqual = _deprecate(assertNotEqual)
618    failUnlessAlmostEqual = _deprecate(assertAlmostEqual)
619    failIfAlmostEqual = _deprecate(assertNotAlmostEqual)
620    failUnless = _deprecate(assertTrue)
621    failUnlessRaises = _deprecate(assertRaises)
622    failIf = _deprecate(assertFalse)
623
624    def assertSequenceEqual(self, seq1, seq2, msg=None, seq_type=None):
625        """An equality assertion for ordered sequences (like lists and tuples).
626
627        For the purposes of this function, a valid ordered sequence type is one
628        which can be indexed, has a length, and has an equality operator.
629
630        Args:
631            seq1: The first sequence to compare.
632            seq2: The second sequence to compare.
633            seq_type: The expected datatype of the sequences, or None if no
634                    datatype should be enforced.
635            msg: Optional message to use on failure instead of a list of
636                    differences.
637        """
638        if seq_type is not None:
639            seq_type_name = seq_type.__name__
640            if not isinstance(seq1, seq_type):
641                raise self.failureException('First sequence is not a %s: %s'
642                                        % (seq_type_name, safe_repr(seq1)))
643            if not isinstance(seq2, seq_type):
644                raise self.failureException('Second sequence is not a %s: %s'
645                                        % (seq_type_name, safe_repr(seq2)))
646        else:
647            seq_type_name = "sequence"
648
649        differing = None
650        try:
651            len1 = len(seq1)
652        except (TypeError, NotImplementedError):
653            differing = 'First %s has no length.    Non-sequence?' % (
654                    seq_type_name)
655
656        if differing is None:
657            try:
658                len2 = len(seq2)
659            except (TypeError, NotImplementedError):
660                differing = 'Second %s has no length.    Non-sequence?' % (
661                        seq_type_name)
662
663        if differing is None:
664            if seq1 == seq2:
665                return
666
667            seq1_repr = safe_repr(seq1)
668            seq2_repr = safe_repr(seq2)
669            if len(seq1_repr) > 30:
670                seq1_repr = seq1_repr[:30] + '...'
671            if len(seq2_repr) > 30:
672                seq2_repr = seq2_repr[:30] + '...'
673            elements = (seq_type_name.capitalize(), seq1_repr, seq2_repr)
674            differing = '%ss differ: %s != %s\n' % elements
675
676            for i in xrange(min(len1, len2)):
677                try:
678                    item1 = seq1[i]
679                except (TypeError, IndexError, NotImplementedError):
680                    differing += ('\nUnable to index element %d of first %s\n' %
681                                 (i, seq_type_name))
682                    break
683
684                try:
685                    item2 = seq2[i]
686                except (TypeError, IndexError, NotImplementedError):
687                    differing += ('\nUnable to index element %d of second %s\n' %
688                                 (i, seq_type_name))
689                    break
690
691                if item1 != item2:
692                    differing += ('\nFirst differing element %d:\n%s\n%s\n' %
693                                 (i, safe_repr(item1), safe_repr(item2)))
694                    break
695            else:
696                if (len1 == len2 and seq_type is None and
697                    type(seq1) != type(seq2)):
698                    # The sequences are the same, but have differing types.
699                    return
700
701            if len1 > len2:
702                differing += ('\nFirst %s contains %d additional '
703                             'elements.\n' % (seq_type_name, len1 - len2))
704                try:
705                    differing += ('First extra element %d:\n%s\n' %
706                                  (len2, safe_repr(seq1[len2])))
707                except (TypeError, IndexError, NotImplementedError):
708                    differing += ('Unable to index element %d '
709                                  'of first %s\n' % (len2, seq_type_name))
710            elif len1 < len2:
711                differing += ('\nSecond %s contains %d additional '
712                             'elements.\n' % (seq_type_name, len2 - len1))
713                try:
714                    differing += ('First extra element %d:\n%s\n' %
715                                  (len1, safe_repr(seq2[len1])))
716                except (TypeError, IndexError, NotImplementedError):
717                    differing += ('Unable to index element %d '
718                                  'of second %s\n' % (len1, seq_type_name))
719        standardMsg = differing
720        diffMsg = '\n' + '\n'.join(
721            difflib.ndiff(pprint.pformat(seq1).splitlines(),
722                          pprint.pformat(seq2).splitlines()))
723        standardMsg = self._truncateMessage(standardMsg, diffMsg)
724        msg = self._formatMessage(msg, standardMsg)
725        self.fail(msg)
726
727    def _truncateMessage(self, message, diff):
728        max_diff = self.maxDiff
729        if max_diff is None or len(diff) <= max_diff:
730            return message + diff
731        return message + (DIFF_OMITTED % len(diff))
732
733    def assertListEqual(self, list1, list2, msg=None):
734        """A list-specific equality assertion.
735
736        Args:
737            list1: The first list to compare.
738            list2: The second list to compare.
739            msg: Optional message to use on failure instead of a list of
740                    differences.
741
742        """
743        self.assertSequenceEqual(list1, list2, msg, seq_type=list)
744
745    def assertTupleEqual(self, tuple1, tuple2, msg=None):
746        """A tuple-specific equality assertion.
747
748        Args:
749            tuple1: The first tuple to compare.
750            tuple2: The second tuple to compare.
751            msg: Optional message to use on failure instead of a list of
752                    differences.
753        """
754        self.assertSequenceEqual(tuple1, tuple2, msg, seq_type=tuple)
755
756    def assertSetEqual(self, set1, set2, msg=None):
757        """A set-specific equality assertion.
758
759        Args:
760            set1: The first set to compare.
761            set2: The second set to compare.
762            msg: Optional message to use on failure instead of a list of
763                    differences.
764
765        assertSetEqual uses ducktyping to support different types of sets, and
766        is optimized for sets specifically (parameters must support a
767        difference method).
768        """
769        try:
770            difference1 = set1.difference(set2)
771        except TypeError, e:
772            self.fail('invalid type when attempting set difference: %s' % e)
773        except AttributeError, e:
774            self.fail('first argument does not support set difference: %s' % e)
775
776        try:
777            difference2 = set2.difference(set1)
778        except TypeError, e:
779            self.fail('invalid type when attempting set difference: %s' % e)
780        except AttributeError, e:
781            self.fail('second argument does not support set difference: %s' % e)
782
783        if not (difference1 or difference2):
784            return
785
786        lines = []
787        if difference1:
788            lines.append('Items in the first set but not the second:')
789            for item in difference1:
790                lines.append(repr(item))
791        if difference2:
792            lines.append('Items in the second set but not the first:')
793            for item in difference2:
794                lines.append(repr(item))
795
796        standardMsg = '\n'.join(lines)
797        self.fail(self._formatMessage(msg, standardMsg))
798
799    def assertIn(self, member, container, msg=None):
800        """Just like self.assertTrue(a in b), but with a nicer default message."""
801        if member not in container:
802            standardMsg = '%s not found in %s' % (safe_repr(member),
803                                                  safe_repr(container))
804            self.fail(self._formatMessage(msg, standardMsg))
805
806    def assertNotIn(self, member, container, msg=None):
807        """Just like self.assertTrue(a not in b), but with a nicer default message."""
808        if member in container:
809            standardMsg = '%s unexpectedly found in %s' % (safe_repr(member),
810                                                        safe_repr(container))
811            self.fail(self._formatMessage(msg, standardMsg))
812
813    def assertIs(self, expr1, expr2, msg=None):
814        """Just like self.assertTrue(a is b), but with a nicer default message."""
815        if expr1 is not expr2:
816            standardMsg = '%s is not %s' % (safe_repr(expr1),
817                                             safe_repr(expr2))
818            self.fail(self._formatMessage(msg, standardMsg))
819
820    def assertIsNot(self, expr1, expr2, msg=None):
821        """Just like self.assertTrue(a is not b), but with a nicer default message."""
822        if expr1 is expr2:
823            standardMsg = 'unexpectedly identical: %s' % (safe_repr(expr1),)
824            self.fail(self._formatMessage(msg, standardMsg))
825
826    def assertDictEqual(self, d1, d2, msg=None):
827        self.assertIsInstance(d1, dict, 'First argument is not a dictionary')
828        self.assertIsInstance(d2, dict, 'Second argument is not a dictionary')
829
830        if d1 != d2:
831            standardMsg = '%s != %s' % (safe_repr(d1, True), safe_repr(d2, True))
832            diff = ('\n' + '\n'.join(difflib.ndiff(
833                           pprint.pformat(d1).splitlines(),
834                           pprint.pformat(d2).splitlines())))
835            standardMsg = self._truncateMessage(standardMsg, diff)
836            self.fail(self._formatMessage(msg, standardMsg))
837
838    def assertDictContainsSubset(self, expected, actual, msg=None):
839        """Checks whether actual is a superset of expected."""
840        missing = []
841        mismatched = []
842        for key, value in expected.iteritems():
843            if key not in actual:
844                missing.append(key)
845            elif value != actual[key]:
846                mismatched.append('%s, expected: %s, actual: %s' %
847                                  (safe_repr(key), safe_repr(value),
848                                   safe_repr(actual[key])))
849
850        if not (missing or mismatched):
851            return
852
853        standardMsg = ''
854        if missing:
855            standardMsg = 'Missing: %s' % ','.join(safe_repr(m) for m in
856                                                    missing)
857        if mismatched:
858            if standardMsg:
859                standardMsg += '; '
860            standardMsg += 'Mismatched values: %s' % ','.join(mismatched)
861
862        self.fail(self._formatMessage(msg, standardMsg))
863
864    def assertItemsEqual(self, expected_seq, actual_seq, msg=None):
865        """An unordered sequence specific comparison. It asserts that
866        actual_seq and expected_seq have the same element counts.
867        Equivalent to::
868
869            self.assertEqual(Counter(iter(actual_seq)),
870                             Counter(iter(expected_seq)))
871
872        Asserts that each element has the same count in both sequences.
873        Example:
874            - [0, 1, 1] and [1, 0, 1] compare equal.
875            - [0, 0, 1] and [0, 1] compare unequal.
876        """
877        first_seq, second_seq = list(expected_seq), list(actual_seq)
878        with warnings.catch_warnings():
879            if sys.py3kwarning:
880                # Silence Py3k warning raised during the sorting
881                for _msg in ["(code|dict|type) inequality comparisons",
882                             "builtin_function_or_method order comparisons",
883                             "comparing unequal types"]:
884                    warnings.filterwarnings("ignore", _msg, DeprecationWarning)
885            try:
886                first = collections.Counter(first_seq)
887                second = collections.Counter(second_seq)
888            except TypeError:
889                # Handle case with unhashable elements
890                differences = _count_diff_all_purpose(first_seq, second_seq)
891            else:
892                if first == second:
893                    return
894                differences = _count_diff_hashable(first_seq, second_seq)
895
896        if differences:
897            standardMsg = 'Element counts were not equal:\n'
898            lines = ['First has %d, Second has %d:  %r' % diff for diff in differences]
899            diffMsg = '\n'.join(lines)
900            standardMsg = self._truncateMessage(standardMsg, diffMsg)
901            msg = self._formatMessage(msg, standardMsg)
902            self.fail(msg)
903
904    def assertMultiLineEqual(self, first, second, msg=None):
905        """Assert that two multi-line strings are equal."""
906        self.assertIsInstance(first, basestring,
907                'First argument is not a string')
908        self.assertIsInstance(second, basestring,
909                'Second argument is not a string')
910
911        if first != second:
912            # don't use difflib if the strings are too long
913            if (len(first) > self._diffThreshold or
914                len(second) > self._diffThreshold):
915                self._baseAssertEqual(first, second, msg)
916            firstlines = first.splitlines(True)
917            secondlines = second.splitlines(True)
918            if len(firstlines) == 1 and first.strip('\r\n') == first:
919                firstlines = [first + '\n']
920                secondlines = [second + '\n']
921            standardMsg = '%s != %s' % (safe_repr(first, True),
922                                        safe_repr(second, True))
923            diff = '\n' + ''.join(difflib.ndiff(firstlines, secondlines))
924            standardMsg = self._truncateMessage(standardMsg, diff)
925            self.fail(self._formatMessage(msg, standardMsg))
926
927    def assertLess(self, a, b, msg=None):
928        """Just like self.assertTrue(a < b), but with a nicer default message."""
929        if not a < b:
930            standardMsg = '%s not less than %s' % (safe_repr(a), safe_repr(b))
931            self.fail(self._formatMessage(msg, standardMsg))
932
933    def assertLessEqual(self, a, b, msg=None):
934        """Just like self.assertTrue(a <= b), but with a nicer default message."""
935        if not a <= b:
936            standardMsg = '%s not less than or equal to %s' % (safe_repr(a), safe_repr(b))
937            self.fail(self._formatMessage(msg, standardMsg))
938
939    def assertGreater(self, a, b, msg=None):
940        """Just like self.assertTrue(a > b), but with a nicer default message."""
941        if not a > b:
942            standardMsg = '%s not greater than %s' % (safe_repr(a), safe_repr(b))
943            self.fail(self._formatMessage(msg, standardMsg))
944
945    def assertGreaterEqual(self, a, b, msg=None):
946        """Just like self.assertTrue(a >= b), but with a nicer default message."""
947        if not a >= b:
948            standardMsg = '%s not greater than or equal to %s' % (safe_repr(a), safe_repr(b))
949            self.fail(self._formatMessage(msg, standardMsg))
950
951    def assertIsNone(self, obj, msg=None):
952        """Same as self.assertTrue(obj is None), with a nicer default message."""
953        if obj is not None:
954            standardMsg = '%s is not None' % (safe_repr(obj),)
955            self.fail(self._formatMessage(msg, standardMsg))
956
957    def assertIsNotNone(self, obj, msg=None):
958        """Included for symmetry with assertIsNone."""
959        if obj is None:
960            standardMsg = 'unexpectedly None'
961            self.fail(self._formatMessage(msg, standardMsg))
962
963    def assertIsInstance(self, obj, cls, msg=None):
964        """Same as self.assertTrue(isinstance(obj, cls)), with a nicer
965        default message."""
966        if not isinstance(obj, cls):
967            standardMsg = '%s is not an instance of %r' % (safe_repr(obj), cls)
968            self.fail(self._formatMessage(msg, standardMsg))
969
970    def assertNotIsInstance(self, obj, cls, msg=None):
971        """Included for symmetry with assertIsInstance."""
972        if isinstance(obj, cls):
973            standardMsg = '%s is an instance of %r' % (safe_repr(obj), cls)
974            self.fail(self._formatMessage(msg, standardMsg))
975
976    def assertRaisesRegexp(self, expected_exception, expected_regexp,
977                           callable_obj=None, *args, **kwargs):
978        """Asserts that the message in a raised exception matches a regexp.
979
980        Args:
981            expected_exception: Exception class expected to be raised.
982            expected_regexp: Regexp (re pattern object or string) expected
983                    to be found in error message.
984            callable_obj: Function to be called.
985            args: Extra args.
986            kwargs: Extra kwargs.
987        """
988        if expected_regexp is not None:
989            expected_regexp = re.compile(expected_regexp)
990        context = _AssertRaisesContext(expected_exception, self, expected_regexp)
991        if callable_obj is None:
992            return context
993        with context:
994            callable_obj(*args, **kwargs)
995
996    def assertRegexpMatches(self, text, expected_regexp, msg=None):
997        """Fail the test unless the text matches the regular expression."""
998        if isinstance(expected_regexp, basestring):
999            expected_regexp = re.compile(expected_regexp)
1000        if not expected_regexp.search(text):
1001            msg = msg or "Regexp didn't match"
1002            msg = '%s: %r not found in %r' % (msg, expected_regexp.pattern, text)
1003            raise self.failureException(msg)
1004
1005    def assertNotRegexpMatches(self, text, unexpected_regexp, msg=None):
1006        """Fail the test if the text matches the regular expression."""
1007        if isinstance(unexpected_regexp, basestring):
1008            unexpected_regexp = re.compile(unexpected_regexp)
1009        match = unexpected_regexp.search(text)
1010        if match:
1011            msg = msg or "Regexp matched"
1012            msg = '%s: %r matches %r in %r' % (msg,
1013                                               text[match.start():match.end()],
1014                                               unexpected_regexp.pattern,
1015                                               text)
1016            raise self.failureException(msg)
1017
1018
1019class FunctionTestCase(TestCase):
1020    """A test case that wraps a test function.
1021
1022    This is useful for slipping pre-existing test functions into the
1023    unittest framework. Optionally, set-up and tidy-up functions can be
1024    supplied. As with TestCase, the tidy-up ('tearDown') function will
1025    always be called if the set-up ('setUp') function ran successfully.
1026    """
1027
1028    def __init__(self, testFunc, setUp=None, tearDown=None, description=None):
1029        super(FunctionTestCase, self).__init__()
1030        self._setUpFunc = setUp
1031        self._tearDownFunc = tearDown
1032        self._testFunc = testFunc
1033        self._description = description
1034
1035    def setUp(self):
1036        if self._setUpFunc is not None:
1037            self._setUpFunc()
1038
1039    def tearDown(self):
1040        if self._tearDownFunc is not None:
1041            self._tearDownFunc()
1042
1043    def runTest(self):
1044        self._testFunc()
1045
1046    def id(self):
1047        return self._testFunc.__name__
1048
1049    def __eq__(self, other):
1050        if not isinstance(other, self.__class__):
1051            return NotImplemented
1052
1053        return self._setUpFunc == other._setUpFunc and \
1054               self._tearDownFunc == other._tearDownFunc and \
1055               self._testFunc == other._testFunc and \
1056               self._description == other._description
1057
1058    def __ne__(self, other):
1059        return not self == other
1060
1061    def __hash__(self):
1062        return hash((type(self), self._setUpFunc, self._tearDownFunc,
1063                     self._testFunc, self._description))
1064
1065    def __str__(self):
1066        return "%s (%s)" % (strclass(self.__class__),
1067                            self._testFunc.__name__)
1068
1069    def __repr__(self):
1070        return "<%s tec=%s>" % (strclass(self.__class__),
1071                                     self._testFunc)
1072
1073    def shortDescription(self):
1074        if self._description is not None:
1075            return self._description
1076        doc = self._testFunc.__doc__
1077        return doc and doc.split("\n")[0].strip() or None
1078