1# coding: utf-8
2"""
3Module ``mock``
4---------------
5
6Wrapper to unittest.mock reducing the boilerplate when testing asyncio powered
7code.
8
9A mock can behave as a coroutine, as specified in the documentation of
10:class:`~asynctest.mock.Mock`.
11"""
12
13import asyncio
14import asyncio.coroutines
15import contextlib
16import enum
17import functools
18import inspect
19import sys
20import types
21import unittest.mock
22
23
24# From python 3.6, a sentinel object is used to mark coroutines (rather than
25# a boolean) to prevent a mock/proxy object to return a truthy value.
26# see: https://github.com/python/asyncio/commit/ea776a11f632a975ad3ebbb07d8981804aa292db
27try:
28    _is_coroutine = asyncio.coroutines._is_coroutine
29except AttributeError:
30    _is_coroutine = True
31
32
33class _AsyncIterator:
34    """
35    Wraps an iterator in an asynchronous iterator.
36    """
37    def __init__(self, iterator):
38        self.iterator = iterator
39
40    def __aiter__(self):
41        return self
42
43    async def __anext__(self):
44        try:
45            return next(self.iterator)
46        except StopIteration:
47            pass
48        raise StopAsyncIteration
49
50
51# magic methods which must be coroutine functions
52async_magic_coroutines = ("__aenter__", "__aexit__", "__anext__")
53# all magic methods used in an async context
54_async_magics = async_magic_coroutines + ("__aiter__", )
55
56# We use unittest.mock.MagicProxy which works well, but it's not aware that
57# we want __aexit__ to return a falsy value by default.
58# We add the entry in unittest internal dict as it will not change the
59# normal behavior of unittest.
60unittest.mock._return_values["__aexit__"] = False
61
62
63def _get_async_iter(mock):
64    """
65    Factory of ``__aiter__`` magic methods for a MagicMock.
66
67    It creates a function which returns an asynchronous iterator based on the
68    return value of ``mock.__aiter__``.
69
70    Since __aiter__ used could be a coroutine in Python 3.5 and 3.6, we also
71    support this case.
72
73    See: https://www.python.org/dev/peps/pep-0525/#id23
74    """
75    def __aiter__():
76        return_value = mock.__aiter__._mock_return_value
77        if return_value is DEFAULT:
78            iterator = iter([])
79        else:
80            iterator = iter(return_value)
81
82        return _AsyncIterator(iterator)
83
84    if asyncio.iscoroutinefunction(mock.__aiter__):
85        return asyncio.coroutine(__aiter__)
86
87    return __aiter__
88
89
90unittest.mock._side_effect_methods["__aiter__"] = _get_async_iter
91
92
93async_magic_coroutines = set(async_magic_coroutines)
94_async_magics = set(_async_magics)
95
96# This changes the behavior of unittest, but the change is minor and is
97# probably better than overriding __set/get/del attr__ everywhere.
98unittest.mock._all_magics |= _async_magics
99
100def _raise(exception):
101    raise exception
102
103
104def _make_native_coroutine(coroutine):
105    """
106    Wrap a coroutine (or any function returning an awaitable) in a native
107    coroutine.
108    """
109    if inspect.iscoroutinefunction(coroutine):
110        # Nothing to do.
111        return coroutine
112
113    @functools.wraps(coroutine)
114    async def wrapper(*args, **kwargs):
115        return await coroutine(*args, **kwargs)
116
117    return wrapper
118
119
120def _is_started(patching):
121    if isinstance(patching, _patch_dict):
122        return patching._is_started
123    else:
124        return unittest.mock._is_started(patching)
125
126
127class FakeInheritanceMeta(type):
128    """
129    A metaclass which recreates the original inheritance model from
130    unittest.mock.
131
132    - NonCallableMock > NonCallableMagicMock
133    - NonCallable > Mock
134    - Mock > MagicMock
135    """
136    def __init__(self, name, bases, attrs):
137        attrs['__new__'] = types.MethodType(self.__new, self)
138        super().__init__(name, bases, attrs)
139
140    @staticmethod
141    def __new(cls, *args, **kwargs):
142        new = type(cls.__name__, (cls, ), {'__doc__': cls.__doc__})
143        return object.__new__(new, *args, **kwargs)
144
145    def __instancecheck__(cls, obj):
146        # That's tricky, each type(mock) is actually a subclass of the actual
147        # Mock type (see __new__)
148        if super().__instancecheck__(obj):
149            return True
150
151        _type = type(obj)
152        if issubclass(cls, NonCallableMock):
153            if issubclass(_type, (NonCallableMagicMock, Mock, )):
154                return True
155
156        if issubclass(cls, Mock) and not issubclass(cls, CoroutineMock):
157            if issubclass(_type, (MagicMock, )):
158                return True
159
160        return False
161
162
163def _get_is_coroutine(self):
164    return self.__dict__['_mock_is_coroutine']
165
166
167def _set_is_coroutine(self, value):
168    # property setters and getters are overridden by Mock(), we need to
169    # update the dict to add values
170    value = _is_coroutine if bool(value) else False
171    self.__dict__['_mock_is_coroutine'] = value
172
173
174# _mock_add_spec() is the actual private implementation in unittest.mock, we
175# override it to support coroutines in the metaclass.
176def _mock_add_spec(self, spec, *args, **kwargs):
177    unittest.mock.NonCallableMock._mock_add_spec(self, spec, *args, **kwargs)
178
179    _spec_coroutines = []
180    for attr in dir(spec):
181        if asyncio.iscoroutinefunction(getattr(spec, attr)):
182            _spec_coroutines.append(attr)
183
184    self.__dict__['_spec_coroutines'] = _spec_coroutines
185
186
187def _get_child_mock(self, *args, **kwargs):
188    _new_name = kwargs.get("_new_name")
189    if _new_name in self.__dict__['_spec_coroutines']:
190        return CoroutineMock(*args, **kwargs)
191
192    _type = type(self)
193
194    if issubclass(_type, MagicMock) and _new_name in async_magic_coroutines:
195        klass = CoroutineMock
196    elif issubclass(_type, CoroutineMock):
197        klass = MagicMock
198    elif not issubclass(_type, unittest.mock.CallableMixin):
199        if issubclass(_type, unittest.mock.NonCallableMagicMock):
200            klass = MagicMock
201        elif issubclass(_type, NonCallableMock):
202            klass = Mock
203    else:
204        klass = _type.__mro__[1]
205
206    return klass(*args, **kwargs)
207
208
209class MockMetaMixin(FakeInheritanceMeta):
210    def __new__(meta, name, base, namespace):
211        if not any((isinstance(baseclass, meta) for baseclass in base)):
212            # this ensures that inspect.iscoroutinefunction() doesn't return
213            # True when testing a mock.
214            code_mock = unittest.mock.NonCallableMock(spec_set=types.CodeType)
215            code_mock.co_flags = 0
216
217            namespace.update({
218                '_mock_add_spec': _mock_add_spec,
219                '_get_child_mock': _get_child_mock,
220                '__code__': code_mock,
221            })
222
223        return super().__new__(meta, name, base, namespace)
224
225
226class IsCoroutineArgMeta(MockMetaMixin):
227    def __new__(meta, name, base, namespace):
228        if not any((isinstance(baseclass, meta) for baseclass in base)):
229            namespace.update({
230                '_asynctest_get_is_coroutine': _get_is_coroutine,
231                '_asynctest_set_is_coroutine': _set_is_coroutine,
232                'is_coroutine': property(_get_is_coroutine, _set_is_coroutine,
233                                         doc="True if the object mocked is a coroutine"),
234                '_is_coroutine': property(_get_is_coroutine),
235            })
236
237            wrapped_setattr = namespace.get("__setattr__", base[0].__setattr__)
238            def __setattr__(self, attrname, value):
239                if attrname == 'is_coroutine':
240                    self._asynctest_set_is_coroutine(value)
241                else:
242                    return wrapped_setattr(self, attrname, value)
243
244            namespace['__setattr__'] = __setattr__
245
246        return super().__new__(meta, name, base, namespace)
247
248
249class AsyncMagicMixin:
250    """
251    Add support for async magic methods to :class:`MagicMock` and
252    :class:`NonCallableMagicMock`.
253
254    Actually, it's a shameless copy-paste of :class:`unittest.mock.MagicMixin`:
255        when added to our classes, it will just do exactly what its
256        :mod:`unittest` counterpart does, but for magic methods. It adds some
257        behavior but should be compatible with future additions of
258        :class:`MagicMock`.
259    """
260    # Magic methods are invoked as type(obj).__magic__(obj), as seen in
261    # PEP-343 (with) and PEP-492 (async with)
262    def __init__(self, *args, **kwargs):
263        self._mock_set_async_magics()  # make magic work for kwargs in init
264        unittest.mock._safe_super(AsyncMagicMixin, self).__init__(*args, **kwargs)
265        self._mock_set_async_magics()  # fix magic broken by upper level init
266
267    def _mock_set_async_magics(self):
268        these_magics = _async_magics
269
270        if getattr(self, "_mock_methods", None) is not None:
271            these_magics = _async_magics.intersection(self._mock_methods)
272            remove_magics = _async_magics - these_magics
273
274            for entry in remove_magics:
275                if entry in type(self).__dict__:
276                    # remove unneeded magic methods
277                    delattr(self, entry)
278
279        # don't overwrite existing attributes if called a second time
280        these_magics = these_magics - set(type(self).__dict__)
281
282        _type = type(self)
283        for entry in these_magics:
284            setattr(_type, entry, unittest.mock.MagicProxy(entry, self))
285
286    def mock_add_spec(self, *args, **kwargs):
287        unittest.mock.MagicMock.mock_add_spec(self, *args, **kwargs)
288        self._mock_set_async_magics()
289
290    def __setattr__(self, name, value):
291        _mock_methods = getattr(self, '_mock_methods', None)
292        if _mock_methods is None or name in _mock_methods:
293            if name in _async_magics:
294                if not unittest.mock._is_instance_mock(value):
295                    setattr(type(self), name,
296                            unittest.mock._get_method(name, value))
297                    original = value
298
299                    def value(*args, **kwargs):
300                        return original(self, *args, **kwargs)
301                else:
302                    unittest.mock._check_and_set_parent(self, value, None, name)
303                    setattr(type(self), name, value)
304                    self._mock_children[name] = value
305
306                return object.__setattr__(self, name, value)
307
308        unittest.mock._safe_super(AsyncMagicMixin, self).__setattr__(name, value)
309
310
311# Notes about unittest.mock:
312#  - MagicMock > Mock > NonCallableMock (where ">" means inherits from)
313#  - when a mock instance is created, a new class (type) is created
314#    dynamically,
315#  - we *must* use magic or object's internals when we want to add our own
316#    properties, and often override __getattr__/__setattr__ which are used
317#    in unittest.mock.NonCallableMock.
318class NonCallableMock(unittest.mock.NonCallableMock,
319                      metaclass=IsCoroutineArgMeta):
320    """
321    Enhance :class:`unittest.mock.NonCallableMock` with features allowing to
322    mock a coroutine function.
323
324    If ``is_coroutine`` is set to ``True``, the :class:`NonCallableMock`
325    object will behave so :func:`asyncio.iscoroutinefunction` will return
326    ``True`` with ``mock`` as parameter.
327
328    If ``spec`` or ``spec_set`` is defined and an attribute is get,
329    :class:`~asynctest.CoroutineMock` is returned instead of
330    :class:`~asynctest.Mock` when the matching spec attribute is a coroutine
331    function.
332
333    The test author can also specify a wrapped object with ``wraps``. In this
334    case, the :class:`~asynctest.Mock` object behavior is the same as with an
335    :class:`unittest.mock.Mock` object: the wrapped object may have methods
336    defined as coroutine functions.
337
338    See :class:`unittest.mock.NonCallableMock`
339    """
340    def __init__(self, spec=None, wraps=None, name=None, spec_set=None,
341                 is_coroutine=None, parent=None, **kwargs):
342        super().__init__(spec=spec, wraps=wraps, name=name, spec_set=spec_set,
343                         parent=parent, **kwargs)
344
345        self._asynctest_set_is_coroutine(is_coroutine)
346
347
348class NonCallableMagicMock(AsyncMagicMixin, unittest.mock.NonCallableMagicMock,
349                           metaclass=IsCoroutineArgMeta):
350    """
351    A version of :class:`~asynctest.MagicMock` that isn't callable.
352    """
353    def __init__(self, spec=None, wraps=None, name=None, spec_set=None,
354                 is_coroutine=None, parent=None, **kwargs):
355
356        super().__init__(spec=spec, wraps=wraps, name=name, spec_set=spec_set,
357                         parent=parent, **kwargs)
358
359        self._asynctest_set_is_coroutine(is_coroutine)
360
361
362class Mock(unittest.mock.Mock, metaclass=MockMetaMixin):
363    """
364    Enhance :class:`unittest.mock.Mock` so it returns
365    a :class:`~asynctest.CoroutineMock` object instead of
366    a :class:`~asynctest.Mock` object where a method on a ``spec`` or
367    ``spec_set`` object is a coroutine.
368
369    For instance:
370
371    >>> class Foo:
372    ...     @asyncio.coroutine
373    ...     def foo(self):
374    ...         pass
375    ...
376    ...     def bar(self):
377    ...         pass
378
379    >>> type(asynctest.mock.Mock(Foo()).foo)
380    <class 'asynctest.mock.CoroutineMock'>
381
382    >>> type(asynctest.mock.Mock(Foo()).bar)
383    <class 'asynctest.mock.Mock'>
384
385    The test author can also specify a wrapped object with ``wraps``. In this
386    case, the :class:`~asynctest.Mock` object behavior is the same as with an
387    :class:`unittest.mock.Mock` object: the wrapped object may have methods
388    defined as coroutine functions.
389
390    If you want to mock a coroutine function, use :class:`CoroutineMock`
391    instead.
392
393    See :class:`~asynctest.NonCallableMock` for details about :mod:`asynctest`
394    features, and :mod:`unittest.mock` for the comprehensive documentation
395    about mocking.
396    """
397
398
399class MagicMock(AsyncMagicMixin, unittest.mock.MagicMock,
400                metaclass=MockMetaMixin):
401    """
402    Enhance :class:`unittest.mock.MagicMock` so it returns
403    a :class:`~asynctest.CoroutineMock` object instead of
404    a :class:`~asynctest.Mock` object where a method on a ``spec`` or
405    ``spec_set`` object is a coroutine.
406
407    If you want to mock a coroutine function, use :class:`CoroutineMock`
408    instead.
409
410    :class:`MagicMock` allows to mock ``__aenter__``, ``__aexit__``,
411    ``__aiter__`` and ``__anext__``.
412
413    When mocking an asynchronous iterator, you can set the
414    ``return_value`` of ``__aiter__`` to an iterable to define the list of
415    values to be returned during iteration.
416
417    You can not mock ``__await__``. If you want to mock an object implementing
418    __await__, :class:`CoroutineMock` will likely be sufficient.
419
420    see :class:`~asynctest.Mock`.
421
422    .. versionadded:: 0.11
423
424        support of asynchronous iterators and asynchronous context managers.
425    """
426
427
428class _AwaitEvent:
429    def __init__(self, mock):
430        self._mock = mock
431        self._condition = None
432
433    @asyncio.coroutine
434    def wait(self, skip=0):
435        """
436        Wait for await.
437
438        :param skip: How many awaits will be skipped.
439                     As a result, the mock should be awaited at least
440                     ``skip + 1`` times.
441        """
442        def predicate(mock):
443            return mock.await_count > skip
444
445        return (yield from self.wait_for(predicate))
446
447    @asyncio.coroutine
448    def wait_next(self, skip=0):
449        """
450        Wait for the next await.
451
452        Unlike :meth:`wait` that counts any await, mock has to be awaited once
453        more, disregarding to the current
454        :attr:`asynctest.CoroutineMock.await_count`.
455
456        :param skip: How many awaits will be skipped.
457                     As a result, the mock should be awaited at least
458                     ``skip + 1`` more times.
459        """
460        await_count = self._mock.await_count
461
462        def predicate(mock):
463            return mock.await_count > await_count + skip
464
465        return (yield from self.wait_for(predicate))
466
467    @asyncio.coroutine
468    def wait_for(self, predicate):
469        """
470        Wait for a given predicate to become True.
471
472        :param predicate: A callable that receives mock which result
473                          will be interpreted as a boolean value.
474                          The final predicate value is the return value.
475        """
476        condition = self._get_condition()
477
478        try:
479            yield from condition.acquire()
480
481            def _predicate():
482                return predicate(self._mock)
483
484            return (yield from condition.wait_for(_predicate))
485        finally:
486            condition.release()
487
488    @asyncio.coroutine
489    def _notify(self):
490        condition = self._get_condition()
491
492        try:
493            yield from condition.acquire()
494            condition.notify_all()
495        finally:
496            condition.release()
497
498    def _get_condition(self):
499        """
500        Creation of condition is delayed, to minimize the change of using the
501        wrong loop.
502
503        A user may create a mock with _AwaitEvent before selecting the
504        execution loop.  Requiring a user to delay creation is error-prone and
505        inflexible. Instead, condition is created when user actually starts to
506        use the mock.
507        """
508        # No synchronization is needed:
509        #   - asyncio is thread unsafe
510        #   - there are no awaits here, method will be executed without
511        #   switching asyncio context.
512        if self._condition is None:
513            self._condition = asyncio.Condition()
514
515        return self._condition
516
517    def __bool__(self):
518        return self._mock.await_count != 0
519
520
521class CoroutineMock(Mock):
522    """
523    Enhance :class:`~asynctest.mock.Mock` with features allowing to mock
524    a coroutine function.
525
526    The :class:`~asynctest.CoroutineMock` object will behave so the object is
527    recognized as coroutine function, and the result of a call as a coroutine:
528
529    >>> mock = CoroutineMock()
530    >>> asyncio.iscoroutinefunction(mock)
531    True
532    >>> asyncio.iscoroutine(mock())
533    True
534
535
536    The result of ``mock()`` is a coroutine which will have the outcome of
537    ``side_effect`` or ``return_value``:
538
539    - if ``side_effect`` is a function, the coroutine will return the result
540      of that function,
541    - if ``side_effect`` is an exception, the coroutine will raise the
542      exception,
543    - if ``side_effect`` is an iterable, the coroutine will return the next
544      value of the iterable, however, if the sequence of result is exhausted,
545      ``StopIteration`` is raised immediately,
546    - if ``side_effect`` is not defined, the coroutine will return the value
547      defined by ``return_value``, hence, by default, the coroutine returns
548      a new :class:`~asynctest.CoroutineMock` object.
549
550    If the outcome of ``side_effect`` or ``return_value`` is a coroutine, the
551    mock coroutine obtained when the mock object is called will be this
552    coroutine itself (and not a coroutine returning a coroutine).
553
554    The test author can also specify a wrapped object with ``wraps``. In this
555    case, the :class:`~asynctest.Mock` object behavior is the same as with an
556    :class:`unittest.mock.Mock` object: the wrapped object may have methods
557    defined as coroutine functions.
558    """
559    #: Property which is set when the mock is awaited. Its ``wait`` and
560    #: ``wait_next`` coroutine methods can be used to synchronize execution.
561    #:
562    #: .. versionadded:: 0.12
563    awaited = unittest.mock._delegating_property('awaited')
564    #: Number of times the mock has been awaited.
565    #:
566    #: .. versionadded:: 0.12
567    await_count = unittest.mock._delegating_property('await_count')
568    await_args = unittest.mock._delegating_property('await_args')
569    await_args_list = unittest.mock._delegating_property('await_args_list')
570
571    def __init__(self, *args, **kwargs):
572        super().__init__(*args, **kwargs)
573
574        # asyncio.iscoroutinefunction() checks this property to say if an
575        # object is a coroutine
576        # It is set through __dict__ because when spec_set is True, this
577        # attribute is likely undefined.
578        self.__dict__['_is_coroutine'] = _is_coroutine
579        self.__dict__['_mock_awaited'] = _AwaitEvent(self)
580        self.__dict__['_mock_await_count'] = 0
581        self.__dict__['_mock_await_args'] = None
582        self.__dict__['_mock_await_args_list'] = unittest.mock._CallList()
583
584    def _mock_call(_mock_self, *args, **kwargs):
585        try:
586            result = super()._mock_call(*args, **kwargs)
587        except StopIteration as e:
588            side_effect = _mock_self.side_effect
589            if side_effect is not None and not callable(side_effect):
590                raise
591
592            result = asyncio.coroutine(_raise)(e)
593        except BaseException as e:
594            result = asyncio.coroutine(_raise)(e)
595
596        _call = _mock_self.call_args
597
598        @asyncio.coroutine
599        def proxy():
600            try:
601                if inspect.isawaitable(result):
602                    return (yield from result)
603                else:
604                    return result
605            finally:
606                _mock_self.await_count += 1
607                _mock_self.await_args = _call
608                _mock_self.await_args_list.append(_call)
609                yield from _mock_self.awaited._notify()
610
611        return proxy()
612
613    def assert_awaited(_mock_self):
614        """
615        Assert that the mock was awaited at least once.
616
617        .. versionadded:: 0.12
618        """
619        self = _mock_self
620        if self.await_count == 0:
621            msg = ("Expected '%s' to have been awaited." %
622                   self._mock_name or 'mock')
623            raise AssertionError(msg)
624
625    def assert_awaited_once(_mock_self, *args, **kwargs):
626        """
627        Assert that the mock was awaited exactly once.
628
629        .. versionadded:: 0.12
630        """
631        self = _mock_self
632        if not self.await_count == 1:
633            msg = ("Expected '%s' to have been awaited once. Awaited %s times." %
634                   (self._mock_name or 'mock', self.await_count))
635            raise AssertionError(msg)
636
637    def assert_awaited_with(_mock_self, *args, **kwargs):
638        """
639        Assert that the last await was with the specified arguments.
640
641        .. versionadded:: 0.12
642        """
643        self = _mock_self
644        if self.await_args is None:
645            expected = self._format_mock_call_signature(args, kwargs)
646            raise AssertionError('Expected await: %s\nNot awaited' % (expected,))
647
648        def _error_message():
649            msg = self._format_mock_failure_message(args, kwargs)
650            return msg
651
652        expected = self._call_matcher((args, kwargs))
653        actual = self._call_matcher(self.await_args)
654        if expected != actual:
655            cause = expected if isinstance(expected, Exception) else None
656            raise AssertionError(_error_message()) from cause
657
658    def assert_awaited_once_with(_mock_self, *args, **kwargs):
659        """
660        Assert that the mock was awaited exactly once and with the specified arguments.
661
662        .. versionadded:: 0.12
663        """
664        self = _mock_self
665        if not self.await_count == 1:
666            msg = ("Expected '%s' to be awaited once. Awaited %s times." %
667                   (self._mock_name or 'mock', self.await_count))
668            raise AssertionError(msg)
669        return self.assert_awaited_with(*args, **kwargs)
670
671    def assert_any_await(_mock_self, *args, **kwargs):
672        """
673        Assert the mock has ever been awaited with the specified arguments.
674
675        .. versionadded:: 0.12
676        """
677        self = _mock_self
678        expected = self._call_matcher((args, kwargs))
679        actual = [self._call_matcher(c) for c in self.await_args_list]
680        if expected not in actual:
681            cause = expected if isinstance(expected, Exception) else None
682            expected_string = self._format_mock_call_signature(args, kwargs)
683            raise AssertionError(
684                '%s await not found' % expected_string
685            ) from cause
686
687    def assert_has_awaits(_mock_self, calls, any_order=False):
688        """
689        Assert the mock has been awaited with the specified calls.
690        The :attr:`await_args_list` list is checked for the awaits.
691
692        If `any_order` is False (the default) then the awaits must be
693        sequential. There can be extra calls before or after the
694        specified awaits.
695
696        If `any_order` is True then the awaits can be in any order, but
697        they must all appear in :attr:`await_args_list`.
698
699        .. versionadded:: 0.12
700        """
701        self = _mock_self
702        expected = [self._call_matcher(c) for c in calls]
703        cause = expected if isinstance(expected, Exception) else None
704        all_awaits = unittest.mock._CallList(self._call_matcher(c) for c in self.await_args_list)
705        if not any_order:
706            if expected not in all_awaits:
707                raise AssertionError(
708                    'Awaits not found.\nExpected: %r\n'
709                    'Actual: %r' % (unittest.mock._CallList(calls), self.await_args_list)
710                ) from cause
711            return
712
713        all_awaits = list(all_awaits)
714
715        not_found = []
716        for kall in expected:
717            try:
718                all_awaits.remove(kall)
719            except ValueError:
720                not_found.append(kall)
721        if not_found:
722            raise AssertionError(
723                '%r not all found in await list' % (tuple(not_found),)
724            ) from cause
725
726    def assert_not_awaited(_mock_self):
727        """
728        Assert that the mock was never awaited.
729
730        .. versionadded:: 0.12
731        """
732        self = _mock_self
733        if self.await_count != 0:
734            msg = ("Expected '%s' to not have been awaited. Awaited %s times." %
735                   (self._mock_name or 'mock', self.await_count))
736            raise AssertionError(msg)
737
738    def reset_mock(self, *args, **kwargs):
739        """
740        See :func:`unittest.mock.Mock.reset_mock()`
741        """
742        super().reset_mock(*args, **kwargs)
743        self.awaited = _AwaitEvent(self)
744        self.await_count = 0
745        self.await_args = None
746        self.await_args_list = unittest.mock._CallList()
747
748
749def create_autospec(spec, spec_set=False, instance=False, _parent=None,
750                    _name=None, **kwargs):
751    """
752    Create a mock object using another object as a spec. Attributes on the mock
753    will use the corresponding attribute on the spec object as their spec.
754
755    ``spec`` can be a coroutine function, a class or object with coroutine
756    functions as attributes.
757
758    If ``spec`` is a coroutine function, and ``instance`` is not ``False``, a
759    :exc:`RuntimeError` is raised.
760
761    .. versionadded:: 0.12
762    """
763    if unittest.mock._is_list(spec):
764        spec = type(spec)
765
766    is_type = isinstance(spec, type)
767    is_coroutine_func = asyncio.iscoroutinefunction(spec)
768
769    _kwargs = {'spec': spec}
770    if spec_set:
771        _kwargs = {'spec_set': spec}
772    elif spec is None:
773        # None we mock with a normal mock without a spec
774        _kwargs = {}
775    if _kwargs and instance:
776        _kwargs['_spec_as_instance'] = True
777
778    _kwargs.update(kwargs)
779
780    Klass = MagicMock
781    if inspect.isdatadescriptor(spec):
782        _kwargs = {}
783    elif is_coroutine_func:
784        if instance:
785            raise RuntimeError("Instance can not be True when create_autospec "
786                               "is mocking a coroutine function")
787        Klass = CoroutineMock
788    elif not unittest.mock._callable(spec):
789        Klass = NonCallableMagicMock
790    elif is_type and instance and not unittest.mock._instance_callable(spec):
791        Klass = NonCallableMagicMock
792
793    _name = _kwargs.pop('name', _name)
794
795    _new_name = _name
796    if _parent is None:
797        _new_name = ''
798
799    mock = Klass(parent=_parent, _new_parent=_parent, _new_name=_new_name,
800                 name=_name, **_kwargs)
801
802    if isinstance(spec, unittest.mock.FunctionTypes):
803        wrapped_mock = mock
804        # _set_signature returns an object wrapping the mock, not the mock
805        # itself.
806        mock = unittest.mock._set_signature(mock, spec)
807        if is_coroutine_func:
808            # Can't wrap the mock with asyncio.coroutine because it doesn't
809            # detect a CoroWrapper as an awaitable in debug mode.
810            # It is safe to do so because the mock object wrapped by
811            # _set_signature returns the result of the CoroutineMock itself,
812            # which is a Coroutine (as defined in CoroutineMock._mock_call)
813            mock._is_coroutine = _is_coroutine
814            mock.awaited = _AwaitEvent(mock)
815            mock.await_count = 0
816            mock.await_args = None
817            mock.await_args_list = unittest.mock._CallList()
818
819            for a in ('assert_awaited',
820                      'assert_awaited_once',
821                      'assert_awaited_with',
822                      'assert_awaited_once_with',
823                      'assert_any_await',
824                      'assert_has_awaits',
825                      'assert_not_awaited'):
826                setattr(mock, a, getattr(wrapped_mock, a))
827    else:
828        unittest.mock._check_signature(spec, mock, is_type, instance)
829
830    if _parent is not None and not instance:
831        _parent._mock_children[_name] = mock
832
833    if is_type and not instance and 'return_value' not in kwargs:
834        mock.return_value = create_autospec(spec, spec_set, instance=True,
835                                            _name='()', _parent=mock)
836
837    for entry in dir(spec):
838        if unittest.mock._is_magic(entry):
839            continue
840        try:
841            original = getattr(spec, entry)
842        except AttributeError:
843            continue
844
845        kwargs = {'spec': original}
846        if spec_set:
847            kwargs = {'spec_set': original}
848
849        if not isinstance(original, unittest.mock.FunctionTypes):
850            new = unittest.mock._SpecState(original, spec_set, mock, entry,
851                                           instance)
852            mock._mock_children[entry] = new
853        else:
854            parent = mock
855            if isinstance(spec, unittest.mock.FunctionTypes):
856                parent = mock.mock
857
858            skipfirst = unittest.mock._must_skip(spec, entry, is_type)
859            kwargs['_eat_self'] = skipfirst
860            if asyncio.iscoroutinefunction(original):
861                child_klass = CoroutineMock
862            else:
863                child_klass = MagicMock
864            new = child_klass(parent=parent, name=entry, _new_name=entry,
865                              _new_parent=parent, **kwargs)
866            mock._mock_children[entry] = new
867            unittest.mock._check_signature(original, new, skipfirst=skipfirst)
868
869        if isinstance(new, unittest.mock.FunctionTypes):
870            setattr(mock, entry, new)
871
872    return mock
873
874
875def mock_open(mock=None, read_data=''):
876    """
877    A helper function to create a mock to replace the use of :func:`open()`. It
878    works for :func:`open()` called directly or used as a context manager.
879
880    :param mock: mock object to configure, by default
881                 a :class:`~asynctest.MagicMock` object is
882                 created with the API limited to methods or attributes
883                 available on standard file handles.
884
885    :param read_data: string for the :func:`read()` and :func:`readlines()` of
886                      the file handle to return. This is an empty string by
887                      default.
888    """
889    if mock is None:
890        mock = MagicMock(name='open', spec=open)
891
892    return unittest.mock.mock_open(mock, read_data)
893
894
895ANY = unittest.mock.ANY
896DEFAULT = unittest.mock.sentinel.DEFAULT
897
898
899def _update_new_callable(patcher, new, new_callable):
900    if new == DEFAULT and not new_callable:
901        original = patcher.get_original()[0]
902        if isinstance(original, (classmethod, staticmethod)):
903            # the original object is the raw descriptor, if it's a classmethod
904            # or a static method, we need to unwrap it
905            original = original.__get__(None, object)
906
907        if asyncio.iscoroutinefunction(original):
908            patcher.new_callable = CoroutineMock
909        else:
910            patcher.new_callable = MagicMock
911
912    return patcher
913
914
915# Documented in doc/asynctest.mock.rst
916PatchScope = enum.Enum('PatchScope', 'LIMITED GLOBAL')
917LIMITED = PatchScope.LIMITED
918GLOBAL = PatchScope.GLOBAL
919
920
921def _decorate_coroutine_callable(func, new_patching):
922    if hasattr(func, 'patchings'):
923        func.patchings.append(new_patching)
924        return func
925
926    # Python 3.5 returns True for is_generator_func(new_style_coroutine) if
927    # there is an "await" statement in the function body, which is wrong. It is
928    # fixed in 3.6, but I can't find which commit fixes this.
929    # The only way to work correctly with 3.5 and 3.6 seems to use
930    # inspect.iscoroutinefunction()
931    is_generator_func = inspect.isgeneratorfunction(func)
932    is_coroutine_func = asyncio.iscoroutinefunction(func)
933    try:
934        is_native_coroutine_func = inspect.iscoroutinefunction(func)
935    except AttributeError:
936        is_native_coroutine_func = False
937
938    if not (is_generator_func or is_coroutine_func):
939        return None
940
941    patchings = [new_patching]
942
943    def patched_factory(*args, **kwargs):
944        extra_args = []
945        patchers_to_exit = []
946        patch_dict_with_limited_scope = []
947
948        exc_info = tuple()
949        try:
950            for patching in patchings:
951                arg = patching.__enter__()
952                if patching.scope == LIMITED:
953                    patchers_to_exit.append(patching)
954                if isinstance(patching, _patch_dict):
955                    if patching.scope == GLOBAL:
956                        for limited_patching in patch_dict_with_limited_scope:
957                            if limited_patching.in_dict is patching.in_dict:
958                                limited_patching._keep_global_patch(patching)
959                    else:
960                        patch_dict_with_limited_scope.append(patching)
961                else:
962                    if patching.attribute_name is not None:
963                        kwargs.update(arg)
964                        if patching.new is DEFAULT:
965                            patching.new = arg[patching.attribute_name]
966                    elif patching.new is DEFAULT:
967                        patching.mock_to_reuse = arg
968                        extra_args.append(arg)
969
970            args += tuple(extra_args)
971            gen = func(*args, **kwargs)
972            return _PatchedGenerator(gen, patchings,
973                                     asyncio.iscoroutinefunction(func))
974        except BaseException:
975            if patching not in patchers_to_exit and _is_started(patching):
976                # the patcher may have been started, but an exception
977                # raised whilst entering one of its additional_patchers
978                patchers_to_exit.append(patching)
979            # Pass the exception to __exit__
980            exc_info = sys.exc_info()
981            # re-raise the exception
982            raise
983        finally:
984            for patching in reversed(patchers_to_exit):
985                patching.__exit__(*exc_info)
986
987    # wrap the factory in a native coroutine or a generator to respect
988    # introspection.
989    if is_native_coroutine_func:
990        # inspect.iscoroutinefunction() returns True
991        patched = _make_native_coroutine(patched_factory)
992    elif is_generator_func:
993        # inspect.isgeneratorfunction() returns True
994        def patched_generator(*args, **kwargs):
995            return (yield from patched_factory(*args, **kwargs))
996
997        patched = patched_generator
998
999        if is_coroutine_func:
1000            # asyncio.iscoroutinefunction() returns True
1001            patched = asyncio.coroutine(patched)
1002    else:
1003        patched = patched_factory
1004
1005    patched.patchings = patchings
1006    return functools.wraps(func)(patched)
1007
1008
1009class _PatchedGenerator(asyncio.coroutines.CoroWrapper):
1010    # Inheriting from asyncio.CoroWrapper gives us a comprehensive wrapper
1011    # implementing one or more workarounds for cpython bugs
1012    def __init__(self, gen, patchings, is_coroutine):
1013        self.gen = gen
1014        self._is_coroutine = is_coroutine
1015        self.__name__ = getattr(gen, '__name__', None)
1016        self.__qualname__ = getattr(gen, '__qualname__', None)
1017        self.patchings = patchings
1018        self.global_patchings = [p for p in patchings if p.scope == GLOBAL]
1019        self.limited_patchings = [p for p in patchings if p.scope == LIMITED]
1020
1021        # GLOBAL patches have been started in the _patch/patched() wrapper
1022
1023    def _limited_patchings_stack(self):
1024        with contextlib.ExitStack() as stack:
1025            for patching in self.limited_patchings:
1026                stack.enter_context(patching)
1027
1028            return stack.pop_all()
1029
1030    def _stop_global_patchings(self):
1031        for patching in reversed(self.global_patchings):
1032            if _is_started(patching):
1033                patching.stop()
1034
1035    def __repr__(self):
1036        return repr(self.generator)
1037
1038    def __next__(self):
1039        try:
1040            with self._limited_patchings_stack():
1041                return self.gen.send(None)
1042        except BaseException:
1043            # the generator/coroutine terminated, stop the patchings
1044            self._stop_global_patchings()
1045            raise
1046
1047    def send(self, value):
1048        with self._limited_patchings_stack():
1049            return super().send(value)
1050
1051    def throw(self, exc, value=None, traceback=None):
1052        with self._limited_patchings_stack():
1053            return self.gen.throw(exc, value, traceback)
1054
1055    def close(self):
1056        try:
1057            with self._limited_patchings_stack():
1058                return self.gen.close()
1059        finally:
1060            self._stop_global_patchings()
1061
1062    def __del__(self):
1063        # The generator/coroutine is deleted before it terminated, we must
1064        # still stop the patchings
1065        self._stop_global_patchings()
1066
1067
1068class _patch(unittest.mock._patch):
1069    def __init__(self, *args, scope=GLOBAL, **kwargs):
1070        super().__init__(*args, **kwargs)
1071        self.scope = scope
1072        self.mock_to_reuse = None
1073
1074    def copy(self):
1075        patcher = _patch(
1076            self.getter, self.attribute, self.new, self.spec,
1077            self.create, self.spec_set,
1078            self.autospec, self.new_callable, self.kwargs,
1079            scope=self.scope)
1080        patcher.attribute_name = self.attribute_name
1081        patcher.additional_patchers = [
1082            p.copy() for p in self.additional_patchers
1083        ]
1084        return patcher
1085
1086    def __enter__(self):
1087        # When patching a coroutine, we reuse the same mock object
1088        if self.mock_to_reuse is not None:
1089            self.target = self.getter()
1090            self.temp_original, self.is_local = self.get_original()
1091            setattr(self.target, self.attribute, self.mock_to_reuse)
1092            if self.attribute_name is not None:
1093                for patching in self.additional_patchers:
1094                    patching.__enter__()
1095            return self.mock_to_reuse
1096        else:
1097            return self._perform_patch()
1098
1099    def _perform_patch(self):
1100        # This will intercept the result of super().__enter__() if we need to
1101        # override the default behavior (ie: we need to use our own autospec).
1102        original, local = self.get_original()
1103        result = super().__enter__()
1104
1105        if self.autospec is None or not self.autospec:
1106            # no need to override the default behavior
1107            return result
1108
1109        if self.autospec is True:
1110            autospec = original
1111        else:
1112            autospec = self.autospec
1113
1114        new = create_autospec(autospec, spec_set=bool(self.spec_set),
1115                              _name=self.attribute, **self.kwargs)
1116
1117        self.temp_original = original
1118        self.is_local = local
1119        setattr(self.target, self.attribute, new)
1120
1121        if self.attribute_name is not None:
1122            if self.new is DEFAULT:
1123                result[self.attribute_name] = new
1124            return result
1125
1126        return new
1127
1128    def decorate_callable(self, func):
1129        wrapped = _decorate_coroutine_callable(func, self)
1130        if wrapped is None:
1131            return super().decorate_callable(func)
1132        else:
1133            return wrapped
1134
1135
1136def patch(target, new=DEFAULT, spec=None, create=False, spec_set=None,
1137          autospec=None, new_callable=None, scope=GLOBAL, **kwargs):
1138    """
1139    A context manager, function decorator or class decorator which patches the
1140    target with the value given by the ``new`` argument.
1141
1142    ``new`` specifies which object will replace the ``target`` when the patch
1143    is applied. By default, the target will be patched with an instance of
1144    :class:`~asynctest.CoroutineMock` if it is a coroutine, or
1145    a :class:`~asynctest.MagicMock` object.
1146
1147    It is a replacement to :func:`unittest.mock.patch`, but using
1148    :mod:`asynctest.mock` objects.
1149
1150    When a generator or a coroutine is patched using the decorator, the patch
1151    is activated or deactivated according to the ``scope`` argument value:
1152
1153      * :const:`asynctest.GLOBAL`: the default, enables the patch until the
1154        generator or the coroutine finishes (returns or raises an exception),
1155
1156      * :const:`asynctest.LIMITED`: the patch will be activated when the
1157        generator or coroutine is being executed, and deactivated when it
1158        yields a value and pauses its execution (with ``yield``, ``yield from``
1159        or ``await``).
1160
1161    The behavior differs from :func:`unittest.mock.patch` for generators.
1162
1163    When used as a context manager, the patch is still active even if the
1164    generator or coroutine is paused, which may affect concurrent tasks::
1165
1166        @asyncio.coroutine
1167        def coro():
1168            with asynctest.mock.patch("module.function"):
1169                yield from asyncio.get_event_loop().sleep(1)
1170
1171        @asyncio.coroutine
1172        def independent_coro():
1173            assert not isinstance(module.function, asynctest.mock.Mock)
1174
1175        asyncio.create_task(coro())
1176        asyncio.create_task(independent_coro())
1177        # this will raise an AssertionError(coro() is scheduled first)!
1178        loop.run_forever()
1179
1180    :param scope: :const:`asynctest.GLOBAL` or :const:`asynctest.LIMITED`,
1181        controls when the patch is activated on generators and coroutines
1182
1183    When used as a decorator with a generator based coroutine, the order of
1184    the decorators matters. The order of the ``@patch()`` decorators is in
1185    the reverse order of the parameters produced by these patches for the
1186    patched function. And the ``@asyncio.coroutine`` decorator should be
1187    the last since ``@patch()`` conceptually patches the coroutine, not
1188    the function::
1189
1190        @patch("module.function2")
1191        @patch("module.function1")
1192        @asyncio.coroutine
1193        def test_coro(self, mock_function1, mock_function2):
1194            yield from asyncio.get_event_loop().sleep(1)
1195
1196    see :func:`unittest.mock.patch()`.
1197
1198    .. versionadded:: 0.6 patch into generators and coroutines with
1199                      a decorator.
1200    """
1201    getter, attribute = unittest.mock._get_target(target)
1202    patcher = _patch(getter, attribute, new, spec, create, spec_set, autospec,
1203                     new_callable, kwargs, scope=scope)
1204
1205    return _update_new_callable(patcher, new, new_callable)
1206
1207
1208def _patch_object(target, attribute, new=DEFAULT, spec=None, create=False,
1209                  spec_set=None, autospec=None, new_callable=None,
1210                  scope=GLOBAL, **kwargs):
1211    patcher = _patch(lambda: target, attribute, new, spec, create, spec_set,
1212                     autospec, new_callable, kwargs, scope=scope)
1213
1214    return _update_new_callable(patcher, new, new_callable)
1215
1216
1217def _patch_multiple(target, spec=None, create=False, spec_set=None,
1218                    autospec=None, new_callable=None, scope=GLOBAL, **kwargs):
1219    if type(target) is str:
1220        def getter():
1221            return unittest.mock._importer(target)
1222    else:
1223        def getter():
1224            return target
1225
1226    if not kwargs:
1227        raise ValueError('Must supply at least one keyword argument with '
1228                         'patch.multiple')
1229
1230    items = list(kwargs.items())
1231    attribute, new = items[0]
1232    patcher = _patch(getter, attribute, new, spec, create, spec_set, autospec,
1233                     new_callable, {}, scope=scope)
1234
1235    patcher.attribute_name = attribute
1236    for attribute, new in items[1:]:
1237        this_patcher = _patch(getter, attribute, new, spec, create, spec_set,
1238                              autospec, new_callable, {}, scope=scope)
1239        this_patcher.attribute_name = attribute
1240        patcher.additional_patchers.append(this_patcher)
1241
1242    def _update(patcher):
1243        return _update_new_callable(patcher, patcher.new, new_callable)
1244
1245    patcher = _update(patcher)
1246    patcher.additional_patchers = list(map(_update,
1247                                           patcher.additional_patchers))
1248
1249    return patcher
1250
1251
1252class _patch_dict(unittest.mock._patch_dict):
1253    # documentation is in doc/asynctest.mock.rst
1254    def __init__(self, in_dict, values=(), clear=False, scope=GLOBAL,
1255                 **kwargs):
1256        super().__init__(in_dict, values, clear, **kwargs)
1257        self.scope = scope
1258        self._is_started = False
1259        self._global_patchings = []
1260
1261    def _keep_global_patch(self, other_patching):
1262        self._global_patchings.append(other_patching)
1263
1264    def decorate_class(self, klass):
1265        for attr in dir(klass):
1266            attr_value = getattr(klass, attr)
1267            if (attr.startswith(patch.TEST_PREFIX) and
1268                    hasattr(attr_value, "__call__")):
1269                decorator = _patch_dict(self.in_dict, self.values, self.clear)
1270                decorated = decorator(attr_value)
1271                setattr(klass, attr, decorated)
1272        return klass
1273
1274    def __call__(self, func):
1275        if isinstance(func, type):
1276            return self.decorate_class(func)
1277
1278        wrapper = _decorate_coroutine_callable(func, self)
1279        if wrapper is None:
1280            return super().__call__(func)
1281        else:
1282            return wrapper
1283
1284    def _patch_dict(self):
1285        self._is_started = True
1286
1287        # Since Python 3.7.3, the moment when a dict specified by a target
1288        # string has been corrected. (see #115)
1289        if isinstance(self.in_dict, str):
1290            self.in_dict = unittest.mock._importer(self.in_dict)
1291
1292        try:
1293            self._original = self.in_dict.copy()
1294        except AttributeError:
1295            # dict like object with no copy method
1296            # must support iteration over keys
1297            self._original = {}
1298            for key in self.in_dict:
1299                self._original[key] = self.in_dict[key]
1300
1301        if self.clear:
1302            _clear_dict(self.in_dict)
1303
1304        try:
1305            self.in_dict.update(self.values)
1306        except AttributeError:
1307            # dict like object with no update method
1308            for key in self.values:
1309                self.in_dict[key] = self.values[key]
1310
1311    def _unpatch_dict(self):
1312        self._is_started = False
1313
1314        if self.scope == LIMITED:
1315            # add to self.values the updated values which where not in
1316            # the original dict, as the patch may be reactivated
1317            for key in self.in_dict:
1318                if (key not in self._original or
1319                        self._original[key] is not self.in_dict[key]):
1320                    self.values[key] = self.in_dict[key]
1321
1322        _clear_dict(self.in_dict)
1323
1324        originals = [self._original]
1325        for patching in self._global_patchings:
1326            if patching._is_started:
1327                # keep the values of global patches
1328                originals.append(patching.values)
1329
1330        for original in originals:
1331            try:
1332                self.in_dict.update(original)
1333            except AttributeError:
1334                for key in original:
1335                    self.in_dict[key] = original[key]
1336
1337
1338_clear_dict = unittest.mock._clear_dict
1339
1340patch.object = _patch_object
1341patch.dict = _patch_dict
1342patch.multiple = _patch_multiple
1343patch.stopall = unittest.mock._patch_stopall
1344patch.TEST_PREFIX = 'test'
1345
1346
1347sentinel = unittest.mock.sentinel
1348call = unittest.mock.call
1349PropertyMock = unittest.mock.PropertyMock
1350
1351
1352def return_once(value, then=None):
1353    """
1354    Helper to use with ``side_effect``, so a mock will return a given value
1355    only once, then return another value.
1356
1357    When used as a ``side_effect`` value, if one of ``value`` or ``then`` is an
1358    :class:`Exception` type, an instance of this exception will be raised.
1359
1360    >>> mock.recv = Mock(side_effect=return_once(b"data"))
1361    >>> mock.recv()
1362    b"data"
1363    >>> repr(mock.recv())
1364    'None'
1365    >>> repr(mock.recv())
1366    'None'
1367
1368    >>> mock.recv = Mock(side_effect=return_once(b"data", then=BlockingIOError))
1369    >>> mock.recv()
1370    b"data"
1371    >>> mock.recv()
1372    Traceback BlockingIOError
1373
1374    :param value: value to be returned once by the mock when called.
1375
1376    :param then: value returned for any subsequent call.
1377
1378    .. versionadded:: 0.4
1379    """
1380    yield value
1381    while True:
1382        yield then
1383