1"""Copyright 2011-2015 Herman Sheremetyev, Slavek Kabrda. All rights reserved.
2
3Redistribution and use in source and binary forms, with or without modification,
4are permitted provided that the following conditions are met:
5
6   1. Redistributions of source code must retain the above copyright notice,
7      this list of conditions and the following disclaimer.
8
9   2. Redistributions in binary form must reproduce the above copyright notice,
10      this list of conditions and the following disclaimer in the documentation
11      and/or other materials provided with the distribution.
12
13THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED
14WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
15MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
16EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
17INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
18LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
19PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
20LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE
21OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
22ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23"""
24
25
26# from flexmock import * is evil, keep it from doing any damage
27__all__ = ['flexmock']
28__version__ = '0.10.10'
29
30
31import inspect
32import re
33import sys
34import types
35
36AT_LEAST = 'at least'
37AT_MOST = 'at most'
38EXACTLY = 'exactly'
39UPDATED_ATTRS = ['should_receive', 'should_call', 'new_instances']
40DEFAULT_CLASS_ATTRIBUTES = [attr for attr in dir(type)
41                            if attr not in dir(type('', (object,), {}))]
42RE_TYPE = re.compile('')
43SPECIAL_METHODS = (classmethod, staticmethod)
44
45
46class FlexmockError(Exception):
47    pass
48
49
50class MockBuiltinError(Exception):
51    pass
52
53
54class MethodSignatureError(FlexmockError):
55    pass
56
57
58class ExceptionClassError(FlexmockError):
59    pass
60
61
62class ExceptionMessageError(FlexmockError):
63    pass
64
65
66class StateError(FlexmockError):
67    pass
68
69
70class MethodCallError(FlexmockError):
71    pass
72
73
74class CallOrderError(FlexmockError):
75    pass
76
77
78class ReturnValue(object):
79    def __init__(self, value=None, raises=None):
80        self.value = value
81        self.raises = raises
82
83    def __str__(self):
84        if self.raises:
85            return '%s(%s)' % (self.raises, _arg_to_str(self.value))
86        else:
87            if not isinstance(self.value, tuple):
88                return '%s' % _arg_to_str(self.value)
89            elif len(self.value) == 1:
90                return '%s' % _arg_to_str(self.value[0])
91            else:
92                return '(%s)' % ', '.join([_arg_to_str(x) for x in self.value])
93
94
95class FullArgSpec(object):
96    """Silly hack for inpsect.getargspec return a tuple on python <2.6"""
97    def __init__(self, spec):
98        if len(spec) == 4:  # python2 => getargspec was used
99            spec += ([], None, {})
100        (self.args, self.varargs, self.keywords, self.defaults, self.kwonlyargs,
101         self.kwonlydefaults, self.annotations) = spec
102
103
104class FlexmockContainer(object):
105    """Holds global hash of object/expectation mappings."""
106    flexmock_objects = {}
107    properties = {}
108    ordered = []
109    last = None
110
111    @classmethod
112    def reset(cls):
113        cls.ordered = []
114        cls.last = None
115        cls.flexmock_objects = {}
116        cls.properties = {}
117
118    @classmethod
119    def get_flexmock_expectation(cls, obj, name=None, args=None):
120        """Retrieves an existing matching expectation."""
121        if args is None:
122            args = {'kargs': (), 'kwargs': {}}
123        if not isinstance(args, dict):
124            args = {'kargs': args, 'kwargs': {}}
125        if not isinstance(args['kargs'], tuple):
126            args['kargs'] = (args['kargs'],)
127        if name and obj in cls.flexmock_objects:
128            found = None
129            for e in reversed(cls.flexmock_objects[obj]):
130                if e.name == name and e.match_args(args):
131                    if e in cls.ordered or not e._ordered and not found:
132                        found = e
133            if found and found._ordered:
134                cls._verify_call_order(found, args)
135            return found
136
137    @classmethod
138    def _verify_call_order(cls, expectation, args):
139        if not cls.ordered:
140            next_method = cls.last
141        else:
142            next_method = cls.ordered.pop(0)
143            cls.last = next_method
144        if expectation is not next_method:
145            raise CallOrderError(
146                '%s called before %s' %
147                (_format_args(expectation.name, args),
148                 _format_args(next_method.name, next_method.args)))
149
150    @classmethod
151    def add_expectation(cls, obj, expectation):
152        if obj in cls.flexmock_objects:
153            cls.flexmock_objects[obj].append(expectation)
154        else:
155            cls.flexmock_objects[obj] = [expectation]
156
157    @classmethod
158    def add_teardown_property(cls, obj, name):
159        if obj in cls.properties:
160            cls.properties[obj].append(name)
161        else:
162            cls.properties[obj] = [name]
163
164    @classmethod
165    def teardown_properties(cls):
166        for obj, names in cls.properties.items():
167            for name in names:
168                delattr(obj, name)
169
170
171class Expectation(object):
172    """Holds expectations about methods.
173
174    The information contained in the Expectation object includes method name,
175    its argument list, return values, and any exceptions that the method might
176    raise.
177    """
178
179    def __init__(self, mock, name=None, return_value=None, original=None, method_type=None):
180        self.name = name
181        self.modifier = EXACTLY
182        if original is not None:
183            self.original = original
184        self.args = None
185        self.method_type = method_type
186        self.argspec = None
187        value = ReturnValue(return_value)
188        self.return_values = return_values = []
189        self._replace_with = None
190        if return_value is not None:
191            return_values.append(value)
192        self.times_called = 0
193        self.expected_calls = {
194            EXACTLY: None,
195            AT_LEAST: None,
196            AT_MOST: None}
197        self.runnable = lambda: True
198        self._mock = mock
199        self._pass_thru = False
200        self._ordered = False
201        self._one_by_one = False
202        self._verified = False
203        self._callable = True
204        self._local_override = False
205
206    def __str__(self):
207        return '%s -> (%s)' % (_format_args(self.name, self.args),
208                               ', '.join(['%s' % x for x in self.return_values]))
209
210    def __call__(self):
211        return self
212
213    def __getattribute__(self, name):
214        if name == 'once':
215            return _getattr(self, 'times')(1)
216        elif name == 'twice':
217            return _getattr(self, 'times')(2)
218        elif name == 'never':
219            return _getattr(self, 'times')(0)
220        elif name in ('at_least', 'at_most', 'ordered', 'one_by_one'):
221            return _getattr(self, name)()
222        elif name == 'mock':
223            return _getattr(self, 'mock')()
224        else:
225            return _getattr(self, name)
226
227    def __getattr__(self, name):
228        self.__raise(
229            AttributeError,
230            "'%s' object has not attribute '%s'" % (self.__class__.__name__, name))
231
232    def _get_runnable(self):
233        """Ugly hack to get the name of when() condition from the source code."""
234        name = 'condition'
235        try:
236            source = inspect.getsource(self.runnable)
237            if 'when(' in source:
238                name = source.split('when(')[1].split(')')[0]
239            elif 'def ' in source:
240                name = source.split('def ')[1].split('(')[0]
241        except:  # couldn't get the source, oh well
242            pass
243        return name
244
245    def _verify_signature_match(self, *kargs, **kwargs):
246        if isinstance(self._mock, Mock):
247            return  # no sense in enforcing this for fake objects
248        allowed = self.argspec
249        args_len = len(allowed.args)
250
251        # self is the first expected argument
252        has_self = allowed.args and allowed.args[0] == "self"
253        # Builtin methods take `self` as the first argument but `inspect.ismethod` returns False
254        # so we need to check for them explicitly
255        is_builtin_method = isinstance(self.original, types.BuiltinMethodType) and has_self
256        # Methods take `self` if not a staticmethod
257        is_method = inspect.ismethod(self.original) and self.method_type is not staticmethod
258        # Class init takes `self`
259        is_class = inspect.isclass(self.original)
260        # When calling class methods or instance methods on a class method takes `cls`
261        is_class_method = (
262            inspect.isfunction(self.original)
263            and inspect.isclass(self._mock)
264            and self.method_type is not staticmethod
265        )
266        if is_builtin_method or is_method or is_class or is_class_method:
267            # Do not count `self` or `cls`.
268            args_len -= 1
269        minimum = args_len - (allowed.defaults and len(allowed.defaults) or 0)
270        maximum = None
271        if allowed.varargs is None and allowed.keywords is None:
272            maximum = args_len
273        total_positional = len(
274            kargs + tuple(a for a in kwargs if a in allowed.args))
275        named_optionals = [a for a in kwargs
276                           if allowed.defaults
277                           if a in allowed.args[len(allowed.args) - len(allowed.defaults):]]
278        if allowed.defaults and total_positional == minimum and named_optionals:
279            minimum += len(named_optionals)
280        if total_positional < minimum:
281            raise MethodSignatureError(
282                '%s requires at least %s arguments, expectation provided %s' %
283                (self.name, minimum, total_positional))
284        if maximum is not None and total_positional > maximum:
285            raise MethodSignatureError(
286                '%s requires at most %s arguments, expectation provided %s' %
287                (self.name, maximum, total_positional))
288        if args_len == len(kargs) and any(a for a in kwargs if a in allowed.args):
289            raise MethodSignatureError(
290                '%s already given as positional arguments to %s' %
291                ([a for a in kwargs if a in allowed.args], self.name))
292        if (not allowed.keywords and
293                any(a for a in kwargs if a not in allowed.args + allowed.kwonlyargs)):
294            raise MethodSignatureError(
295                '%s is not a valid keyword argument to %s' %
296                ([a for a in kwargs
297                  if a not in (allowed.args + allowed.kwonlyargs)][0], self.name))
298        # check that kwonlyargs that don't have default value specified are provided
299        required_kwonlyargs = [a for a in allowed.kwonlyargs
300                               if a not in (allowed.kwonlydefaults or {})]
301        missing_kwonlyargs = [a for a in required_kwonlyargs if a not in kwargs]
302        if missing_kwonlyargs:
303            raise MethodSignatureError(
304                '%s requires keyword-only argument(s) "%s"' %
305                (self.name, '", "'.join(missing_kwonlyargs)))
306
307    def _update_original(self, name, obj):
308        if hasattr(obj, '__dict__') and name in obj.__dict__:
309            self.original = obj.__dict__[name]
310        else:
311            self.original = getattr(obj, name)
312        self._update_argspec()
313
314    def _update_argspec(self):
315        original = self.__dict__.get('original')
316        if original:
317            try:
318                if sys.version_info < (3, 0):
319                    self.argspec = FullArgSpec(inspect.getargspec(original))
320                else:
321                    self.argspec = FullArgSpec(inspect.getfullargspec(original))
322            except TypeError:
323                # built-in function: fall back to stupid processing and hope the
324                # builtins don't change signature
325                pass
326
327    def _normalize_named_args(self, *kargs, **kwargs):
328        argspec = self.argspec
329        default = {'kargs': kargs, 'kwargs': kwargs}
330        if not argspec:
331            return default
332        ret = {'kargs': (), 'kwargs': kwargs}
333        if inspect.ismethod(self.original):
334            args = argspec.args[1:]
335        else:
336            args = argspec.args
337        for i, arg in enumerate(kargs):
338            if len(args) <= i:
339                return default
340            ret['kwargs'][args[i]] = arg
341        return ret
342
343    def __raise(self, exception, message):
344        """Safe internal raise implementation.
345
346        In case we're patching builtins, it's important to reset the
347        expectation before raising any exceptions or else things like
348        open() might be stubbed out and the resulting runner errors are very
349        difficult to diagnose.
350        """
351        self.reset()
352        raise exception(message)
353
354    def match_args(self, given_args):
355        """Check if the set of given arguments matches this expectation."""
356        expected_args = self.args
357        given_args = self._normalize_named_args(
358            *given_args['kargs'], **given_args['kwargs'])
359        if (expected_args == given_args or expected_args is None):
360            return True
361        if (len(given_args['kargs']) != len(expected_args['kargs']) or
362            len(given_args['kwargs']) != len(expected_args['kwargs']) or
363            (sorted(given_args['kwargs'].keys()) !=
364             sorted(expected_args['kwargs'].keys()))):
365            return False
366        for i, arg in enumerate(given_args['kargs']):
367            if not _arguments_match(arg, expected_args['kargs'][i]):
368                return False
369        for k, v in given_args['kwargs'].items():
370            if not _arguments_match(v, expected_args['kwargs'][k]):
371                return False
372        return True
373
374    def mock(self):
375        """Return the mock associated with this expectation."""
376        return self._mock
377
378    def with_args(self, *kargs, **kwargs):
379        """Override the arguments used to match this expectation's method.
380
381        Args:
382          - kargs: optional keyword arguments
383          - kwargs: optional named arguments
384
385        Returns:
386          - self, i.e. can be chained with other Expectation methods
387        """
388        if not self._callable:
389            self.__raise(FlexmockError, "can't use with_args() with attribute stubs")
390        self._update_argspec()
391        if self.argspec:
392            # do this outside try block as TypeError is way too general and catches
393            # unrelated errors in the verify signature code
394            self._verify_signature_match(*kargs, **kwargs)
395            self.args = self._normalize_named_args(*kargs, **kwargs)
396        else:
397            self.args = {'kargs': kargs, 'kwargs': kwargs}
398        return self
399
400    def and_return(self, *values):
401        """Override the return value of this expectation's method.
402
403        When and_return is given multiple times, each value provided is returned
404        on successive invocations of the method. It is also possible to mix
405        and_return with and_raise in the same manner to alternate between returning
406        a value and raising and exception on different method invocations.
407
408        When combined with the one_by_one property, value is treated as a list of
409        values to be returned in the order specified by successive calls to this
410        method rather than a single list to be returned each time.
411
412        Args:
413          - values: optional list of return values, defaults to None if not given
414
415        Returns:
416          - self, i.e. can be chained with other Expectation methods
417        """
418        if not values:
419            value = None
420        elif len(values) == 1:
421            value = values[0]
422        else:
423            value = values
424
425        if not self._callable:
426            _setattr(self._mock, self.name, value)
427            return self
428
429        return_values = _getattr(self, 'return_values')
430        if not _getattr(self, '_one_by_one'):
431            value = ReturnValue(value)
432            return_values.append(value)
433        else:
434            try:
435                return_values.extend([ReturnValue(v) for v in value])
436            except TypeError:
437                return_values.append(ReturnValue(value))
438        return self
439
440    def times(self, number):
441        """Number of times this expectation's method is expected to be called.
442
443        There are also 3 aliases for the times() method:
444
445          - once() -> times(1)
446          - twice() -> times(2)
447          - never() -> times(0)
448
449        Args:
450          - number: int
451
452        Returns:
453          - self, i.e. can be chained with other Expectation methods
454        """
455        if not self._callable:
456            self.__raise(FlexmockError, "can't use times() with attribute stubs")
457        expected_calls = _getattr(self, 'expected_calls')
458        modifier = _getattr(self, 'modifier')
459        expected_calls[modifier] = number
460        return self
461
462    def one_by_one(self):
463        """Modifies the return value to be treated as a list of return values.
464
465        Each value in the list is returned on successive invocations of the method.
466
467        Returns:
468          - self, i.e. can be chained with other Expectation methods
469        """
470        if not self._callable:
471            self.__raise(FlexmockError, "can't use one_by_one() with attribute stubs")
472        if not self._one_by_one:
473            self._one_by_one = True
474            return_values = _getattr(self, 'return_values')
475            saved_values = return_values[:]
476            self.return_values = return_values = []
477            for value in saved_values:
478                try:
479                    for val in value.value:
480                        return_values.append(ReturnValue(val))
481                except TypeError:
482                    return_values.append(value)
483        return self
484
485    def at_least(self):
486        """Modifies the associated times() expectation.
487
488        When given, an exception will only be raised if the method is called less
489        than times() specified. Does nothing if times() is not given.
490
491        Returns:
492          - self, i.e. can be chained with other Expectation methods
493        """
494        if not self._callable:
495            self.__raise(FlexmockError, "can't use at_least() with attribute stubs")
496        expected_calls = _getattr(self, 'expected_calls')
497        modifier = _getattr(self, 'modifier')
498        if expected_calls[AT_LEAST] is not None or modifier == AT_LEAST:
499            self.__raise(FlexmockError, 'cannot use at_least modifier twice')
500        if modifier == AT_MOST and expected_calls[AT_MOST] is None:
501            self.__raise(FlexmockError, 'cannot use at_least with at_most unset')
502        self.modifier = AT_LEAST
503        return self
504
505    def at_most(self):
506        """Modifies the associated "times" expectation.
507
508        When given, an exception will only be raised if the method is called more
509        than times() specified. Does nothing if times() is not given.
510
511        Returns:
512          - self, i.e. can be chained with other Expectation methods
513        """
514        if not self._callable:
515            self.__raise(FlexmockError, "can't use at_most() with attribute stubs")
516        expected_calls = _getattr(self, 'expected_calls')
517        modifier = _getattr(self, 'modifier')
518        if expected_calls[AT_MOST] is not None or modifier == AT_MOST:
519            self.__raise(FlexmockError, 'cannot use at_most modifier twice')
520        if modifier == AT_LEAST and expected_calls[AT_LEAST] is None:
521            self.__raise(FlexmockError, 'cannot use at_most with at_least unset')
522        self.modifier = AT_MOST
523        return self
524
525    def ordered(self):
526        """Makes the expectation respect the order of should_receive statements.
527
528        An exception will be raised if methods are called out of order, determined
529        by order of should_receive calls in the test.
530
531        Returns:
532          - self, i.e. can be chained with other Expectation methods
533        """
534        if not self._callable:
535            self.__raise(FlexmockError, "can't use ordered() with attribute stubs")
536        self._ordered = True
537        FlexmockContainer.ordered.append(self)
538        return self
539
540    def when(self, func):
541        """Sets an outside resource to be checked before executing the method.
542
543        Args:
544          - func: function to call to check if the method should be executed
545
546        Returns:
547          - self, i.e. can be chained with other Expectation methods
548        """
549        if not self._callable:
550            self.__raise(FlexmockError, "can't use when() with attribute stubs")
551        if not hasattr(func, '__call__'):
552            self.__raise(FlexmockError, 'when() parameter must be callable')
553        self.runnable = func
554        return self
555
556    def and_raise(self, exception, *kargs, **kwargs):
557        """Specifies the exception to be raised when this expectation is met.
558
559        Args:
560          - exception: class or instance of the exception
561          - kargs: optional keyword arguments to pass to the exception
562          - kwargs: optional named arguments to pass to the exception
563
564        Returns:
565          - self, i.e. can be chained with other Expectation methods
566        """
567        if not self._callable:
568            self.__raise(FlexmockError, "can't use and_raise() with attribute stubs")
569        args = {'kargs': kargs, 'kwargs': kwargs}
570        return_values = _getattr(self, 'return_values')
571        return_values.append(ReturnValue(raises=exception, value=args))
572        return self
573
574    def replace_with(self, function):
575        """Gives a function to run instead of the mocked out one.
576
577        Args:
578          - function: callable
579
580        Returns:
581          - self, i.e. can be chained with other Expectation methods
582        """
583        if not self._callable:
584            self.__raise(FlexmockError, "can't use replace_with() with attribute/property stubs")
585        replace_with = _getattr(self, '_replace_with')
586        original = self.__dict__.get('original')
587        if replace_with:
588            self.__raise(FlexmockError, 'replace_with cannot be specified twice')
589        if function == original:
590            self._pass_thru = True
591        self._replace_with = function
592        return self
593
594    def and_yield(self, *kargs):
595        """Specifies the list of items to be yielded on successive method calls.
596
597        In effect, the mocked object becomes a generator.
598
599        Returns:
600          - self, i.e. can be chained with other Expectation methods
601        """
602        if not self._callable:
603            self.__raise(
604                FlexmockError, "can't use and_yield() with attribute stubs")
605        return self.and_return(iter(kargs))
606
607    def verify(self, final=True):
608        """Verify that this expectation has been met.
609
610        Args:
611          final: boolean, True if no further calls to this method expected
612                 (skip checking at_least expectations when False)
613
614        Raises:
615          MethodCallError Exception
616        """
617        failed, message = self._verify_number_of_calls(final)
618        if failed and not self._verified:
619            self._verified = True
620            self.__raise(
621                MethodCallError,
622                '%s expected to be called %s times, called %s times' %
623                (_format_args(self.name, self.args), message, self.times_called))
624
625    def _verify_number_of_calls(self, final):
626        failed = False
627        message = ''
628        expected_calls = _getattr(self, 'expected_calls')
629        times_called = _getattr(self, 'times_called')
630        if expected_calls[EXACTLY] is not None:
631            message = 'exactly %s' % expected_calls[EXACTLY]
632            if final:
633                if times_called != expected_calls[EXACTLY]:
634                    failed = True
635            else:
636                if times_called > expected_calls[EXACTLY]:
637                    failed = True
638        else:
639            if final and expected_calls[AT_LEAST] is not None:
640                message = 'at least %s' % expected_calls[AT_LEAST]
641                if times_called < expected_calls[AT_LEAST]:
642                    failed = True
643            if expected_calls[AT_MOST] is not None:
644                if message:
645                    message += ' and '
646                message += 'at most %s' % expected_calls[AT_MOST]
647                if times_called > expected_calls[AT_MOST]:
648                    failed = True
649        return failed, message
650
651    def reset(self):
652        """Returns the methods overriden by this expectation to their originals."""
653        _mock = _getattr(self, '_mock')
654        if not isinstance(_mock, Mock):
655            original = self.__dict__.get('original')
656            if original:
657                # name may be unicode but pypy demands dict keys to be str
658                name = str(_getattr(self, 'name'))
659                if (hasattr(_mock, '__dict__') and
660                        name in _mock.__dict__ and
661                        self._local_override):
662                    delattr(_mock, name)
663                elif (hasattr(_mock, '__dict__') and
664                        name in _mock.__dict__ and
665                        type(_mock.__dict__) is dict):
666                    _mock.__dict__[name] = original
667                else:
668                    if self.method_type == staticmethod and sys.version_info < (3, 0):
669                        # on some Python 2 implementations (e.g. pypy), just assigning
670                        # the original staticmethod would make it a normal method,
671                        # thus an additional "self" argument would be passed to it,
672                        # we need to explicitly cast it to staticmethod
673                        setattr(_mock, name, staticmethod(original))
674                    else:
675                        setattr(_mock, name, original)
676        del self
677
678
679class Mock(object):
680    """Fake object class returned by the flexmock() function."""
681
682    def __init__(self, **kwargs):
683        """Mock constructor.
684
685        Args:
686          - kwargs: dict of attribute/value pairs used to initialize the mock object
687        """
688        self._object = self
689        for attr, value in kwargs.items():
690            if type(value) is property:
691                setattr(self.__class__, attr, value)
692            else:
693                setattr(self, attr, value)
694
695    def __enter__(self):
696        return self._object
697
698    def __exit__(self, type, value, traceback):
699        pass
700
701    def __call__(self, *kargs, **kwargs):
702        """Hack to make Expectation.mock() work with parens."""
703        return self
704
705    def __iter__(self):
706        """Makes the mock object iterable.
707
708        Call the instance's version of __iter__ if available, otherwise yield self.
709        """
710        if (hasattr(self, '__dict__') and type(self.__dict__) is dict and
711                '__iter__' in self.__dict__):
712            for item in self.__dict__['__iter__'](self):
713                yield item
714        else:
715            yield self
716
717    def should_receive(self, name):
718        """Replaces the specified attribute with a fake.
719
720        Args:
721          - name: string name of the attribute to replace
722
723        Returns:
724          - Expectation object which can be used to modify the expectations
725            on the fake attribute
726        """
727        if name in UPDATED_ATTRS:
728            raise FlexmockError('unable to replace flexmock methods')
729        chained_methods = None
730        obj = _getattr(self, '_object')
731        if '.' in name:
732            name, chained_methods = name.split('.', 1)
733        name = _update_name_if_private(obj, name)
734        _ensure_object_has_named_attribute(obj, name)
735        if chained_methods:
736            if (not isinstance(obj, Mock) and not hasattr(getattr(obj, name), '__call__')):
737                return_value = _create_partial_mock(getattr(obj, name))
738            else:
739                return_value = Mock()
740            self._create_expectation(obj, name, return_value)
741            return return_value.should_receive(chained_methods)
742        else:
743            return self._create_expectation(obj, name)
744
745    def should_call(self, name):
746        """Creates a spy.
747
748        This means that the original method will be called rather than the fake
749        version. However, we can still keep track of how many times it's called and
750        with what arguments, and apply expectations accordingly.
751
752        should_call is meaningless/not allowed for non-callable attributes.
753
754        Args:
755          - name: string name of the method
756
757        Returns:
758          - Expectation object
759        """
760        expectation = self.should_receive(name)
761        return expectation.replace_with(expectation.__dict__.get('original'))
762
763    def new_instances(self, *kargs):
764        """Overrides __new__ method on the class to return custom objects.
765
766        Alias for should_receive('__new__').and_return(kargs).one_by_one
767
768        Args:
769          - kargs: objects to return on each successive call to __new__
770
771        Returns:
772          - Expectation object
773        """
774        if _isclass(self._object):
775            return self.should_receive('__new__').and_return(kargs).one_by_one
776        else:
777            raise FlexmockError('new_instances can only be called on a class mock')
778
779    def _create_expectation(self, obj, name, return_value=None):
780        if self not in FlexmockContainer.flexmock_objects:
781            FlexmockContainer.flexmock_objects[self] = []
782        expectation = self._save_expectation(name, return_value)
783        FlexmockContainer.add_expectation(self, expectation)
784        if _isproperty(obj, name):
785            self._update_property(expectation, name, return_value)
786        elif (isinstance(obj, Mock) or
787              hasattr(getattr(obj, name), '__call__') or
788              _isclass(getattr(obj, name))):
789            self._update_method(expectation, name)
790        else:
791            self._update_attribute(expectation, name, return_value)
792        return expectation
793
794    def _save_expectation(self, name, return_value=None):
795        if name in [x.name for x in
796                    FlexmockContainer.flexmock_objects[self]]:
797            expectation = [x for x in FlexmockContainer.flexmock_objects[self]
798                           if x.name == name][0]
799            expectation = Expectation(
800                self._object, name=name, return_value=return_value,
801                original=expectation.__dict__.get('original'),
802                method_type=expectation.__dict__.get("method_type"))
803        else:
804            expectation = Expectation(
805                self._object, name=name, return_value=return_value)
806        return expectation
807
808    def _update_class_for_magic_builtins(self, obj, name):
809        """Fixes MRO for builtin methods on new-style objects.
810
811        On 2.7+ and 3.2+, replacing magic builtins on instances of new-style
812        classes has no effect as the one attached to the class takes precedence.
813        To work around it, we update the class' method to check if the instance
814        in question has one in its own __dict__ and call that instead.
815        """
816        if not (name.startswith('__') and name.endswith('__') and len(name) > 4):
817            return
818        original = getattr(obj.__class__, name)
819
820        def updated(self, *kargs, **kwargs):
821            if (hasattr(self, '__dict__') and type(self.__dict__) is dict and
822                    name in self.__dict__):
823                return self.__dict__[name](*kargs, **kwargs)
824            else:
825                return original(self, *kargs, **kwargs)
826        setattr(obj.__class__, name, updated)
827        if _get_code(updated) != _get_code(original):
828            self._create_placeholder_mock_for_proper_teardown(
829                obj.__class__, name, original)
830
831    def _create_placeholder_mock_for_proper_teardown(self, obj, name, original):
832        """Ensures that the given function is replaced on teardown."""
833        mock = Mock()
834        mock._object = obj
835        expectation = Expectation(obj, name=name, original=original)
836        FlexmockContainer.add_expectation(mock, expectation)
837
838    def _update_method(self, expectation, name):
839        method_instance = self._create_mock_method(name)
840        obj = self._object
841        if _hasattr(obj, name) and not hasattr(expectation, "original"):
842            expectation._update_original(name, obj)
843            expectation.method_type = self._get_method_type(obj, name, expectation.original)
844            if expectation.method_type in SPECIAL_METHODS:
845                expectation.original_function = getattr(obj, name)
846        if (
847            not _isclass(obj)
848            or expectation.method_type in SPECIAL_METHODS
849            or name == '__new__'
850        ):
851            method_instance = types.MethodType(method_instance, obj)
852        override = _setattr(obj, name, method_instance)
853        expectation._local_override = override
854        if (override and not _isclass(obj) and not isinstance(obj, Mock) and
855                hasattr(obj.__class__, name)):
856            self._update_class_for_magic_builtins(obj, name)
857
858    def _get_method_type(self, obj, name, method):
859        """Get method type of the original method.
860
861        Method type is saved because after mocking the base class, it is difficult to determine
862        the original method type.
863        """
864        if not inspect.isclass(obj) and not hasattr(obj, "__class__"):
865            return type(method)
866
867        method_type = self._get_saved_method_type(obj, name, method)
868        if method_type is not None:
869            return method_type
870        if _is_class_method(method, name):
871            method_type = classmethod
872        elif _is_static_method(obj, name):
873            method_type = staticmethod
874        else:
875            method_type = type(method)
876        setattr(obj, "%s__flexmock__method_type" % name, method_type)
877        return method_type
878
879    def _get_saved_method_type(self, obj, name, method):
880        """Check method type of the original method if it was saved to the class or base class."""
881        bound_to = getattr(method, "__self__", None)
882        if bound_to is not None and inspect.isclass(bound_to):
883            # Check if the method type was saved in a base class
884            for cls in inspect.getmro(bound_to):
885                method_type = vars(cls).get("%s__flexmock__method_type" % name)
886                if method_type:
887                    return method_type
888        elif inspect.isclass(obj):
889            method_type = vars(obj).get("%s__flexmock__method_type" % name)
890            if method_type:
891                return method_type
892        return None
893
894    def _update_attribute(self, expectation, name, return_value=None):
895        obj = self._object
896        expectation._callable = False
897        if _hasattr(obj, name) and not hasattr(expectation, 'original'):
898            expectation._update_original(name, obj)
899        override = _setattr(obj, name, return_value)
900        expectation._local_override = override
901
902    def _update_property(self, expectation, name, return_value=None):
903        new_name = '_flexmock__%s' % name
904        obj = self._object
905        if not _isclass(obj):
906            obj = obj.__class__
907        expectation._callable = False
908        original = getattr(obj, name)
909
910        @property
911        def updated(self):
912            if (hasattr(self, '__dict__') and type(self.__dict__) is dict and
913                    name in self.__dict__):
914                return self.__dict__[name]
915            else:
916                return getattr(self, new_name)
917        setattr(obj, name, updated)
918        if not hasattr(obj, new_name):
919            # don't try to double update
920            FlexmockContainer.add_teardown_property(obj, new_name)
921            setattr(obj, new_name, original)
922            self._create_placeholder_mock_for_proper_teardown(obj, name, original)
923
924    def _create_mock_method(self, name):
925        def _handle_exception_matching(expectation):
926            return_values = _getattr(expectation, 'return_values')
927            if return_values:
928                raised, instance = sys.exc_info()[:2]
929                message = '%s' % instance
930                expected = return_values[0].raises
931                if not expected:
932                    raise
933                args = return_values[0].value
934                expected_instance = expected(*args['kargs'], **args['kwargs'])
935                expected_message = '%s' % expected_instance
936                if _isclass(expected):
937                    if expected is not raised and expected not in raised.__bases__:
938                        raise (ExceptionClassError('expected %s, raised %s' %
939                               (expected, raised)))
940                    if args['kargs'] and type(RE_TYPE) is type(args['kargs'][0]):
941                        if not args['kargs'][0].search(message):
942                            raise (ExceptionMessageError('expected /%s/, raised "%s"' %
943                                   (args['kargs'][0].pattern, message)))
944                    elif expected_message and expected_message != message:
945                        raise (ExceptionMessageError('expected "%s", raised "%s"' %
946                               (expected_message, message)))
947                elif expected is not raised:
948                    raise (ExceptionClassError('expected "%s", raised "%s"' %
949                           (expected, raised)))
950            else:
951                raise
952
953        def match_return_values(expected, received):
954            if not isinstance(expected, tuple):
955                expected = (expected,)
956            if not isinstance(received, tuple):
957                received = (received,)
958            if len(received) != len(expected):
959                return False
960            for i, val in enumerate(received):
961                if not _arguments_match(val, expected[i]):
962                    return False
963            return True
964
965        def pass_thru(expectation, runtime_self, *kargs, **kwargs):
966            return_values = None
967            try:
968                original = _getattr(expectation, 'original')
969                _mock = _getattr(expectation, '_mock')
970                if _isclass(_mock):
971                    if expectation.method_type in SPECIAL_METHODS:
972                        original = _getattr(expectation, 'original_function')
973                        return_values = original(*kargs, **kwargs)
974                    else:
975                        return_values = original(runtime_self, *kargs, **kwargs)
976                else:
977                    return_values = original(*kargs, **kwargs)
978            except:
979                return _handle_exception_matching(expectation)
980            expected_values = _getattr(expectation, 'return_values')
981            if (expected_values and
982                    not match_return_values(expected_values[0].value, return_values)):
983                raise (MethodSignatureError('expected to return %s, returned %s' %
984                       (expected_values[0].value, return_values)))
985            return return_values
986
987        def _handle_matched_expectation(expectation, runtime_self, *kargs, **kwargs):
988            if not expectation.runnable():
989                raise StateError('%s expected to be called when %s is True' %
990                                 (name, expectation._get_runnable()))
991            expectation.times_called += 1
992            expectation.verify(final=False)
993            _pass_thru = _getattr(expectation, '_pass_thru')
994            _replace_with = _getattr(expectation, '_replace_with')
995            if _pass_thru:
996                return pass_thru(expectation, runtime_self, *kargs, **kwargs)
997            elif _replace_with:
998                return _replace_with(*kargs, **kwargs)
999            return_values = _getattr(expectation, 'return_values')
1000            if return_values:
1001                return_value = return_values[0]
1002                del return_values[0]
1003                return_values.append(return_value)
1004            else:
1005                return_value = ReturnValue()
1006            if return_value.raises:
1007                if _isclass(return_value.raises):
1008                    raise return_value.raises(
1009                        *return_value.value['kargs'], **return_value.value['kwargs'])
1010                else:
1011                    raise return_value.raises
1012            else:
1013                return return_value.value
1014
1015        def mock_method(runtime_self, *kargs, **kwargs):
1016            arguments = {'kargs': kargs, 'kwargs': kwargs}
1017            expectation = FlexmockContainer.get_flexmock_expectation(
1018                self, name, arguments)
1019            if expectation:
1020                return _handle_matched_expectation(expectation, runtime_self, *kargs, **kwargs)
1021            # inform the user which expectation(s) for the method were _not_ matched
1022            expectations = [
1023                e for e in reversed(FlexmockContainer.flexmock_objects.get(self, []))
1024                if e.name == name
1025            ]
1026            error_msg = _format_args(name, arguments)
1027            if expectations:
1028                for e in expectations:
1029                    error_msg += '\nDid not match expectation %s' % _format_args(name, e.args)
1030            raise MethodSignatureError(error_msg)
1031
1032        return mock_method
1033
1034
1035def _arg_to_str(arg):
1036    if type(RE_TYPE) is type(arg):
1037        return '/%s/' % arg.pattern
1038    if sys.version_info < (3, 0):
1039        # prior to 3.0 unicode strings are type unicode that inherits
1040        # from basestring along with str, in 3.0 both unicode and basestring
1041        # go away and str handles everything properly
1042        if isinstance(arg, basestring):
1043            return '"%s"' % (arg,)
1044        else:
1045            return '%s' % (arg,)
1046    else:
1047        if isinstance(arg, str):
1048            return '"%s"' % (arg,)
1049        else:
1050            return '%s' % (arg,)
1051
1052
1053def _format_args(name, arguments):
1054    if arguments is None:
1055        arguments = {'kargs': (), 'kwargs': {}}
1056    kargs = ', '.join(_arg_to_str(arg) for arg in arguments['kargs'])
1057    kwargs = ', '.join('%s=%s' % (k, _arg_to_str(v)) for k, v in arguments['kwargs'].items())
1058    if kargs and kwargs:
1059        args = '%s, %s' % (kargs, kwargs)
1060    else:
1061        args = '%s%s' % (kargs, kwargs)
1062    return '%s(%s)' % (name, args)
1063
1064
1065def _create_partial_mock(obj_or_class, **kwargs):
1066    matches = [x for x in FlexmockContainer.flexmock_objects
1067               if x._object is obj_or_class]
1068    if matches:
1069        mock = matches[0]
1070    else:
1071        mock = Mock()
1072        mock._object = obj_or_class
1073    for name, return_value in kwargs.items():
1074        if hasattr(return_value, '__call__'):
1075            mock.should_receive(name).replace_with(return_value)
1076        else:
1077            mock.should_receive(name).and_return(return_value)
1078    if not matches:
1079        FlexmockContainer.add_expectation(mock, Expectation(obj_or_class))
1080    if (_attach_flexmock_methods(mock, Mock, obj_or_class) and not _isclass(mock._object)):
1081        mock = mock._object
1082    return mock
1083
1084
1085def _attach_flexmock_methods(mock, flexmock_class, obj):
1086    try:
1087        for attr in UPDATED_ATTRS:
1088            if hasattr(obj, attr):
1089                if (_get_code(getattr(obj, attr)) is not _get_code(getattr(flexmock_class, attr))):
1090                    return False
1091        for attr in UPDATED_ATTRS:
1092            _setattr(obj, attr, getattr(mock, attr))
1093    except TypeError:
1094        raise MockBuiltinError(
1095            'Python does not allow you to mock builtin objects or modules. '
1096            'Consider wrapping it in a class you can mock instead')
1097    except AttributeError:
1098        raise MockBuiltinError(
1099            'Python does not allow you to mock instances of builtin objects. '
1100            'Consider wrapping it in a class you can mock instead')
1101    return True
1102
1103
1104def _get_code(func):
1105    if hasattr(func, 'func_code'):
1106        code = 'func_code'
1107    elif hasattr(func, 'im_func'):
1108        func = func.im_func
1109        code = 'func_code'
1110    else:
1111        code = '__code__'
1112    return getattr(func, code)
1113
1114
1115def _arguments_match(arg, expected_arg):
1116    if expected_arg == arg:
1117        return True
1118    elif _isclass(expected_arg) and isinstance(arg, expected_arg):
1119        return True
1120    elif (type(RE_TYPE) is type(expected_arg) and
1121          expected_arg.search(arg)):
1122        return True
1123    else:
1124        return False
1125
1126
1127def _getattr(obj, name):
1128    """Convenience wrapper to work around custom __getattribute__."""
1129    return object.__getattribute__(obj, name)
1130
1131
1132def _setattr(obj, name, value):
1133    """Ensure we use local __dict__ where possible."""
1134    name = str(name)  # name may be unicode but pypy demands dict keys to be str
1135    local_override = False
1136    if hasattr(obj, '__dict__') and type(obj.__dict__) is dict:
1137        if name not in obj.__dict__:
1138            # Overriding attribute locally on an instance.
1139            local_override = True
1140        obj.__dict__[name] = value
1141    else:
1142        if inspect.isclass(obj) and not vars(obj).get(name):
1143            # Overriding derived attribute locally on a child class.
1144            local_override = True
1145        setattr(obj, name, value)
1146    return local_override
1147
1148
1149def _hasattr(obj, name):
1150    """Ensure hasattr checks don't create side-effects for properties."""
1151    if (not _isclass(obj) and hasattr(obj, '__dict__') and name not in obj.__dict__):
1152        if name in DEFAULT_CLASS_ATTRIBUTES:
1153            return False  # avoid false positives for things like __call__
1154        else:
1155            return hasattr(obj.__class__, name)
1156    else:
1157        return hasattr(obj, name)
1158
1159
1160def _isclass(obj):
1161    """Fixes stupid bug in inspect.isclass from < 2.7."""
1162    if sys.version_info < (2, 7):
1163        return isinstance(obj, (type, types.ClassType))
1164    else:
1165        return inspect.isclass(obj)
1166
1167
1168def _isproperty(obj, name):
1169    if isinstance(obj, Mock):
1170        return False
1171    if not _isclass(obj) and hasattr(obj, '__dict__') and name not in obj.__dict__:
1172        attr = getattr(obj.__class__, name)
1173        if type(attr) is property:
1174            return True
1175    elif _isclass(obj):
1176        attr = getattr(obj, name)
1177        if type(attr) is property:
1178            return True
1179    return False
1180
1181
1182def _update_name_if_private(obj, name):
1183    if (name.startswith('__') and not name.endswith('__') and not inspect.ismodule(obj)):
1184        if _isclass(obj):
1185            class_name = obj.__name__
1186        else:
1187            class_name = obj.__class__.__name__
1188        name = '_%s__%s' % (class_name.lstrip('_'), name.lstrip('_'))
1189    return name
1190
1191
1192def _ensure_object_has_named_attribute(obj, name):
1193    if not isinstance(obj, Mock) and not _hasattr(obj, name):
1194        exc_msg = '%s does not have attribute %s' % (obj, name)
1195        if name == '__new__':
1196            exc_msg = 'old-style classes do not have a __new__() method'
1197        raise FlexmockError(exc_msg)
1198
1199def _is_class_method(method, name):
1200    """Check if a method is a classmethod.
1201
1202    This function checks all the classes in the class method resolution in order
1203    to get the correct result for derived methods as well.
1204    """
1205    bound_to = getattr(method, "__self__", None)
1206    if not inspect.isclass(bound_to):
1207        return False
1208    for cls in inspect.getmro(bound_to):
1209        descriptor = vars(cls).get(name)
1210        if descriptor is not None:
1211            return isinstance(descriptor, classmethod)
1212    return False
1213
1214def _is_static_method(obj, name):
1215    if sys.version_info < (3, 0):
1216        return isinstance(getattr(obj, name), types.FunctionType)
1217    try:
1218        return isinstance(inspect.getattr_static(obj, name), staticmethod)
1219    except AttributeError:
1220        # AttributeError is raised when mocking a proxied object
1221        if hasattr(obj, "__mro__"):
1222            for cls in inspect.getmro(obj):
1223                descriptor = vars(cls).get(name)
1224                if descriptor is not None:
1225                    return isinstance(descriptor, staticmethod)
1226    return False
1227
1228def flexmock_teardown():
1229    """Performs lexmock-specific teardown tasks."""
1230    saved = {}
1231    instances = []
1232    classes = []
1233    for mock_object, expectations in FlexmockContainer.flexmock_objects.items():
1234        saved[mock_object] = expectations[:]
1235        for expectation in expectations:
1236            _getattr(expectation, 'reset')()
1237        for expectation in expectations:
1238            # Remove method type attributes set by flexmock. This needs to be done after
1239            # resetting all the expectations because method type is needed in expectation teardown.
1240            if inspect.isclass(mock_object) or hasattr(mock_object, "__class__"):
1241                try:
1242                    delattr(mock_object._object, "%s__flexmock__method_type" % expectation.name)
1243                except (AttributeError, TypeError):
1244                    pass
1245    for mock in saved.keys():
1246        obj = mock._object
1247        if not isinstance(obj, Mock) and not _isclass(obj):
1248            instances.append(obj)
1249        if _isclass(obj):
1250            classes.append(obj)
1251    for obj in instances + classes:
1252        for attr in UPDATED_ATTRS:
1253            try:
1254                obj_dict = obj.__dict__
1255                if _get_code(obj_dict[attr]) is _get_code(Mock.__dict__[attr]):
1256                    del obj_dict[attr]
1257            except:
1258                try:
1259                    if _get_code(getattr(obj, attr)) is _get_code(Mock.__dict__[attr]):
1260                        delattr(obj, attr)
1261                except AttributeError:
1262                    pass
1263    FlexmockContainer.teardown_properties()
1264    FlexmockContainer.reset()
1265
1266    # make sure this is done last to keep exceptions here from breaking
1267    # any of the previous steps that cleanup all the changes
1268    for mock_object, expectations in saved.items():
1269        for expectation in expectations:
1270            _getattr(expectation, 'verify')()
1271
1272
1273def flexmock(spec=None, **kwargs):
1274    """Main entry point into the flexmock API.
1275
1276    This function is used to either generate a new fake object or take
1277    an existing object (or class or module) and use it as a basis for
1278    a partial mock. In case of a partial mock, the passed in object
1279    is modified to support basic Mock class functionality making
1280    it unnecessary to make successive flexmock() calls on the same
1281    objects to generate new expectations.
1282
1283    Examples:
1284      >>> flexmock(SomeClass)
1285      >>> SomeClass.should_receive('some_method')
1286
1287    NOTE: it's safe to call flexmock() on the same object, it will detect
1288    when an object has already been partially mocked and return it each time.
1289
1290    Args:
1291      - spec: object (or class or module) to mock
1292      - kwargs: method/return_value pairs to attach to the object
1293
1294    Returns:
1295      Mock object if no spec is provided. Otherwise return the spec object.
1296    """
1297    if spec is not None:
1298        return _create_partial_mock(spec, **kwargs)
1299    else:
1300        # use this intermediate class to attach properties
1301        klass = type('MockClass', (Mock,), {})
1302        return klass(**kwargs)
1303
1304
1305# RUNNER INTEGRATION
1306
1307
1308def _hook_into_pytest():
1309    try:
1310        from _pytest import runner
1311        saved = runner.call_runtest_hook
1312
1313        def call_runtest_hook(item, when, **kwargs):
1314            ret = saved(item, when, **kwargs)
1315            if when != 'call' and ret.excinfo is None:
1316                return ret
1317            if hasattr(runner.CallInfo, "from_call"):
1318                teardown = runner.CallInfo.from_call(flexmock_teardown, when=when)
1319                if hasattr(teardown, "duration"):
1320                    # CallInfo.duration only available in Pytest 6+
1321                    teardown.duration = ret.duration
1322            else:
1323                teardown = runner.CallInfo(flexmock_teardown, when=when)
1324                teardown.result = None
1325            if ret.excinfo is not None:
1326                teardown.excinfo = ret.excinfo
1327            return teardown
1328        runner.call_runtest_hook = call_runtest_hook
1329
1330    except ImportError:
1331        pass
1332_hook_into_pytest()
1333
1334
1335def _hook_into_doctest():
1336    try:
1337        from doctest import DocTestRunner
1338        saved = DocTestRunner.run
1339
1340        def run(self, test, compileflags=None, out=None, clear_globs=True):
1341            try:
1342                return saved(self, test, compileflags, out, clear_globs)
1343            finally:
1344                flexmock_teardown()
1345        DocTestRunner.run = run
1346    except ImportError:
1347        pass
1348_hook_into_doctest()
1349
1350
1351def _patch_test_result(klass):
1352    """Patches flexmock into any class that inherits unittest.TestResult.
1353
1354    This seems to work well for majority of test runners. In the case of nose
1355    it's not even necessary as it doesn't override unittest.TestResults's
1356    addSuccess and addFailure methods so simply patching unittest works
1357    out of the box for nose.
1358
1359    For those that do inherit from unittest.TestResult and override its
1360    stopTest and addSuccess methods, patching is pretty straightforward
1361    (numerous examples below).
1362
1363    The reason we don't simply patch unittest's parent TestResult class
1364    is stopTest and addSuccess in the child classes tend to add messages
1365    into the output that we want to override in case flexmock generates
1366    its own failures.
1367    """
1368
1369    saved_addSuccess = klass.addSuccess
1370    saved_stopTest = klass.stopTest
1371
1372    def addSuccess(self, test):
1373        self._pre_flexmock_success = True
1374
1375    def stopTest(self, test):
1376        if _get_code(saved_stopTest) is not _get_code(stopTest):
1377            # if parent class was for some reason patched, avoid calling
1378            # flexmock_teardown() twice and delegate up the class hierarchy
1379            # this doesn't help if there is a gap and only the parent's
1380            # parent class was patched, but should cover most screw-ups
1381            try:
1382                flexmock_teardown()
1383                saved_addSuccess(self, test)
1384            except:
1385                if hasattr(self, '_pre_flexmock_success'):
1386                    self.addFailure(test, sys.exc_info())
1387            if hasattr(self, '_pre_flexmock_success'):
1388                del self._pre_flexmock_success
1389        return saved_stopTest(self, test)
1390
1391    if klass.stopTest is not stopTest:
1392        klass.stopTest = stopTest
1393
1394    if klass.addSuccess is not addSuccess:
1395        klass.addSuccess = addSuccess
1396
1397
1398def _hook_into_unittest():
1399    import unittest
1400    try:
1401        try:
1402            # only valid TestResult class for unittest is TextTestResult
1403            _patch_test_result(unittest.TextTestResult)
1404        except AttributeError:
1405            # ugh, python2.4
1406            _patch_test_result(unittest._TextTestResult)
1407    except:  # let's not take any chances
1408        pass
1409_hook_into_unittest()
1410
1411
1412def _hook_into_unittest2():
1413    try:
1414        try:
1415            from unittest2 import TextTestResult
1416        except ImportError:
1417            # Django has its own copy of unittest2 it uses as fallback
1418            from django.utils.unittest import TextTestResult
1419        _patch_test_result(TextTestResult)
1420    except:
1421        pass
1422_hook_into_unittest2()
1423
1424
1425def _hook_into_twisted():
1426    try:
1427        from twisted.trial import reporter
1428        _patch_test_result(reporter.MinimalReporter)
1429        _patch_test_result(reporter.TextReporter)
1430        _patch_test_result(reporter.VerboseTextReporter)
1431        _patch_test_result(reporter.TreeReporter)
1432    except:
1433        pass
1434_hook_into_twisted()
1435
1436
1437def _hook_into_subunit():
1438    try:
1439        import subunit
1440        _patch_test_result(subunit.TestProtocolClient)
1441    except:
1442        pass
1443_hook_into_subunit()
1444
1445
1446def _hook_into_zope():
1447    try:
1448        from zope import testrunner
1449        _patch_test_result(testrunner.runner.TestResult)
1450    except:
1451        pass
1452_hook_into_zope()
1453
1454
1455def _hook_into_testtools():
1456    try:
1457        from testtools import testresult
1458        _patch_test_result(testresult.TestResult)
1459    except:
1460        pass
1461_hook_into_testtools()
1462
1463
1464def _hook_into_teamcity_unittest():
1465    try:
1466        from tcunittest import TeamcityTestResult
1467        _patch_test_result(TeamcityTestResult)
1468    except:
1469        pass
1470_hook_into_teamcity_unittest()
1471
1472
1473# Dark magic to make the flexmock module itself callable.
1474# So that you can say:
1475#   import flexmock
1476# instead of:
1477#   from flexmock import flexmock
1478class _CallableModule(types.ModuleType):
1479    def __init__(self):
1480        super(_CallableModule, self).__init__('flexmock')
1481        self._realmod = sys.modules['flexmock']
1482        sys.modules['flexmock'] = self
1483        self.__doc__ = flexmock.__doc__
1484
1485    def __dir__(self):
1486        return dir(self._realmod)
1487
1488    def __call__(self, *args, **kw):
1489        return self._realmod.flexmock(*args, **kw)
1490
1491    def __getattr__(self, attr):
1492        return getattr(self._realmod, attr)
1493
1494_CallableModule()
1495