1"""Test case implementation"""
2
3import collections
4import sys
5import functools
6import difflib
7import pprint
8import re
9import types
10import warnings
11
12from . import result
13from .util import (
14    strclass, safe_repr, unorderable_list_difference,
15    _count_diff_all_purpose, _count_diff_hashable
16)
17
18
19__unittest = True
20
21
22DIFF_OMITTED = ('\nDiff is %s characters long. '
23                 'Set self.maxDiff to None to see it.')
24
25class SkipTest(Exception):
26    """
27    Raise this exception in a test to skip it.
28
29    Usually you can use TestResult.skip() or one of the skipping decorators
30    instead of raising this directly.
31    """
32    pass
33
34class _ExpectedFailure(Exception):
35    """
36    Raise this when a test is expected to fail.
37
38    This is an implementation detail.
39    """
40
41    def __init__(self, exc_info):
42        super(_ExpectedFailure, self).__init__()
43        self.exc_info = exc_info
44
45class _UnexpectedSuccess(Exception):
46    """
47    The test was supposed to fail, but it didn't!
48    """
49    pass
50
51def _id(obj):
52    return obj
53
54def skip(reason):
55    """
56    Unconditionally skip a test.
57    """
58    def decorator(test_item):
59        if not isinstance(test_item, (type, types.ClassType)):
60            @functools.wraps(test_item)
61            def skip_wrapper(*args, **kwargs):
62                raise SkipTest(reason)
63            test_item = skip_wrapper
64
65        test_item.__unittest_skip__ = True
66        test_item.__unittest_skip_why__ = reason
67        return test_item
68    return decorator
69
70def skipIf(condition, reason):
71    """
72    Skip a test if the condition is true.
73    """
74    if condition:
75        return skip(reason)
76    return _id
77
78def skipUnless(condition, reason):
79    """
80    Skip a test unless the condition is true.
81    """
82    if not condition:
83        return skip(reason)
84    return _id
85
86
87def expectedFailure(func):
88    @functools.wraps(func)
89    def wrapper(*args, **kwargs):
90        try:
91            func(*args, **kwargs)
92        except Exception:
93            raise _ExpectedFailure(sys.exc_info())
94        raise _UnexpectedSuccess
95    return wrapper
96
97
98class _AssertRaisesContext(object):
99    """A context manager used to implement TestCase.assertRaises* methods."""
100
101    def __init__(self, expected, test_case, expected_regexp=None):
102        self.expected = expected
103        self.failureException = test_case.failureException
104        self.expected_regexp = expected_regexp
105
106    def __enter__(self):
107        return self
108
109    def __exit__(self, exc_type, exc_value, tb):
110        if exc_type is None:
111            try:
112                exc_name = self.expected.__name__
113            except AttributeError:
114                exc_name = str(self.expected)
115            raise self.failureException(
116                "{0} not raised".format(exc_name))
117        if not issubclass(exc_type, self.expected):
118            # let unexpected exceptions pass through
119            return False
120        self.exception = exc_value # store for later retrieval
121        if self.expected_regexp is None:
122            return True
123
124        expected_regexp = self.expected_regexp
125        if isinstance(expected_regexp, basestring):
126            expected_regexp = re.compile(expected_regexp)
127        if not expected_regexp.search(str(exc_value)):
128            raise self.failureException('"%s" does not match "%s"' %
129                     (expected_regexp.pattern, str(exc_value)))
130        return True
131
132
133class TestCase(object):
134    """A class whose instances are single test cases.
135
136    By default, the test code itself should be placed in a method named
137    'runTest'.
138
139    If the fixture may be used for many test cases, create as
140    many test methods as are needed. When instantiating such a TestCase
141    subclass, specify in the constructor arguments the name of the test method
142    that the instance is to execute.
143
144    Test authors should subclass TestCase for their own tests. Construction
145    and deconstruction of the test's environment ('fixture') can be
146    implemented by overriding the 'setUp' and 'tearDown' methods respectively.
147
148    If it is necessary to override the __init__ method, the base class
149    __init__ method must always be called. It is important that subclasses
150    should not change the signature of their __init__ method, since instances
151    of the classes are instantiated automatically by parts of the framework
152    in order to be run.
153    """
154
155    # This attribute determines which exception will be raised when
156    # the instance's assertion methods fail; test methods raising this
157    # exception will be deemed to have 'failed' rather than 'errored'
158
159    failureException = AssertionError
160
161    # This attribute determines whether long messages (including repr of
162    # objects used in assert methods) will be printed on failure in *addition*
163    # to any explicit message passed.
164
165    longMessage = False
166
167    # This attribute sets the maximum length of a diff in failure messages
168    # by assert methods using difflib. It is looked up as an instance attribute
169    # so can be configured by individual tests if required.
170
171    maxDiff = 80*8
172
173    # If a string is longer than _diffThreshold, use normal comparison instead
174    # of difflib.  See #11763.
175    _diffThreshold = 2**16
176
177    # Attribute used by TestSuite for classSetUp
178
179    _classSetupFailed = False
180
181    def __init__(self, methodName='runTest'):
182        """Create an instance of the class that will use the named test
183           method when executed. Raises a ValueError if the instance does
184           not have a method with the specified name.
185        """
186        self._testMethodName = methodName
187        self._resultForDoCleanups = None
188        try:
189            testMethod = getattr(self, methodName)
190        except AttributeError:
191            raise ValueError("no such test method in %s: %s" %
192                  (self.__class__, methodName))
193        self._testMethodDoc = testMethod.__doc__
194        self._cleanups = []
195
196        # Map types to custom assertEqual functions that will compare
197        # instances of said type in more detail to generate a more useful
198        # error message.
199        self._type_equality_funcs = {}
200        self.addTypeEqualityFunc(dict, 'assertDictEqual')
201        self.addTypeEqualityFunc(list, 'assertListEqual')
202        self.addTypeEqualityFunc(tuple, 'assertTupleEqual')
203        self.addTypeEqualityFunc(set, 'assertSetEqual')
204        self.addTypeEqualityFunc(frozenset, 'assertSetEqual')
205        try:
206            self.addTypeEqualityFunc(unicode, 'assertMultiLineEqual')
207        except NameError:
208            # No unicode support in this build
209            pass
210
211    def addTypeEqualityFunc(self, typeobj, function):
212        """Add a type specific assertEqual style function to compare a type.
213
214        This method is for use by TestCase subclasses that need to register
215        their own type equality functions to provide nicer error messages.
216
217        Args:
218            typeobj: The data type to call this function on when both values
219                    are of the same type in assertEqual().
220            function: The callable taking two arguments and an optional
221                    msg= argument that raises self.failureException with a
222                    useful error message when the two arguments are not equal.
223        """
224        self._type_equality_funcs[typeobj] = function
225
226    def addCleanup(self, function, *args, **kwargs):
227        """Add a function, with arguments, to be called when the test is
228        completed. Functions added are called on a LIFO basis and are
229        called after tearDown on test failure or success.
230
231        Cleanup items are called even if setUp fails (unlike tearDown)."""
232        self._cleanups.append((function, args, kwargs))
233
234    def setUp(self):
235        "Hook method for setting up the test fixture before exercising it."
236        pass
237
238    def tearDown(self):
239        "Hook method for deconstructing the test fixture after testing it."
240        pass
241
242    @classmethod
243    def setUpClass(cls):
244        "Hook method for setting up class fixture before running tests in the class."
245
246    @classmethod
247    def tearDownClass(cls):
248        "Hook method for deconstructing the class fixture after running all tests in the class."
249
250    def countTestCases(self):
251        return 1
252
253    def defaultTestResult(self):
254        return result.TestResult()
255
256    def shortDescription(self):
257        """Returns a one-line description of the test, or None if no
258        description has been provided.
259
260        The default implementation of this method returns the first line of
261        the specified test method's docstring.
262        """
263        doc = self._testMethodDoc
264        return doc and doc.split("\n")[0].strip() or None
265
266
267    def id(self):
268        return "%s.%s" % (strclass(self.__class__), self._testMethodName)
269
270    def __eq__(self, other):
271        if type(self) is not type(other):
272            return NotImplemented
273
274        return self._testMethodName == other._testMethodName
275
276    def __ne__(self, other):
277        return not self == other
278
279    def __hash__(self):
280        return hash((type(self), self._testMethodName))
281
282    def __str__(self):
283        return "%s (%s)" % (self._testMethodName, strclass(self.__class__))
284
285    def __repr__(self):
286        return "<%s testMethod=%s>" % \
287               (strclass(self.__class__), self._testMethodName)
288
289    def _addSkip(self, result, reason):
290        addSkip = getattr(result, 'addSkip', None)
291        if addSkip is not None:
292            addSkip(self, reason)
293        else:
294            warnings.warn("TestResult has no addSkip method, skips not reported",
295                          RuntimeWarning, 2)
296            result.addSuccess(self)
297
298    def run(self, result=None):
299        orig_result = result
300        if result is None:
301            result = self.defaultTestResult()
302            startTestRun = getattr(result, 'startTestRun', None)
303            if startTestRun is not None:
304                startTestRun()
305
306        self._resultForDoCleanups = result
307        result.startTest(self)
308
309        testMethod = getattr(self, self._testMethodName)
310        if (getattr(self.__class__, "__unittest_skip__", False) or
311            getattr(testMethod, "__unittest_skip__", False)):
312            # If the class or method was skipped.
313            try:
314                skip_why = (getattr(self.__class__, '__unittest_skip_why__', '')
315                            or getattr(testMethod, '__unittest_skip_why__', ''))
316                self._addSkip(result, skip_why)
317            finally:
318                result.stopTest(self)
319            return
320        try:
321            success = False
322            try:
323                self.setUp()
324            except SkipTest as e:
325                self._addSkip(result, str(e))
326            except KeyboardInterrupt:
327                raise
328            except:
329                result.addError(self, sys.exc_info())
330            else:
331                try:
332                    testMethod()
333                except KeyboardInterrupt:
334                    raise
335                except self.failureException:
336                    result.addFailure(self, sys.exc_info())
337                except _ExpectedFailure as e:
338                    addExpectedFailure = getattr(result, 'addExpectedFailure', None)
339                    if addExpectedFailure is not None:
340                        addExpectedFailure(self, e.exc_info)
341                    else:
342                        warnings.warn("TestResult has no addExpectedFailure method, reporting as passes",
343                                      RuntimeWarning)
344                        result.addSuccess(self)
345                except _UnexpectedSuccess:
346                    addUnexpectedSuccess = getattr(result, 'addUnexpectedSuccess', None)
347                    if addUnexpectedSuccess is not None:
348                        addUnexpectedSuccess(self)
349                    else:
350                        warnings.warn("TestResult has no addUnexpectedSuccess method, reporting as failures",
351                                      RuntimeWarning)
352                        result.addFailure(self, sys.exc_info())
353                except SkipTest as e:
354                    self._addSkip(result, str(e))
355                except:
356                    result.addError(self, sys.exc_info())
357                else:
358                    success = True
359
360                try:
361                    self.tearDown()
362                except KeyboardInterrupt:
363                    raise
364                except:
365                    result.addError(self, sys.exc_info())
366                    success = False
367
368            cleanUpSuccess = self.doCleanups()
369            success = success and cleanUpSuccess
370            if success:
371                result.addSuccess(self)
372        finally:
373            result.stopTest(self)
374            if orig_result is None:
375                stopTestRun = getattr(result, 'stopTestRun', None)
376                if stopTestRun is not None:
377                    stopTestRun()
378
379    def doCleanups(self):
380        """Execute all cleanup functions. Normally called for you after
381        tearDown."""
382        result = self._resultForDoCleanups
383        ok = True
384        while self._cleanups:
385            function, args, kwargs = self._cleanups.pop(-1)
386            try:
387                function(*args, **kwargs)
388            except KeyboardInterrupt:
389                raise
390            except:
391                ok = False
392                result.addError(self, sys.exc_info())
393        return ok
394
395    def __call__(self, *args, **kwds):
396        return self.run(*args, **kwds)
397
398    def debug(self):
399        """Run the test without collecting errors in a TestResult"""
400        self.setUp()
401        getattr(self, self._testMethodName)()
402        self.tearDown()
403        while self._cleanups:
404            function, args, kwargs = self._cleanups.pop(-1)
405            function(*args, **kwargs)
406
407    def skipTest(self, reason):
408        """Skip this test."""
409        raise SkipTest(reason)
410
411    def fail(self, msg=None):
412        """Fail immediately, with the given message."""
413        raise self.failureException(msg)
414
415    def assertFalse(self, expr, msg=None):
416        """Check that the expression is false."""
417        if expr:
418            msg = self._formatMessage(msg, "%s is not false" % safe_repr(expr))
419            raise self.failureException(msg)
420
421    def assertTrue(self, expr, msg=None):
422        """Check that the expression is true."""
423        if not expr:
424            msg = self._formatMessage(msg, "%s is not true" % safe_repr(expr))
425            raise self.failureException(msg)
426
427    def _formatMessage(self, msg, standardMsg):
428        """Honour the longMessage attribute when generating failure messages.
429        If longMessage is False this means:
430        * Use only an explicit message if it is provided
431        * Otherwise use the standard message for the assert
432
433        If longMessage is True:
434        * Use the standard message
435        * If an explicit message is provided, plus ' : ' and the explicit message
436        """
437        if not self.longMessage:
438            return msg or standardMsg
439        if msg is None:
440            return standardMsg
441        try:
442            # don't switch to '{}' formatting in Python 2.X
443            # it changes the way unicode input is handled
444            return '%s : %s' % (standardMsg, msg)
445        except UnicodeDecodeError:
446            return  '%s : %s' % (safe_repr(standardMsg), safe_repr(msg))
447
448
449    def assertRaises(self, excClass, callableObj=None, *args, **kwargs):
450        """Fail unless an exception of class excClass is raised
451           by callableObj when invoked with arguments args and keyword
452           arguments kwargs. If a different type of exception is
453           raised, it will not be caught, and the test case will be
454           deemed to have suffered an error, exactly as for an
455           unexpected exception.
456
457           If called with callableObj omitted or None, will return a
458           context object used like this::
459
460                with self.assertRaises(SomeException):
461                    do_something()
462
463           The context manager keeps a reference to the exception as
464           the 'exception' attribute. This allows you to inspect the
465           exception after the assertion::
466
467               with self.assertRaises(SomeException) as cm:
468                   do_something()
469               the_exception = cm.exception
470               self.assertEqual(the_exception.error_code, 3)
471        """
472        context = _AssertRaisesContext(excClass, self)
473        if callableObj is None:
474            return context
475        with context:
476            callableObj(*args, **kwargs)
477
478    def _getAssertEqualityFunc(self, first, second):
479        """Get a detailed comparison function for the types of the two args.
480
481        Returns: A callable accepting (first, second, msg=None) that will
482        raise a failure exception if first != second with a useful human
483        readable error message for those types.
484        """
485        #
486        # NOTE(gregory.p.smith): I considered isinstance(first, type(second))
487        # and vice versa.  I opted for the conservative approach in case
488        # subclasses are not intended to be compared in detail to their super
489        # class instances using a type equality func.  This means testing
490        # subtypes won't automagically use the detailed comparison.  Callers
491        # should use their type specific assertSpamEqual method to compare
492        # subclasses if the detailed comparison is desired and appropriate.
493        # See the discussion in http://bugs.python.org/issue2578.
494        #
495        if type(first) is type(second):
496            asserter = self._type_equality_funcs.get(type(first))
497            if asserter is not None:
498                if isinstance(asserter, basestring):
499                    asserter = getattr(self, asserter)
500                return asserter
501
502        return self._baseAssertEqual
503
504    def _baseAssertEqual(self, first, second, msg=None):
505        """The default assertEqual implementation, not type specific."""
506        if not first == second:
507            standardMsg = '%s != %s' % (safe_repr(first), safe_repr(second))
508            msg = self._formatMessage(msg, standardMsg)
509            raise self.failureException(msg)
510
511    def assertEqual(self, first, second, msg=None):
512        """Fail if the two objects are unequal as determined by the '=='
513           operator.
514        """
515        assertion_func = self._getAssertEqualityFunc(first, second)
516        assertion_func(first, second, msg=msg)
517
518    def assertNotEqual(self, first, second, msg=None):
519        """Fail if the two objects are equal as determined by the '!='
520           operator.
521        """
522        if not first != second:
523            msg = self._formatMessage(msg, '%s == %s' % (safe_repr(first),
524                                                          safe_repr(second)))
525            raise self.failureException(msg)
526
527
528    def assertAlmostEqual(self, first, second, places=None, msg=None, delta=None):
529        """Fail if the two objects are unequal as determined by their
530           difference rounded to the given number of decimal places
531           (default 7) and comparing to zero, or by comparing that the
532           between the two objects is more than the given delta.
533
534           Note that decimal places (from zero) are usually not the same
535           as significant digits (measured from the most signficant digit).
536
537           If the two objects compare equal then they will automatically
538           compare almost equal.
539        """
540        if first == second:
541            # shortcut
542            return
543        if delta is not None and places is not None:
544            raise TypeError("specify delta or places not both")
545
546        if delta is not None:
547            if abs(first - second) <= delta:
548                return
549
550            standardMsg = '%s != %s within %s delta' % (safe_repr(first),
551                                                        safe_repr(second),
552                                                        safe_repr(delta))
553        else:
554            if places is None:
555                places = 7
556
557            if round(abs(second-first), places) == 0:
558                return
559
560            standardMsg = '%s != %s within %r places' % (safe_repr(first),
561                                                          safe_repr(second),
562                                                          places)
563        msg = self._formatMessage(msg, standardMsg)
564        raise self.failureException(msg)
565
566    def assertNotAlmostEqual(self, first, second, places=None, msg=None, delta=None):
567        """Fail if the two objects are equal as determined by their
568           difference rounded to the given number of decimal places
569           (default 7) and comparing to zero, or by comparing that the
570           between the two objects is less than the given delta.
571
572           Note that decimal places (from zero) are usually not the same
573           as significant digits (measured from the most signficant digit).
574
575           Objects that are equal automatically fail.
576        """
577        if delta is not None and places is not None:
578            raise TypeError("specify delta or places not both")
579        if delta is not None:
580            if not (first == second) and abs(first - second) > delta:
581                return
582            standardMsg = '%s == %s within %s delta' % (safe_repr(first),
583                                                        safe_repr(second),
584                                                        safe_repr(delta))
585        else:
586            if places is None:
587                places = 7
588            if not (first == second) and round(abs(second-first), places) != 0:
589                return
590            standardMsg = '%s == %s within %r places' % (safe_repr(first),
591                                                         safe_repr(second),
592                                                         places)
593
594        msg = self._formatMessage(msg, standardMsg)
595        raise self.failureException(msg)
596
597    # Synonyms for assertion methods
598
599    # The plurals are undocumented.  Keep them that way to discourage use.
600    # Do not add more.  Do not remove.
601    # Going through a deprecation cycle on these would annoy many people.
602    assertEquals = assertEqual
603    assertNotEquals = assertNotEqual
604    assertAlmostEquals = assertAlmostEqual
605    assertNotAlmostEquals = assertNotAlmostEqual
606    assert_ = assertTrue
607
608    # These fail* assertion method names are pending deprecation and will
609    # be a DeprecationWarning in 3.2; http://bugs.python.org/issue2578
610    def _deprecate(original_func):
611        def deprecated_func(*args, **kwargs):
612            warnings.warn(
613                'Please use {0} instead.'.format(original_func.__name__),
614                PendingDeprecationWarning, 2)
615            return original_func(*args, **kwargs)
616        return deprecated_func
617
618    failUnlessEqual = _deprecate(assertEqual)
619    failIfEqual = _deprecate(assertNotEqual)
620    failUnlessAlmostEqual = _deprecate(assertAlmostEqual)
621    failIfAlmostEqual = _deprecate(assertNotAlmostEqual)
622    failUnless = _deprecate(assertTrue)
623    failUnlessRaises = _deprecate(assertRaises)
624    failIf = _deprecate(assertFalse)
625
626    def assertSequenceEqual(self, seq1, seq2, msg=None, seq_type=None):
627        """An equality assertion for ordered sequences (like lists and tuples).
628
629        For the purposes of this function, a valid ordered sequence type is one
630        which can be indexed, has a length, and has an equality operator.
631
632        Args:
633            seq1: The first sequence to compare.
634            seq2: The second sequence to compare.
635            seq_type: The expected datatype of the sequences, or None if no
636                    datatype should be enforced.
637            msg: Optional message to use on failure instead of a list of
638                    differences.
639        """
640        if seq_type is not None:
641            seq_type_name = seq_type.__name__
642            if not isinstance(seq1, seq_type):
643                raise self.failureException('First sequence is not a %s: %s'
644                                        % (seq_type_name, safe_repr(seq1)))
645            if not isinstance(seq2, seq_type):
646                raise self.failureException('Second sequence is not a %s: %s'
647                                        % (seq_type_name, safe_repr(seq2)))
648        else:
649            seq_type_name = "sequence"
650
651        differing = None
652        try:
653            len1 = len(seq1)
654        except (TypeError, NotImplementedError):
655            differing = 'First %s has no length.    Non-sequence?' % (
656                    seq_type_name)
657
658        if differing is None:
659            try:
660                len2 = len(seq2)
661            except (TypeError, NotImplementedError):
662                differing = 'Second %s has no length.    Non-sequence?' % (
663                        seq_type_name)
664
665        if differing is None:
666            if seq1 == seq2:
667                return
668
669            seq1_repr = safe_repr(seq1)
670            seq2_repr = safe_repr(seq2)
671            if len(seq1_repr) > 30:
672                seq1_repr = seq1_repr[:30] + '...'
673            if len(seq2_repr) > 30:
674                seq2_repr = seq2_repr[:30] + '...'
675            elements = (seq_type_name.capitalize(), seq1_repr, seq2_repr)
676            differing = '%ss differ: %s != %s\n' % elements
677
678            for i in xrange(min(len1, len2)):
679                try:
680                    item1 = seq1[i]
681                except (TypeError, IndexError, NotImplementedError):
682                    differing += ('\nUnable to index element %d of first %s\n' %
683                                 (i, seq_type_name))
684                    break
685
686                try:
687                    item2 = seq2[i]
688                except (TypeError, IndexError, NotImplementedError):
689                    differing += ('\nUnable to index element %d of second %s\n' %
690                                 (i, seq_type_name))
691                    break
692
693                if item1 != item2:
694                    differing += ('\nFirst differing element %d:\n%s\n%s\n' %
695                                 (i, item1, item2))
696                    break
697            else:
698                if (len1 == len2 and seq_type is None and
699                    type(seq1) != type(seq2)):
700                    # The sequences are the same, but have differing types.
701                    return
702
703            if len1 > len2:
704                differing += ('\nFirst %s contains %d additional '
705                             'elements.\n' % (seq_type_name, len1 - len2))
706                try:
707                    differing += ('First extra element %d:\n%s\n' %
708                                  (len2, seq1[len2]))
709                except (TypeError, IndexError, NotImplementedError):
710                    differing += ('Unable to index element %d '
711                                  'of first %s\n' % (len2, seq_type_name))
712            elif len1 < len2:
713                differing += ('\nSecond %s contains %d additional '
714                             'elements.\n' % (seq_type_name, len2 - len1))
715                try:
716                    differing += ('First extra element %d:\n%s\n' %
717                                  (len1, seq2[len1]))
718                except (TypeError, IndexError, NotImplementedError):
719                    differing += ('Unable to index element %d '
720                                  'of second %s\n' % (len1, seq_type_name))
721        standardMsg = differing
722        diffMsg = '\n' + '\n'.join(
723            difflib.ndiff(pprint.pformat(seq1).splitlines(),
724                          pprint.pformat(seq2).splitlines()))
725        standardMsg = self._truncateMessage(standardMsg, diffMsg)
726        msg = self._formatMessage(msg, standardMsg)
727        self.fail(msg)
728
729    def _truncateMessage(self, message, diff):
730        max_diff = self.maxDiff
731        if max_diff is None or len(diff) <= max_diff:
732            return message + diff
733        return message + (DIFF_OMITTED % len(diff))
734
735    def assertListEqual(self, list1, list2, msg=None):
736        """A list-specific equality assertion.
737
738        Args:
739            list1: The first list to compare.
740            list2: The second list to compare.
741            msg: Optional message to use on failure instead of a list of
742                    differences.
743
744        """
745        self.assertSequenceEqual(list1, list2, msg, seq_type=list)
746
747    def assertTupleEqual(self, tuple1, tuple2, msg=None):
748        """A tuple-specific equality assertion.
749
750        Args:
751            tuple1: The first tuple to compare.
752            tuple2: The second tuple to compare.
753            msg: Optional message to use on failure instead of a list of
754                    differences.
755        """
756        self.assertSequenceEqual(tuple1, tuple2, msg, seq_type=tuple)
757
758    def assertSetEqual(self, set1, set2, msg=None):
759        """A set-specific equality assertion.
760
761        Args:
762            set1: The first set to compare.
763            set2: The second set to compare.
764            msg: Optional message to use on failure instead of a list of
765                    differences.
766
767        assertSetEqual uses ducktyping to support different types of sets, and
768        is optimized for sets specifically (parameters must support a
769        difference method).
770        """
771        try:
772            difference1 = set1.difference(set2)
773        except TypeError, e:
774            self.fail('invalid type when attempting set difference: %s' % e)
775        except AttributeError, e:
776            self.fail('first argument does not support set difference: %s' % e)
777
778        try:
779            difference2 = set2.difference(set1)
780        except TypeError, e:
781            self.fail('invalid type when attempting set difference: %s' % e)
782        except AttributeError, e:
783            self.fail('second argument does not support set difference: %s' % e)
784
785        if not (difference1 or difference2):
786            return
787
788        lines = []
789        if difference1:
790            lines.append('Items in the first set but not the second:')
791            for item in difference1:
792                lines.append(repr(item))
793        if difference2:
794            lines.append('Items in the second set but not the first:')
795            for item in difference2:
796                lines.append(repr(item))
797
798        standardMsg = '\n'.join(lines)
799        self.fail(self._formatMessage(msg, standardMsg))
800
801    def assertIn(self, member, container, msg=None):
802        """Just like self.assertTrue(a in b), but with a nicer default message."""
803        if member not in container:
804            standardMsg = '%s not found in %s' % (safe_repr(member),
805                                                  safe_repr(container))
806            self.fail(self._formatMessage(msg, standardMsg))
807
808    def assertNotIn(self, member, container, msg=None):
809        """Just like self.assertTrue(a not in b), but with a nicer default message."""
810        if member in container:
811            standardMsg = '%s unexpectedly found in %s' % (safe_repr(member),
812                                                        safe_repr(container))
813            self.fail(self._formatMessage(msg, standardMsg))
814
815    def assertIs(self, expr1, expr2, msg=None):
816        """Just like self.assertTrue(a is b), but with a nicer default message."""
817        if expr1 is not expr2:
818            standardMsg = '%s is not %s' % (safe_repr(expr1),
819                                             safe_repr(expr2))
820            self.fail(self._formatMessage(msg, standardMsg))
821
822    def assertIsNot(self, expr1, expr2, msg=None):
823        """Just like self.assertTrue(a is not b), but with a nicer default message."""
824        if expr1 is expr2:
825            standardMsg = 'unexpectedly identical: %s' % (safe_repr(expr1),)
826            self.fail(self._formatMessage(msg, standardMsg))
827
828    def assertDictEqual(self, d1, d2, msg=None):
829        self.assertIsInstance(d1, dict, 'First argument is not a dictionary')
830        self.assertIsInstance(d2, dict, 'Second argument is not a dictionary')
831
832        if d1 != d2:
833            standardMsg = '%s != %s' % (safe_repr(d1, True), safe_repr(d2, True))
834            diff = ('\n' + '\n'.join(difflib.ndiff(
835                           pprint.pformat(d1).splitlines(),
836                           pprint.pformat(d2).splitlines())))
837            standardMsg = self._truncateMessage(standardMsg, diff)
838            self.fail(self._formatMessage(msg, standardMsg))
839
840    def assertDictContainsSubset(self, expected, actual, msg=None):
841        """Checks whether actual is a superset of expected."""
842        missing = []
843        mismatched = []
844        for key, value in expected.iteritems():
845            if key not in actual:
846                missing.append(key)
847            elif value != actual[key]:
848                mismatched.append('%s, expected: %s, actual: %s' %
849                                  (safe_repr(key), safe_repr(value),
850                                   safe_repr(actual[key])))
851
852        if not (missing or mismatched):
853            return
854
855        standardMsg = ''
856        if missing:
857            standardMsg = 'Missing: %s' % ','.join(safe_repr(m) for m in
858                                                    missing)
859        if mismatched:
860            if standardMsg:
861                standardMsg += '; '
862            standardMsg += 'Mismatched values: %s' % ','.join(mismatched)
863
864        self.fail(self._formatMessage(msg, standardMsg))
865
866    def assertItemsEqual(self, expected_seq, actual_seq, msg=None):
867        """An unordered sequence specific comparison. It asserts that
868        actual_seq and expected_seq have the same element counts.
869        Equivalent to::
870
871            self.assertEqual(Counter(iter(actual_seq)),
872                             Counter(iter(expected_seq)))
873
874        Asserts that each element has the same count in both sequences.
875        Example:
876            - [0, 1, 1] and [1, 0, 1] compare equal.
877            - [0, 0, 1] and [0, 1] compare unequal.
878        """
879        first_seq, second_seq = list(expected_seq), list(actual_seq)
880        with warnings.catch_warnings():
881            if sys.py3kwarning:
882                # Silence Py3k warning raised during the sorting
883                for _msg in ["(code|dict|type) inequality comparisons",
884                             "builtin_function_or_method order comparisons",
885                             "comparing unequal types"]:
886                    warnings.filterwarnings("ignore", _msg, DeprecationWarning)
887            try:
888                first = collections.Counter(first_seq)
889                second = collections.Counter(second_seq)
890            except TypeError:
891                # Handle case with unhashable elements
892                differences = _count_diff_all_purpose(first_seq, second_seq)
893            else:
894                if first == second:
895                    return
896                differences = _count_diff_hashable(first_seq, second_seq)
897
898        if differences:
899            standardMsg = 'Element counts were not equal:\n'
900            lines = ['First has %d, Second has %d:  %r' % diff for diff in differences]
901            diffMsg = '\n'.join(lines)
902            standardMsg = self._truncateMessage(standardMsg, diffMsg)
903            msg = self._formatMessage(msg, standardMsg)
904            self.fail(msg)
905
906    def assertMultiLineEqual(self, first, second, msg=None):
907        """Assert that two multi-line strings are equal."""
908        self.assertIsInstance(first, basestring,
909                'First argument is not a string')
910        self.assertIsInstance(second, basestring,
911                'Second argument is not a string')
912
913        if first != second:
914            # don't use difflib if the strings are too long
915            if (len(first) > self._diffThreshold or
916                len(second) > self._diffThreshold):
917                self._baseAssertEqual(first, second, msg)
918            firstlines = first.splitlines(True)
919            secondlines = second.splitlines(True)
920            if len(firstlines) == 1 and first.strip('\r\n') == first:
921                firstlines = [first + '\n']
922                secondlines = [second + '\n']
923            standardMsg = '%s != %s' % (safe_repr(first, True),
924                                        safe_repr(second, True))
925            diff = '\n' + ''.join(difflib.ndiff(firstlines, secondlines))
926            standardMsg = self._truncateMessage(standardMsg, diff)
927            self.fail(self._formatMessage(msg, standardMsg))
928
929    def assertLess(self, a, b, msg=None):
930        """Just like self.assertTrue(a < b), but with a nicer default message."""
931        if not a < b:
932            standardMsg = '%s not less than %s' % (safe_repr(a), safe_repr(b))
933            self.fail(self._formatMessage(msg, standardMsg))
934
935    def assertLessEqual(self, a, b, msg=None):
936        """Just like self.assertTrue(a <= b), but with a nicer default message."""
937        if not a <= b:
938            standardMsg = '%s not less than or equal to %s' % (safe_repr(a), safe_repr(b))
939            self.fail(self._formatMessage(msg, standardMsg))
940
941    def assertGreater(self, a, b, msg=None):
942        """Just like self.assertTrue(a > b), but with a nicer default message."""
943        if not a > b:
944            standardMsg = '%s not greater than %s' % (safe_repr(a), safe_repr(b))
945            self.fail(self._formatMessage(msg, standardMsg))
946
947    def assertGreaterEqual(self, a, b, msg=None):
948        """Just like self.assertTrue(a >= b), but with a nicer default message."""
949        if not a >= b:
950            standardMsg = '%s not greater than or equal to %s' % (safe_repr(a), safe_repr(b))
951            self.fail(self._formatMessage(msg, standardMsg))
952
953    def assertIsNone(self, obj, msg=None):
954        """Same as self.assertTrue(obj is None), with a nicer default message."""
955        if obj is not None:
956            standardMsg = '%s is not None' % (safe_repr(obj),)
957            self.fail(self._formatMessage(msg, standardMsg))
958
959    def assertIsNotNone(self, obj, msg=None):
960        """Included for symmetry with assertIsNone."""
961        if obj is None:
962            standardMsg = 'unexpectedly None'
963            self.fail(self._formatMessage(msg, standardMsg))
964
965    def assertIsInstance(self, obj, cls, msg=None):
966        """Same as self.assertTrue(isinstance(obj, cls)), with a nicer
967        default message."""
968        if not isinstance(obj, cls):
969            standardMsg = '%s is not an instance of %r' % (safe_repr(obj), cls)
970            self.fail(self._formatMessage(msg, standardMsg))
971
972    def assertNotIsInstance(self, obj, cls, msg=None):
973        """Included for symmetry with assertIsInstance."""
974        if isinstance(obj, cls):
975            standardMsg = '%s is an instance of %r' % (safe_repr(obj), cls)
976            self.fail(self._formatMessage(msg, standardMsg))
977
978    def assertRaisesRegexp(self, expected_exception, expected_regexp,
979                           callable_obj=None, *args, **kwargs):
980        """Asserts that the message in a raised exception matches a regexp.
981
982        Args:
983            expected_exception: Exception class expected to be raised.
984            expected_regexp: Regexp (re pattern object or string) expected
985                    to be found in error message.
986            callable_obj: Function to be called.
987            args: Extra args.
988            kwargs: Extra kwargs.
989        """
990        context = _AssertRaisesContext(expected_exception, self, expected_regexp)
991        if callable_obj is None:
992            return context
993        with context:
994            callable_obj(*args, **kwargs)
995
996    def assertRegexpMatches(self, text, expected_regexp, msg=None):
997        """Fail the test unless the text matches the regular expression."""
998        if isinstance(expected_regexp, basestring):
999            expected_regexp = re.compile(expected_regexp)
1000        if not expected_regexp.search(text):
1001            msg = msg or "Regexp didn't match"
1002            msg = '%s: %r not found in %r' % (msg, expected_regexp.pattern, text)
1003            raise self.failureException(msg)
1004
1005    def assertNotRegexpMatches(self, text, unexpected_regexp, msg=None):
1006        """Fail the test if the text matches the regular expression."""
1007        if isinstance(unexpected_regexp, basestring):
1008            unexpected_regexp = re.compile(unexpected_regexp)
1009        match = unexpected_regexp.search(text)
1010        if match:
1011            msg = msg or "Regexp matched"
1012            msg = '%s: %r matches %r in %r' % (msg,
1013                                               text[match.start():match.end()],
1014                                               unexpected_regexp.pattern,
1015                                               text)
1016            raise self.failureException(msg)
1017
1018
1019class FunctionTestCase(TestCase):
1020    """A test case that wraps a test function.
1021
1022    This is useful for slipping pre-existing test functions into the
1023    unittest framework. Optionally, set-up and tidy-up functions can be
1024    supplied. As with TestCase, the tidy-up ('tearDown') function will
1025    always be called if the set-up ('setUp') function ran successfully.
1026    """
1027
1028    def __init__(self, testFunc, setUp=None, tearDown=None, description=None):
1029        super(FunctionTestCase, self).__init__()
1030        self._setUpFunc = setUp
1031        self._tearDownFunc = tearDown
1032        self._testFunc = testFunc
1033        self._description = description
1034
1035    def setUp(self):
1036        if self._setUpFunc is not None:
1037            self._setUpFunc()
1038
1039    def tearDown(self):
1040        if self._tearDownFunc is not None:
1041            self._tearDownFunc()
1042
1043    def runTest(self):
1044        self._testFunc()
1045
1046    def id(self):
1047        return self._testFunc.__name__
1048
1049    def __eq__(self, other):
1050        if not isinstance(other, self.__class__):
1051            return NotImplemented
1052
1053        return self._setUpFunc == other._setUpFunc and \
1054               self._tearDownFunc == other._tearDownFunc and \
1055               self._testFunc == other._testFunc and \
1056               self._description == other._description
1057
1058    def __ne__(self, other):
1059        return not self == other
1060
1061    def __hash__(self):
1062        return hash((type(self), self._setUpFunc, self._tearDownFunc,
1063                     self._testFunc, self._description))
1064
1065    def __str__(self):
1066        return "%s (%s)" % (strclass(self.__class__),
1067                            self._testFunc.__name__)
1068
1069    def __repr__(self):
1070        return "<%s tec=%s>" % (strclass(self.__class__),
1071                                     self._testFunc)
1072
1073    def shortDescription(self):
1074        if self._description is not None:
1075            return self._description
1076        doc = self._testFunc.__doc__
1077        return doc and doc.split("\n")[0].strip() or None
1078