1import asyncio
2import inspect
3import re
4import unittest
5
6from unittest.mock import (ANY, call, AsyncMock, patch, MagicMock, Mock,
7                           create_autospec, sentinel, _CallList)
8
9
10def tearDownModule():
11    asyncio.set_event_loop_policy(None)
12
13
14class AsyncClass:
15    def __init__(self):
16        pass
17    async def async_method(self):
18        pass
19    def normal_method(self):
20        pass
21
22    @classmethod
23    async def async_class_method(cls):
24        pass
25
26    @staticmethod
27    async def async_static_method():
28        pass
29
30
31class AwaitableClass:
32    def __await__(self):
33        yield
34
35async def async_func():
36    pass
37
38async def async_func_args(a, b, *, c):
39    pass
40
41def normal_func():
42    pass
43
44class NormalClass(object):
45    def a(self):
46        pass
47
48
49async_foo_name = f'{__name__}.AsyncClass'
50normal_foo_name = f'{__name__}.NormalClass'
51
52
53class AsyncPatchDecoratorTest(unittest.TestCase):
54    def test_is_coroutine_function_patch(self):
55        @patch.object(AsyncClass, 'async_method')
56        def test_async(mock_method):
57            self.assertTrue(asyncio.iscoroutinefunction(mock_method))
58        test_async()
59
60    def test_is_async_patch(self):
61        @patch.object(AsyncClass, 'async_method')
62        def test_async(mock_method):
63            m = mock_method()
64            self.assertTrue(inspect.isawaitable(m))
65            asyncio.run(m)
66
67        @patch(f'{async_foo_name}.async_method')
68        def test_no_parent_attribute(mock_method):
69            m = mock_method()
70            self.assertTrue(inspect.isawaitable(m))
71            asyncio.run(m)
72
73        test_async()
74        test_no_parent_attribute()
75
76    def test_is_AsyncMock_patch(self):
77        @patch.object(AsyncClass, 'async_method')
78        def test_async(mock_method):
79            self.assertIsInstance(mock_method, AsyncMock)
80
81        test_async()
82
83    def test_is_AsyncMock_patch_staticmethod(self):
84        @patch.object(AsyncClass, 'async_static_method')
85        def test_async(mock_method):
86            self.assertIsInstance(mock_method, AsyncMock)
87
88        test_async()
89
90    def test_is_AsyncMock_patch_classmethod(self):
91        @patch.object(AsyncClass, 'async_class_method')
92        def test_async(mock_method):
93            self.assertIsInstance(mock_method, AsyncMock)
94
95        test_async()
96
97    def test_async_def_patch(self):
98        @patch(f"{__name__}.async_func", return_value=1)
99        @patch(f"{__name__}.async_func_args", return_value=2)
100        async def test_async(func_args_mock, func_mock):
101            self.assertEqual(func_args_mock._mock_name, "async_func_args")
102            self.assertEqual(func_mock._mock_name, "async_func")
103
104            self.assertIsInstance(async_func, AsyncMock)
105            self.assertIsInstance(async_func_args, AsyncMock)
106
107            self.assertEqual(await async_func(), 1)
108            self.assertEqual(await async_func_args(1, 2, c=3), 2)
109
110        asyncio.run(test_async())
111        self.assertTrue(inspect.iscoroutinefunction(async_func))
112
113
114class AsyncPatchCMTest(unittest.TestCase):
115    def test_is_async_function_cm(self):
116        def test_async():
117            with patch.object(AsyncClass, 'async_method') as mock_method:
118                self.assertTrue(asyncio.iscoroutinefunction(mock_method))
119
120        test_async()
121
122    def test_is_async_cm(self):
123        def test_async():
124            with patch.object(AsyncClass, 'async_method') as mock_method:
125                m = mock_method()
126                self.assertTrue(inspect.isawaitable(m))
127                asyncio.run(m)
128
129        test_async()
130
131    def test_is_AsyncMock_cm(self):
132        def test_async():
133            with patch.object(AsyncClass, 'async_method') as mock_method:
134                self.assertIsInstance(mock_method, AsyncMock)
135
136        test_async()
137
138    def test_async_def_cm(self):
139        async def test_async():
140            with patch(f"{__name__}.async_func", AsyncMock()):
141                self.assertIsInstance(async_func, AsyncMock)
142            self.assertTrue(inspect.iscoroutinefunction(async_func))
143
144        asyncio.run(test_async())
145
146
147class AsyncMockTest(unittest.TestCase):
148    def test_iscoroutinefunction_default(self):
149        mock = AsyncMock()
150        self.assertTrue(asyncio.iscoroutinefunction(mock))
151
152    def test_iscoroutinefunction_function(self):
153        async def foo(): pass
154        mock = AsyncMock(foo)
155        self.assertTrue(asyncio.iscoroutinefunction(mock))
156        self.assertTrue(inspect.iscoroutinefunction(mock))
157
158    def test_isawaitable(self):
159        mock = AsyncMock()
160        m = mock()
161        self.assertTrue(inspect.isawaitable(m))
162        asyncio.run(m)
163        self.assertIn('assert_awaited', dir(mock))
164
165    def test_iscoroutinefunction_normal_function(self):
166        def foo(): pass
167        mock = AsyncMock(foo)
168        self.assertTrue(asyncio.iscoroutinefunction(mock))
169        self.assertTrue(inspect.iscoroutinefunction(mock))
170
171    def test_future_isfuture(self):
172        loop = asyncio.new_event_loop()
173        asyncio.set_event_loop(loop)
174        fut = asyncio.Future()
175        loop.stop()
176        loop.close()
177        mock = AsyncMock(fut)
178        self.assertIsInstance(mock, asyncio.Future)
179
180
181class AsyncAutospecTest(unittest.TestCase):
182    def test_is_AsyncMock_patch(self):
183        @patch(async_foo_name, autospec=True)
184        def test_async(mock_method):
185            self.assertIsInstance(mock_method.async_method, AsyncMock)
186            self.assertIsInstance(mock_method, MagicMock)
187
188        @patch(async_foo_name, autospec=True)
189        def test_normal_method(mock_method):
190            self.assertIsInstance(mock_method.normal_method, MagicMock)
191
192        test_async()
193        test_normal_method()
194
195    def test_create_autospec_instance(self):
196        with self.assertRaises(RuntimeError):
197            create_autospec(async_func, instance=True)
198
199    def test_create_autospec_awaitable_class(self):
200        awaitable_mock = create_autospec(spec=AwaitableClass())
201        self.assertIsInstance(create_autospec(awaitable_mock), AsyncMock)
202
203    def test_create_autospec(self):
204        spec = create_autospec(async_func_args)
205        awaitable = spec(1, 2, c=3)
206        async def main():
207            await awaitable
208
209        self.assertEqual(spec.await_count, 0)
210        self.assertIsNone(spec.await_args)
211        self.assertEqual(spec.await_args_list, [])
212        spec.assert_not_awaited()
213
214        asyncio.run(main())
215
216        self.assertTrue(asyncio.iscoroutinefunction(spec))
217        self.assertTrue(asyncio.iscoroutine(awaitable))
218        self.assertEqual(spec.await_count, 1)
219        self.assertEqual(spec.await_args, call(1, 2, c=3))
220        self.assertEqual(spec.await_args_list, [call(1, 2, c=3)])
221        spec.assert_awaited_once()
222        spec.assert_awaited_once_with(1, 2, c=3)
223        spec.assert_awaited_with(1, 2, c=3)
224        spec.assert_awaited()
225
226    def test_patch_with_autospec(self):
227
228        async def test_async():
229            with patch(f"{__name__}.async_func_args", autospec=True) as mock_method:
230                awaitable = mock_method(1, 2, c=3)
231                self.assertIsInstance(mock_method.mock, AsyncMock)
232
233                self.assertTrue(asyncio.iscoroutinefunction(mock_method))
234                self.assertTrue(asyncio.iscoroutine(awaitable))
235                self.assertTrue(inspect.isawaitable(awaitable))
236
237                # Verify the default values during mock setup
238                self.assertEqual(mock_method.await_count, 0)
239                self.assertEqual(mock_method.await_args_list, [])
240                self.assertIsNone(mock_method.await_args)
241                mock_method.assert_not_awaited()
242
243                await awaitable
244
245            self.assertEqual(mock_method.await_count, 1)
246            self.assertEqual(mock_method.await_args, call(1, 2, c=3))
247            self.assertEqual(mock_method.await_args_list, [call(1, 2, c=3)])
248            mock_method.assert_awaited_once()
249            mock_method.assert_awaited_once_with(1, 2, c=3)
250            mock_method.assert_awaited_with(1, 2, c=3)
251            mock_method.assert_awaited()
252
253            mock_method.reset_mock()
254            self.assertEqual(mock_method.await_count, 0)
255            self.assertIsNone(mock_method.await_args)
256            self.assertEqual(mock_method.await_args_list, [])
257
258        asyncio.run(test_async())
259
260
261class AsyncSpecTest(unittest.TestCase):
262    def test_spec_normal_methods_on_class(self):
263        def inner_test(mock_type):
264            mock = mock_type(AsyncClass)
265            self.assertIsInstance(mock.async_method, AsyncMock)
266            self.assertIsInstance(mock.normal_method, MagicMock)
267
268        for mock_type in [AsyncMock, MagicMock]:
269            with self.subTest(f"test method types with {mock_type}"):
270                inner_test(mock_type)
271
272    def test_spec_normal_methods_on_class_with_mock(self):
273        mock = Mock(AsyncClass)
274        self.assertIsInstance(mock.async_method, AsyncMock)
275        self.assertIsInstance(mock.normal_method, Mock)
276
277    def test_spec_mock_type_kw(self):
278        def inner_test(mock_type):
279            async_mock = mock_type(spec=async_func)
280            self.assertIsInstance(async_mock, mock_type)
281            with self.assertWarns(RuntimeWarning):
282                # Will raise a warning because never awaited
283                self.assertTrue(inspect.isawaitable(async_mock()))
284
285            sync_mock = mock_type(spec=normal_func)
286            self.assertIsInstance(sync_mock, mock_type)
287
288        for mock_type in [AsyncMock, MagicMock, Mock]:
289            with self.subTest(f"test spec kwarg with {mock_type}"):
290                inner_test(mock_type)
291
292    def test_spec_mock_type_positional(self):
293        def inner_test(mock_type):
294            async_mock = mock_type(async_func)
295            self.assertIsInstance(async_mock, mock_type)
296            with self.assertWarns(RuntimeWarning):
297                # Will raise a warning because never awaited
298                self.assertTrue(inspect.isawaitable(async_mock()))
299
300            sync_mock = mock_type(normal_func)
301            self.assertIsInstance(sync_mock, mock_type)
302
303        for mock_type in [AsyncMock, MagicMock, Mock]:
304            with self.subTest(f"test spec positional with {mock_type}"):
305                inner_test(mock_type)
306
307    def test_spec_as_normal_kw_AsyncMock(self):
308        mock = AsyncMock(spec=normal_func)
309        self.assertIsInstance(mock, AsyncMock)
310        m = mock()
311        self.assertTrue(inspect.isawaitable(m))
312        asyncio.run(m)
313
314    def test_spec_as_normal_positional_AsyncMock(self):
315        mock = AsyncMock(normal_func)
316        self.assertIsInstance(mock, AsyncMock)
317        m = mock()
318        self.assertTrue(inspect.isawaitable(m))
319        asyncio.run(m)
320
321    def test_spec_async_mock(self):
322        @patch.object(AsyncClass, 'async_method', spec=True)
323        def test_async(mock_method):
324            self.assertIsInstance(mock_method, AsyncMock)
325
326        test_async()
327
328    def test_spec_parent_not_async_attribute_is(self):
329        @patch(async_foo_name, spec=True)
330        def test_async(mock_method):
331            self.assertIsInstance(mock_method, MagicMock)
332            self.assertIsInstance(mock_method.async_method, AsyncMock)
333
334        test_async()
335
336    def test_target_async_spec_not(self):
337        @patch.object(AsyncClass, 'async_method', spec=NormalClass.a)
338        def test_async_attribute(mock_method):
339            self.assertIsInstance(mock_method, MagicMock)
340            self.assertFalse(inspect.iscoroutine(mock_method))
341            self.assertFalse(inspect.isawaitable(mock_method))
342
343        test_async_attribute()
344
345    def test_target_not_async_spec_is(self):
346        @patch.object(NormalClass, 'a', spec=async_func)
347        def test_attribute_not_async_spec_is(mock_async_func):
348            self.assertIsInstance(mock_async_func, AsyncMock)
349        test_attribute_not_async_spec_is()
350
351    def test_spec_async_attributes(self):
352        @patch(normal_foo_name, spec=AsyncClass)
353        def test_async_attributes_coroutines(MockNormalClass):
354            self.assertIsInstance(MockNormalClass.async_method, AsyncMock)
355            self.assertIsInstance(MockNormalClass, MagicMock)
356
357        test_async_attributes_coroutines()
358
359
360class AsyncSpecSetTest(unittest.TestCase):
361    def test_is_AsyncMock_patch(self):
362        @patch.object(AsyncClass, 'async_method', spec_set=True)
363        def test_async(async_method):
364            self.assertIsInstance(async_method, AsyncMock)
365
366    def test_is_async_AsyncMock(self):
367        mock = AsyncMock(spec_set=AsyncClass.async_method)
368        self.assertTrue(asyncio.iscoroutinefunction(mock))
369        self.assertIsInstance(mock, AsyncMock)
370
371    def test_is_child_AsyncMock(self):
372        mock = MagicMock(spec_set=AsyncClass)
373        self.assertTrue(asyncio.iscoroutinefunction(mock.async_method))
374        self.assertFalse(asyncio.iscoroutinefunction(mock.normal_method))
375        self.assertIsInstance(mock.async_method, AsyncMock)
376        self.assertIsInstance(mock.normal_method, MagicMock)
377        self.assertIsInstance(mock, MagicMock)
378
379    def test_magicmock_lambda_spec(self):
380        mock_obj = MagicMock()
381        mock_obj.mock_func = MagicMock(spec=lambda x: x)
382
383        with patch.object(mock_obj, "mock_func") as cm:
384            self.assertIsInstance(cm, MagicMock)
385
386
387class AsyncArguments(unittest.IsolatedAsyncioTestCase):
388    async def test_add_return_value(self):
389        async def addition(self, var):
390            return var + 1
391
392        mock = AsyncMock(addition, return_value=10)
393        output = await mock(5)
394
395        self.assertEqual(output, 10)
396
397    async def test_add_side_effect_exception(self):
398        async def addition(var):
399            return var + 1
400        mock = AsyncMock(addition, side_effect=Exception('err'))
401        with self.assertRaises(Exception):
402            await mock(5)
403
404    async def test_add_side_effect_coroutine(self):
405        async def addition(var):
406            return var + 1
407        mock = AsyncMock(side_effect=addition)
408        result = await mock(5)
409        self.assertEqual(result, 6)
410
411    async def test_add_side_effect_normal_function(self):
412        def addition(var):
413            return var + 1
414        mock = AsyncMock(side_effect=addition)
415        result = await mock(5)
416        self.assertEqual(result, 6)
417
418    async def test_add_side_effect_iterable(self):
419        vals = [1, 2, 3]
420        mock = AsyncMock(side_effect=vals)
421        for item in vals:
422            self.assertEqual(await mock(), item)
423
424        with self.assertRaises(StopAsyncIteration) as e:
425            await mock()
426
427    async def test_add_side_effect_exception_iterable(self):
428        class SampleException(Exception):
429            pass
430
431        vals = [1, SampleException("foo")]
432        mock = AsyncMock(side_effect=vals)
433        self.assertEqual(await mock(), 1)
434
435        with self.assertRaises(SampleException) as e:
436            await mock()
437
438    async def test_return_value_AsyncMock(self):
439        value = AsyncMock(return_value=10)
440        mock = AsyncMock(return_value=value)
441        result = await mock()
442        self.assertIs(result, value)
443
444    async def test_return_value_awaitable(self):
445        fut = asyncio.Future()
446        fut.set_result(None)
447        mock = AsyncMock(return_value=fut)
448        result = await mock()
449        self.assertIsInstance(result, asyncio.Future)
450
451    async def test_side_effect_awaitable_values(self):
452        fut = asyncio.Future()
453        fut.set_result(None)
454
455        mock = AsyncMock(side_effect=[fut])
456        result = await mock()
457        self.assertIsInstance(result, asyncio.Future)
458
459        with self.assertRaises(StopAsyncIteration):
460            await mock()
461
462    async def test_side_effect_is_AsyncMock(self):
463        effect = AsyncMock(return_value=10)
464        mock = AsyncMock(side_effect=effect)
465
466        result = await mock()
467        self.assertEqual(result, 10)
468
469    async def test_wraps_coroutine(self):
470        value = asyncio.Future()
471
472        ran = False
473        async def inner():
474            nonlocal ran
475            ran = True
476            return value
477
478        mock = AsyncMock(wraps=inner)
479        result = await mock()
480        self.assertEqual(result, value)
481        mock.assert_awaited()
482        self.assertTrue(ran)
483
484    async def test_wraps_normal_function(self):
485        value = 1
486
487        ran = False
488        def inner():
489            nonlocal ran
490            ran = True
491            return value
492
493        mock = AsyncMock(wraps=inner)
494        result = await mock()
495        self.assertEqual(result, value)
496        mock.assert_awaited()
497        self.assertTrue(ran)
498
499    async def test_await_args_list_order(self):
500        async_mock = AsyncMock()
501        mock2 = async_mock(2)
502        mock1 = async_mock(1)
503        await mock1
504        await mock2
505        async_mock.assert_has_awaits([call(1), call(2)])
506        self.assertEqual(async_mock.await_args_list, [call(1), call(2)])
507        self.assertEqual(async_mock.call_args_list, [call(2), call(1)])
508
509
510class AsyncMagicMethods(unittest.TestCase):
511    def test_async_magic_methods_return_async_mocks(self):
512        m_mock = MagicMock()
513        self.assertIsInstance(m_mock.__aenter__, AsyncMock)
514        self.assertIsInstance(m_mock.__aexit__, AsyncMock)
515        self.assertIsInstance(m_mock.__anext__, AsyncMock)
516        # __aiter__ is actually a synchronous object
517        # so should return a MagicMock
518        self.assertIsInstance(m_mock.__aiter__, MagicMock)
519
520    def test_sync_magic_methods_return_magic_mocks(self):
521        a_mock = AsyncMock()
522        self.assertIsInstance(a_mock.__enter__, MagicMock)
523        self.assertIsInstance(a_mock.__exit__, MagicMock)
524        self.assertIsInstance(a_mock.__next__, MagicMock)
525        self.assertIsInstance(a_mock.__len__, MagicMock)
526
527    def test_magicmock_has_async_magic_methods(self):
528        m_mock = MagicMock()
529        self.assertTrue(hasattr(m_mock, "__aenter__"))
530        self.assertTrue(hasattr(m_mock, "__aexit__"))
531        self.assertTrue(hasattr(m_mock, "__anext__"))
532
533    def test_asyncmock_has_sync_magic_methods(self):
534        a_mock = AsyncMock()
535        self.assertTrue(hasattr(a_mock, "__enter__"))
536        self.assertTrue(hasattr(a_mock, "__exit__"))
537        self.assertTrue(hasattr(a_mock, "__next__"))
538        self.assertTrue(hasattr(a_mock, "__len__"))
539
540    def test_magic_methods_are_async_functions(self):
541        m_mock = MagicMock()
542        self.assertIsInstance(m_mock.__aenter__, AsyncMock)
543        self.assertIsInstance(m_mock.__aexit__, AsyncMock)
544        # AsyncMocks are also coroutine functions
545        self.assertTrue(asyncio.iscoroutinefunction(m_mock.__aenter__))
546        self.assertTrue(asyncio.iscoroutinefunction(m_mock.__aexit__))
547
548class AsyncContextManagerTest(unittest.TestCase):
549    class WithAsyncContextManager:
550        async def __aenter__(self, *args, **kwargs):
551            self.entered = True
552            return self
553
554        async def __aexit__(self, *args, **kwargs):
555            self.exited = True
556
557    class WithSyncContextManager:
558        def __enter__(self, *args, **kwargs):
559            return self
560
561        def __exit__(self, *args, **kwargs):
562            pass
563
564    class ProductionCode:
565        # Example real-world(ish) code
566        def __init__(self):
567            self.session = None
568
569        async def main(self):
570            async with self.session.post('https://python.org') as response:
571                val = await response.json()
572                return val
573
574    def test_set_return_value_of_aenter(self):
575        def inner_test(mock_type):
576            pc = self.ProductionCode()
577            pc.session = MagicMock(name='sessionmock')
578            cm = mock_type(name='magic_cm')
579            response = AsyncMock(name='response')
580            response.json = AsyncMock(return_value={'json': 123})
581            cm.__aenter__.return_value = response
582            pc.session.post.return_value = cm
583            result = asyncio.run(pc.main())
584            self.assertEqual(result, {'json': 123})
585
586        for mock_type in [AsyncMock, MagicMock]:
587            with self.subTest(f"test set return value of aenter with {mock_type}"):
588                inner_test(mock_type)
589
590    def test_mock_supports_async_context_manager(self):
591        def inner_test(mock_type):
592            called = False
593            cm = self.WithAsyncContextManager()
594            cm_mock = mock_type(cm)
595
596            async def use_context_manager():
597                nonlocal called
598                async with cm_mock as result:
599                    called = True
600                return result
601
602            cm_result = asyncio.run(use_context_manager())
603            self.assertTrue(called)
604            self.assertTrue(cm_mock.__aenter__.called)
605            self.assertTrue(cm_mock.__aexit__.called)
606            cm_mock.__aenter__.assert_awaited()
607            cm_mock.__aexit__.assert_awaited()
608            # We mock __aenter__ so it does not return self
609            self.assertIsNot(cm_mock, cm_result)
610
611        for mock_type in [AsyncMock, MagicMock]:
612            with self.subTest(f"test context manager magics with {mock_type}"):
613                inner_test(mock_type)
614
615    def test_mock_customize_async_context_manager(self):
616        instance = self.WithAsyncContextManager()
617        mock_instance = MagicMock(instance)
618
619        expected_result = object()
620        mock_instance.__aenter__.return_value = expected_result
621
622        async def use_context_manager():
623            async with mock_instance as result:
624                return result
625
626        self.assertIs(asyncio.run(use_context_manager()), expected_result)
627
628    def test_mock_customize_async_context_manager_with_coroutine(self):
629        enter_called = False
630        exit_called = False
631
632        async def enter_coroutine(*args):
633            nonlocal enter_called
634            enter_called = True
635
636        async def exit_coroutine(*args):
637            nonlocal exit_called
638            exit_called = True
639
640        instance = self.WithAsyncContextManager()
641        mock_instance = MagicMock(instance)
642
643        mock_instance.__aenter__ = enter_coroutine
644        mock_instance.__aexit__ = exit_coroutine
645
646        async def use_context_manager():
647            async with mock_instance:
648                pass
649
650        asyncio.run(use_context_manager())
651        self.assertTrue(enter_called)
652        self.assertTrue(exit_called)
653
654    def test_context_manager_raise_exception_by_default(self):
655        async def raise_in(context_manager):
656            async with context_manager:
657                raise TypeError()
658
659        instance = self.WithAsyncContextManager()
660        mock_instance = MagicMock(instance)
661        with self.assertRaises(TypeError):
662            asyncio.run(raise_in(mock_instance))
663
664
665class AsyncIteratorTest(unittest.TestCase):
666    class WithAsyncIterator(object):
667        def __init__(self):
668            self.items = ["foo", "NormalFoo", "baz"]
669
670        def __aiter__(self):
671            return self
672
673        async def __anext__(self):
674            try:
675                return self.items.pop()
676            except IndexError:
677                pass
678
679            raise StopAsyncIteration
680
681    def test_aiter_set_return_value(self):
682        mock_iter = AsyncMock(name="tester")
683        mock_iter.__aiter__.return_value = [1, 2, 3]
684        async def main():
685            return [i async for i in mock_iter]
686        result = asyncio.run(main())
687        self.assertEqual(result, [1, 2, 3])
688
689    def test_mock_aiter_and_anext_asyncmock(self):
690        def inner_test(mock_type):
691            instance = self.WithAsyncIterator()
692            mock_instance = mock_type(instance)
693            # Check that the mock and the real thing bahave the same
694            # __aiter__ is not actually async, so not a coroutinefunction
695            self.assertFalse(asyncio.iscoroutinefunction(instance.__aiter__))
696            self.assertFalse(asyncio.iscoroutinefunction(mock_instance.__aiter__))
697            # __anext__ is async
698            self.assertTrue(asyncio.iscoroutinefunction(instance.__anext__))
699            self.assertTrue(asyncio.iscoroutinefunction(mock_instance.__anext__))
700
701        for mock_type in [AsyncMock, MagicMock]:
702            with self.subTest(f"test aiter and anext corourtine with {mock_type}"):
703                inner_test(mock_type)
704
705
706    def test_mock_async_for(self):
707        async def iterate(iterator):
708            accumulator = []
709            async for item in iterator:
710                accumulator.append(item)
711
712            return accumulator
713
714        expected = ["FOO", "BAR", "BAZ"]
715        def test_default(mock_type):
716            mock_instance = mock_type(self.WithAsyncIterator())
717            self.assertEqual(asyncio.run(iterate(mock_instance)), [])
718
719
720        def test_set_return_value(mock_type):
721            mock_instance = mock_type(self.WithAsyncIterator())
722            mock_instance.__aiter__.return_value = expected[:]
723            self.assertEqual(asyncio.run(iterate(mock_instance)), expected)
724
725        def test_set_return_value_iter(mock_type):
726            mock_instance = mock_type(self.WithAsyncIterator())
727            mock_instance.__aiter__.return_value = iter(expected[:])
728            self.assertEqual(asyncio.run(iterate(mock_instance)), expected)
729
730        for mock_type in [AsyncMock, MagicMock]:
731            with self.subTest(f"default value with {mock_type}"):
732                test_default(mock_type)
733
734            with self.subTest(f"set return_value with {mock_type}"):
735                test_set_return_value(mock_type)
736
737            with self.subTest(f"set return_value iterator with {mock_type}"):
738                test_set_return_value_iter(mock_type)
739
740
741class AsyncMockAssert(unittest.TestCase):
742    def setUp(self):
743        self.mock = AsyncMock()
744
745    async def _runnable_test(self, *args, **kwargs):
746        await self.mock(*args, **kwargs)
747
748    async def _await_coroutine(self, coroutine):
749        return await coroutine
750
751    def test_assert_called_but_not_awaited(self):
752        mock = AsyncMock(AsyncClass)
753        with self.assertWarns(RuntimeWarning):
754            # Will raise a warning because never awaited
755            mock.async_method()
756        self.assertTrue(asyncio.iscoroutinefunction(mock.async_method))
757        mock.async_method.assert_called()
758        mock.async_method.assert_called_once()
759        mock.async_method.assert_called_once_with()
760        with self.assertRaises(AssertionError):
761            mock.assert_awaited()
762        with self.assertRaises(AssertionError):
763            mock.async_method.assert_awaited()
764
765    def test_assert_called_then_awaited(self):
766        mock = AsyncMock(AsyncClass)
767        mock_coroutine = mock.async_method()
768        mock.async_method.assert_called()
769        mock.async_method.assert_called_once()
770        mock.async_method.assert_called_once_with()
771        with self.assertRaises(AssertionError):
772            mock.async_method.assert_awaited()
773
774        asyncio.run(self._await_coroutine(mock_coroutine))
775        # Assert we haven't re-called the function
776        mock.async_method.assert_called_once()
777        mock.async_method.assert_awaited()
778        mock.async_method.assert_awaited_once()
779        mock.async_method.assert_awaited_once_with()
780
781    def test_assert_called_and_awaited_at_same_time(self):
782        with self.assertRaises(AssertionError):
783            self.mock.assert_awaited()
784
785        with self.assertRaises(AssertionError):
786            self.mock.assert_called()
787
788        asyncio.run(self._runnable_test())
789        self.mock.assert_called_once()
790        self.mock.assert_awaited_once()
791
792    def test_assert_called_twice_and_awaited_once(self):
793        mock = AsyncMock(AsyncClass)
794        coroutine = mock.async_method()
795        with self.assertWarns(RuntimeWarning):
796            # The first call will be awaited so no warning there
797            # But this call will never get awaited, so it will warn here
798            mock.async_method()
799        with self.assertRaises(AssertionError):
800            mock.async_method.assert_awaited()
801        mock.async_method.assert_called()
802        asyncio.run(self._await_coroutine(coroutine))
803        mock.async_method.assert_awaited()
804        mock.async_method.assert_awaited_once()
805
806    def test_assert_called_once_and_awaited_twice(self):
807        mock = AsyncMock(AsyncClass)
808        coroutine = mock.async_method()
809        mock.async_method.assert_called_once()
810        asyncio.run(self._await_coroutine(coroutine))
811        with self.assertRaises(RuntimeError):
812            # Cannot reuse already awaited coroutine
813            asyncio.run(self._await_coroutine(coroutine))
814        mock.async_method.assert_awaited()
815
816    def test_assert_awaited_but_not_called(self):
817        with self.assertRaises(AssertionError):
818            self.mock.assert_awaited()
819        with self.assertRaises(AssertionError):
820            self.mock.assert_called()
821        with self.assertRaises(TypeError):
822            # You cannot await an AsyncMock, it must be a coroutine
823            asyncio.run(self._await_coroutine(self.mock))
824
825        with self.assertRaises(AssertionError):
826            self.mock.assert_awaited()
827        with self.assertRaises(AssertionError):
828            self.mock.assert_called()
829
830    def test_assert_has_calls_not_awaits(self):
831        kalls = [call('foo')]
832        with self.assertWarns(RuntimeWarning):
833            # Will raise a warning because never awaited
834            self.mock('foo')
835        self.mock.assert_has_calls(kalls)
836        with self.assertRaises(AssertionError):
837            self.mock.assert_has_awaits(kalls)
838
839    def test_assert_has_mock_calls_on_async_mock_no_spec(self):
840        with self.assertWarns(RuntimeWarning):
841            # Will raise a warning because never awaited
842            self.mock()
843        kalls_empty = [('', (), {})]
844        self.assertEqual(self.mock.mock_calls, kalls_empty)
845
846        with self.assertWarns(RuntimeWarning):
847            # Will raise a warning because never awaited
848            self.mock('foo')
849            self.mock('baz')
850        mock_kalls = ([call(), call('foo'), call('baz')])
851        self.assertEqual(self.mock.mock_calls, mock_kalls)
852
853    def test_assert_has_mock_calls_on_async_mock_with_spec(self):
854        a_class_mock = AsyncMock(AsyncClass)
855        with self.assertWarns(RuntimeWarning):
856            # Will raise a warning because never awaited
857            a_class_mock.async_method()
858        kalls_empty = [('', (), {})]
859        self.assertEqual(a_class_mock.async_method.mock_calls, kalls_empty)
860        self.assertEqual(a_class_mock.mock_calls, [call.async_method()])
861
862        with self.assertWarns(RuntimeWarning):
863            # Will raise a warning because never awaited
864            a_class_mock.async_method(1, 2, 3, a=4, b=5)
865        method_kalls = [call(), call(1, 2, 3, a=4, b=5)]
866        mock_kalls = [call.async_method(), call.async_method(1, 2, 3, a=4, b=5)]
867        self.assertEqual(a_class_mock.async_method.mock_calls, method_kalls)
868        self.assertEqual(a_class_mock.mock_calls, mock_kalls)
869
870    def test_async_method_calls_recorded(self):
871        with self.assertWarns(RuntimeWarning):
872            # Will raise warnings because never awaited
873            self.mock.something(3, fish=None)
874            self.mock.something_else.something(6, cake=sentinel.Cake)
875
876        self.assertEqual(self.mock.method_calls, [
877            ("something", (3,), {'fish': None}),
878            ("something_else.something", (6,), {'cake': sentinel.Cake})
879        ],
880            "method calls not recorded correctly")
881        self.assertEqual(self.mock.something_else.method_calls,
882                         [("something", (6,), {'cake': sentinel.Cake})],
883                         "method calls not recorded correctly")
884
885    def test_async_arg_lists(self):
886        def assert_attrs(mock):
887            names = ('call_args_list', 'method_calls', 'mock_calls')
888            for name in names:
889                attr = getattr(mock, name)
890                self.assertIsInstance(attr, _CallList)
891                self.assertIsInstance(attr, list)
892                self.assertEqual(attr, [])
893
894        assert_attrs(self.mock)
895        with self.assertWarns(RuntimeWarning):
896            # Will raise warnings because never awaited
897            self.mock()
898            self.mock(1, 2)
899            self.mock(a=3)
900
901        self.mock.reset_mock()
902        assert_attrs(self.mock)
903
904        a_mock = AsyncMock(AsyncClass)
905        with self.assertWarns(RuntimeWarning):
906            # Will raise warnings because never awaited
907            a_mock.async_method()
908            a_mock.async_method(1, a=3)
909
910        a_mock.reset_mock()
911        assert_attrs(a_mock)
912
913    def test_assert_awaited(self):
914        with self.assertRaises(AssertionError):
915            self.mock.assert_awaited()
916
917        asyncio.run(self._runnable_test())
918        self.mock.assert_awaited()
919
920    def test_assert_awaited_once(self):
921        with self.assertRaises(AssertionError):
922            self.mock.assert_awaited_once()
923
924        asyncio.run(self._runnable_test())
925        self.mock.assert_awaited_once()
926
927        asyncio.run(self._runnable_test())
928        with self.assertRaises(AssertionError):
929            self.mock.assert_awaited_once()
930
931    def test_assert_awaited_with(self):
932        msg = 'Not awaited'
933        with self.assertRaisesRegex(AssertionError, msg):
934            self.mock.assert_awaited_with('foo')
935
936        asyncio.run(self._runnable_test())
937        msg = 'expected await not found'
938        with self.assertRaisesRegex(AssertionError, msg):
939            self.mock.assert_awaited_with('foo')
940
941        asyncio.run(self._runnable_test('foo'))
942        self.mock.assert_awaited_with('foo')
943
944        asyncio.run(self._runnable_test('SomethingElse'))
945        with self.assertRaises(AssertionError):
946            self.mock.assert_awaited_with('foo')
947
948    def test_assert_awaited_once_with(self):
949        with self.assertRaises(AssertionError):
950            self.mock.assert_awaited_once_with('foo')
951
952        asyncio.run(self._runnable_test('foo'))
953        self.mock.assert_awaited_once_with('foo')
954
955        asyncio.run(self._runnable_test('foo'))
956        with self.assertRaises(AssertionError):
957            self.mock.assert_awaited_once_with('foo')
958
959    def test_assert_any_wait(self):
960        with self.assertRaises(AssertionError):
961            self.mock.assert_any_await('foo')
962
963        asyncio.run(self._runnable_test('baz'))
964        with self.assertRaises(AssertionError):
965            self.mock.assert_any_await('foo')
966
967        asyncio.run(self._runnable_test('foo'))
968        self.mock.assert_any_await('foo')
969
970        asyncio.run(self._runnable_test('SomethingElse'))
971        self.mock.assert_any_await('foo')
972
973    def test_assert_has_awaits_no_order(self):
974        calls = [call('foo'), call('baz')]
975
976        with self.assertRaises(AssertionError) as cm:
977            self.mock.assert_has_awaits(calls)
978        self.assertEqual(len(cm.exception.args), 1)
979
980        asyncio.run(self._runnable_test('foo'))
981        with self.assertRaises(AssertionError):
982            self.mock.assert_has_awaits(calls)
983
984        asyncio.run(self._runnable_test('foo'))
985        with self.assertRaises(AssertionError):
986            self.mock.assert_has_awaits(calls)
987
988        asyncio.run(self._runnable_test('baz'))
989        self.mock.assert_has_awaits(calls)
990
991        asyncio.run(self._runnable_test('SomethingElse'))
992        self.mock.assert_has_awaits(calls)
993
994    def test_assert_has_awaits_ordered(self):
995        calls = [call('foo'), call('baz')]
996        with self.assertRaises(AssertionError):
997            self.mock.assert_has_awaits(calls, any_order=True)
998
999        asyncio.run(self._runnable_test('baz'))
1000        with self.assertRaises(AssertionError):
1001            self.mock.assert_has_awaits(calls, any_order=True)
1002
1003        asyncio.run(self._runnable_test('bamf'))
1004        with self.assertRaises(AssertionError):
1005            self.mock.assert_has_awaits(calls, any_order=True)
1006
1007        asyncio.run(self._runnable_test('foo'))
1008        self.mock.assert_has_awaits(calls, any_order=True)
1009
1010        asyncio.run(self._runnable_test('qux'))
1011        self.mock.assert_has_awaits(calls, any_order=True)
1012
1013    def test_assert_not_awaited(self):
1014        self.mock.assert_not_awaited()
1015
1016        asyncio.run(self._runnable_test())
1017        with self.assertRaises(AssertionError):
1018            self.mock.assert_not_awaited()
1019
1020    def test_assert_has_awaits_not_matching_spec_error(self):
1021        async def f(x=None): pass
1022
1023        self.mock = AsyncMock(spec=f)
1024        asyncio.run(self._runnable_test(1))
1025
1026        with self.assertRaisesRegex(
1027                AssertionError,
1028                '^{}$'.format(
1029                    re.escape('Awaits not found.\n'
1030                              'Expected: [call()]\n'
1031                              'Actual: [call(1)]'))) as cm:
1032            self.mock.assert_has_awaits([call()])
1033        self.assertIsNone(cm.exception.__cause__)
1034
1035        with self.assertRaisesRegex(
1036                AssertionError,
1037                '^{}$'.format(
1038                    re.escape(
1039                        'Error processing expected awaits.\n'
1040                        "Errors: [None, TypeError('too many positional "
1041                        "arguments')]\n"
1042                        'Expected: [call(), call(1, 2)]\n'
1043                        'Actual: [call(1)]'))) as cm:
1044            self.mock.assert_has_awaits([call(), call(1, 2)])
1045        self.assertIsInstance(cm.exception.__cause__, TypeError)
1046