1import asyncio 2import inspect 3import re 4import unittest 5 6from unittest.mock import (ANY, call, AsyncMock, patch, MagicMock, Mock, 7 create_autospec, sentinel, _CallList) 8 9 10def tearDownModule(): 11 asyncio.set_event_loop_policy(None) 12 13 14class AsyncClass: 15 def __init__(self): 16 pass 17 async def async_method(self): 18 pass 19 def normal_method(self): 20 pass 21 22 @classmethod 23 async def async_class_method(cls): 24 pass 25 26 @staticmethod 27 async def async_static_method(): 28 pass 29 30 31class AwaitableClass: 32 def __await__(self): 33 yield 34 35async def async_func(): 36 pass 37 38async def async_func_args(a, b, *, c): 39 pass 40 41def normal_func(): 42 pass 43 44class NormalClass(object): 45 def a(self): 46 pass 47 48 49async_foo_name = f'{__name__}.AsyncClass' 50normal_foo_name = f'{__name__}.NormalClass' 51 52 53class AsyncPatchDecoratorTest(unittest.TestCase): 54 def test_is_coroutine_function_patch(self): 55 @patch.object(AsyncClass, 'async_method') 56 def test_async(mock_method): 57 self.assertTrue(asyncio.iscoroutinefunction(mock_method)) 58 test_async() 59 60 def test_is_async_patch(self): 61 @patch.object(AsyncClass, 'async_method') 62 def test_async(mock_method): 63 m = mock_method() 64 self.assertTrue(inspect.isawaitable(m)) 65 asyncio.run(m) 66 67 @patch(f'{async_foo_name}.async_method') 68 def test_no_parent_attribute(mock_method): 69 m = mock_method() 70 self.assertTrue(inspect.isawaitable(m)) 71 asyncio.run(m) 72 73 test_async() 74 test_no_parent_attribute() 75 76 def test_is_AsyncMock_patch(self): 77 @patch.object(AsyncClass, 'async_method') 78 def test_async(mock_method): 79 self.assertIsInstance(mock_method, AsyncMock) 80 81 test_async() 82 83 def test_is_AsyncMock_patch_staticmethod(self): 84 @patch.object(AsyncClass, 'async_static_method') 85 def test_async(mock_method): 86 self.assertIsInstance(mock_method, AsyncMock) 87 88 test_async() 89 90 def test_is_AsyncMock_patch_classmethod(self): 91 @patch.object(AsyncClass, 'async_class_method') 92 def test_async(mock_method): 93 self.assertIsInstance(mock_method, AsyncMock) 94 95 test_async() 96 97 def test_async_def_patch(self): 98 @patch(f"{__name__}.async_func", return_value=1) 99 @patch(f"{__name__}.async_func_args", return_value=2) 100 async def test_async(func_args_mock, func_mock): 101 self.assertEqual(func_args_mock._mock_name, "async_func_args") 102 self.assertEqual(func_mock._mock_name, "async_func") 103 104 self.assertIsInstance(async_func, AsyncMock) 105 self.assertIsInstance(async_func_args, AsyncMock) 106 107 self.assertEqual(await async_func(), 1) 108 self.assertEqual(await async_func_args(1, 2, c=3), 2) 109 110 asyncio.run(test_async()) 111 self.assertTrue(inspect.iscoroutinefunction(async_func)) 112 113 114class AsyncPatchCMTest(unittest.TestCase): 115 def test_is_async_function_cm(self): 116 def test_async(): 117 with patch.object(AsyncClass, 'async_method') as mock_method: 118 self.assertTrue(asyncio.iscoroutinefunction(mock_method)) 119 120 test_async() 121 122 def test_is_async_cm(self): 123 def test_async(): 124 with patch.object(AsyncClass, 'async_method') as mock_method: 125 m = mock_method() 126 self.assertTrue(inspect.isawaitable(m)) 127 asyncio.run(m) 128 129 test_async() 130 131 def test_is_AsyncMock_cm(self): 132 def test_async(): 133 with patch.object(AsyncClass, 'async_method') as mock_method: 134 self.assertIsInstance(mock_method, AsyncMock) 135 136 test_async() 137 138 def test_async_def_cm(self): 139 async def test_async(): 140 with patch(f"{__name__}.async_func", AsyncMock()): 141 self.assertIsInstance(async_func, AsyncMock) 142 self.assertTrue(inspect.iscoroutinefunction(async_func)) 143 144 asyncio.run(test_async()) 145 146 147class AsyncMockTest(unittest.TestCase): 148 def test_iscoroutinefunction_default(self): 149 mock = AsyncMock() 150 self.assertTrue(asyncio.iscoroutinefunction(mock)) 151 152 def test_iscoroutinefunction_function(self): 153 async def foo(): pass 154 mock = AsyncMock(foo) 155 self.assertTrue(asyncio.iscoroutinefunction(mock)) 156 self.assertTrue(inspect.iscoroutinefunction(mock)) 157 158 def test_isawaitable(self): 159 mock = AsyncMock() 160 m = mock() 161 self.assertTrue(inspect.isawaitable(m)) 162 asyncio.run(m) 163 self.assertIn('assert_awaited', dir(mock)) 164 165 def test_iscoroutinefunction_normal_function(self): 166 def foo(): pass 167 mock = AsyncMock(foo) 168 self.assertTrue(asyncio.iscoroutinefunction(mock)) 169 self.assertTrue(inspect.iscoroutinefunction(mock)) 170 171 def test_future_isfuture(self): 172 loop = asyncio.new_event_loop() 173 asyncio.set_event_loop(loop) 174 fut = asyncio.Future() 175 loop.stop() 176 loop.close() 177 mock = AsyncMock(fut) 178 self.assertIsInstance(mock, asyncio.Future) 179 180 181class AsyncAutospecTest(unittest.TestCase): 182 def test_is_AsyncMock_patch(self): 183 @patch(async_foo_name, autospec=True) 184 def test_async(mock_method): 185 self.assertIsInstance(mock_method.async_method, AsyncMock) 186 self.assertIsInstance(mock_method, MagicMock) 187 188 @patch(async_foo_name, autospec=True) 189 def test_normal_method(mock_method): 190 self.assertIsInstance(mock_method.normal_method, MagicMock) 191 192 test_async() 193 test_normal_method() 194 195 def test_create_autospec_instance(self): 196 with self.assertRaises(RuntimeError): 197 create_autospec(async_func, instance=True) 198 199 def test_create_autospec_awaitable_class(self): 200 awaitable_mock = create_autospec(spec=AwaitableClass()) 201 self.assertIsInstance(create_autospec(awaitable_mock), AsyncMock) 202 203 def test_create_autospec(self): 204 spec = create_autospec(async_func_args) 205 awaitable = spec(1, 2, c=3) 206 async def main(): 207 await awaitable 208 209 self.assertEqual(spec.await_count, 0) 210 self.assertIsNone(spec.await_args) 211 self.assertEqual(spec.await_args_list, []) 212 spec.assert_not_awaited() 213 214 asyncio.run(main()) 215 216 self.assertTrue(asyncio.iscoroutinefunction(spec)) 217 self.assertTrue(asyncio.iscoroutine(awaitable)) 218 self.assertEqual(spec.await_count, 1) 219 self.assertEqual(spec.await_args, call(1, 2, c=3)) 220 self.assertEqual(spec.await_args_list, [call(1, 2, c=3)]) 221 spec.assert_awaited_once() 222 spec.assert_awaited_once_with(1, 2, c=3) 223 spec.assert_awaited_with(1, 2, c=3) 224 spec.assert_awaited() 225 226 def test_patch_with_autospec(self): 227 228 async def test_async(): 229 with patch(f"{__name__}.async_func_args", autospec=True) as mock_method: 230 awaitable = mock_method(1, 2, c=3) 231 self.assertIsInstance(mock_method.mock, AsyncMock) 232 233 self.assertTrue(asyncio.iscoroutinefunction(mock_method)) 234 self.assertTrue(asyncio.iscoroutine(awaitable)) 235 self.assertTrue(inspect.isawaitable(awaitable)) 236 237 # Verify the default values during mock setup 238 self.assertEqual(mock_method.await_count, 0) 239 self.assertEqual(mock_method.await_args_list, []) 240 self.assertIsNone(mock_method.await_args) 241 mock_method.assert_not_awaited() 242 243 await awaitable 244 245 self.assertEqual(mock_method.await_count, 1) 246 self.assertEqual(mock_method.await_args, call(1, 2, c=3)) 247 self.assertEqual(mock_method.await_args_list, [call(1, 2, c=3)]) 248 mock_method.assert_awaited_once() 249 mock_method.assert_awaited_once_with(1, 2, c=3) 250 mock_method.assert_awaited_with(1, 2, c=3) 251 mock_method.assert_awaited() 252 253 mock_method.reset_mock() 254 self.assertEqual(mock_method.await_count, 0) 255 self.assertIsNone(mock_method.await_args) 256 self.assertEqual(mock_method.await_args_list, []) 257 258 asyncio.run(test_async()) 259 260 261class AsyncSpecTest(unittest.TestCase): 262 def test_spec_normal_methods_on_class(self): 263 def inner_test(mock_type): 264 mock = mock_type(AsyncClass) 265 self.assertIsInstance(mock.async_method, AsyncMock) 266 self.assertIsInstance(mock.normal_method, MagicMock) 267 268 for mock_type in [AsyncMock, MagicMock]: 269 with self.subTest(f"test method types with {mock_type}"): 270 inner_test(mock_type) 271 272 def test_spec_normal_methods_on_class_with_mock(self): 273 mock = Mock(AsyncClass) 274 self.assertIsInstance(mock.async_method, AsyncMock) 275 self.assertIsInstance(mock.normal_method, Mock) 276 277 def test_spec_mock_type_kw(self): 278 def inner_test(mock_type): 279 async_mock = mock_type(spec=async_func) 280 self.assertIsInstance(async_mock, mock_type) 281 with self.assertWarns(RuntimeWarning): 282 # Will raise a warning because never awaited 283 self.assertTrue(inspect.isawaitable(async_mock())) 284 285 sync_mock = mock_type(spec=normal_func) 286 self.assertIsInstance(sync_mock, mock_type) 287 288 for mock_type in [AsyncMock, MagicMock, Mock]: 289 with self.subTest(f"test spec kwarg with {mock_type}"): 290 inner_test(mock_type) 291 292 def test_spec_mock_type_positional(self): 293 def inner_test(mock_type): 294 async_mock = mock_type(async_func) 295 self.assertIsInstance(async_mock, mock_type) 296 with self.assertWarns(RuntimeWarning): 297 # Will raise a warning because never awaited 298 self.assertTrue(inspect.isawaitable(async_mock())) 299 300 sync_mock = mock_type(normal_func) 301 self.assertIsInstance(sync_mock, mock_type) 302 303 for mock_type in [AsyncMock, MagicMock, Mock]: 304 with self.subTest(f"test spec positional with {mock_type}"): 305 inner_test(mock_type) 306 307 def test_spec_as_normal_kw_AsyncMock(self): 308 mock = AsyncMock(spec=normal_func) 309 self.assertIsInstance(mock, AsyncMock) 310 m = mock() 311 self.assertTrue(inspect.isawaitable(m)) 312 asyncio.run(m) 313 314 def test_spec_as_normal_positional_AsyncMock(self): 315 mock = AsyncMock(normal_func) 316 self.assertIsInstance(mock, AsyncMock) 317 m = mock() 318 self.assertTrue(inspect.isawaitable(m)) 319 asyncio.run(m) 320 321 def test_spec_async_mock(self): 322 @patch.object(AsyncClass, 'async_method', spec=True) 323 def test_async(mock_method): 324 self.assertIsInstance(mock_method, AsyncMock) 325 326 test_async() 327 328 def test_spec_parent_not_async_attribute_is(self): 329 @patch(async_foo_name, spec=True) 330 def test_async(mock_method): 331 self.assertIsInstance(mock_method, MagicMock) 332 self.assertIsInstance(mock_method.async_method, AsyncMock) 333 334 test_async() 335 336 def test_target_async_spec_not(self): 337 @patch.object(AsyncClass, 'async_method', spec=NormalClass.a) 338 def test_async_attribute(mock_method): 339 self.assertIsInstance(mock_method, MagicMock) 340 self.assertFalse(inspect.iscoroutine(mock_method)) 341 self.assertFalse(inspect.isawaitable(mock_method)) 342 343 test_async_attribute() 344 345 def test_target_not_async_spec_is(self): 346 @patch.object(NormalClass, 'a', spec=async_func) 347 def test_attribute_not_async_spec_is(mock_async_func): 348 self.assertIsInstance(mock_async_func, AsyncMock) 349 test_attribute_not_async_spec_is() 350 351 def test_spec_async_attributes(self): 352 @patch(normal_foo_name, spec=AsyncClass) 353 def test_async_attributes_coroutines(MockNormalClass): 354 self.assertIsInstance(MockNormalClass.async_method, AsyncMock) 355 self.assertIsInstance(MockNormalClass, MagicMock) 356 357 test_async_attributes_coroutines() 358 359 360class AsyncSpecSetTest(unittest.TestCase): 361 def test_is_AsyncMock_patch(self): 362 @patch.object(AsyncClass, 'async_method', spec_set=True) 363 def test_async(async_method): 364 self.assertIsInstance(async_method, AsyncMock) 365 366 def test_is_async_AsyncMock(self): 367 mock = AsyncMock(spec_set=AsyncClass.async_method) 368 self.assertTrue(asyncio.iscoroutinefunction(mock)) 369 self.assertIsInstance(mock, AsyncMock) 370 371 def test_is_child_AsyncMock(self): 372 mock = MagicMock(spec_set=AsyncClass) 373 self.assertTrue(asyncio.iscoroutinefunction(mock.async_method)) 374 self.assertFalse(asyncio.iscoroutinefunction(mock.normal_method)) 375 self.assertIsInstance(mock.async_method, AsyncMock) 376 self.assertIsInstance(mock.normal_method, MagicMock) 377 self.assertIsInstance(mock, MagicMock) 378 379 def test_magicmock_lambda_spec(self): 380 mock_obj = MagicMock() 381 mock_obj.mock_func = MagicMock(spec=lambda x: x) 382 383 with patch.object(mock_obj, "mock_func") as cm: 384 self.assertIsInstance(cm, MagicMock) 385 386 387class AsyncArguments(unittest.IsolatedAsyncioTestCase): 388 async def test_add_return_value(self): 389 async def addition(self, var): 390 return var + 1 391 392 mock = AsyncMock(addition, return_value=10) 393 output = await mock(5) 394 395 self.assertEqual(output, 10) 396 397 async def test_add_side_effect_exception(self): 398 async def addition(var): 399 return var + 1 400 mock = AsyncMock(addition, side_effect=Exception('err')) 401 with self.assertRaises(Exception): 402 await mock(5) 403 404 async def test_add_side_effect_coroutine(self): 405 async def addition(var): 406 return var + 1 407 mock = AsyncMock(side_effect=addition) 408 result = await mock(5) 409 self.assertEqual(result, 6) 410 411 async def test_add_side_effect_normal_function(self): 412 def addition(var): 413 return var + 1 414 mock = AsyncMock(side_effect=addition) 415 result = await mock(5) 416 self.assertEqual(result, 6) 417 418 async def test_add_side_effect_iterable(self): 419 vals = [1, 2, 3] 420 mock = AsyncMock(side_effect=vals) 421 for item in vals: 422 self.assertEqual(await mock(), item) 423 424 with self.assertRaises(StopAsyncIteration) as e: 425 await mock() 426 427 async def test_add_side_effect_exception_iterable(self): 428 class SampleException(Exception): 429 pass 430 431 vals = [1, SampleException("foo")] 432 mock = AsyncMock(side_effect=vals) 433 self.assertEqual(await mock(), 1) 434 435 with self.assertRaises(SampleException) as e: 436 await mock() 437 438 async def test_return_value_AsyncMock(self): 439 value = AsyncMock(return_value=10) 440 mock = AsyncMock(return_value=value) 441 result = await mock() 442 self.assertIs(result, value) 443 444 async def test_return_value_awaitable(self): 445 fut = asyncio.Future() 446 fut.set_result(None) 447 mock = AsyncMock(return_value=fut) 448 result = await mock() 449 self.assertIsInstance(result, asyncio.Future) 450 451 async def test_side_effect_awaitable_values(self): 452 fut = asyncio.Future() 453 fut.set_result(None) 454 455 mock = AsyncMock(side_effect=[fut]) 456 result = await mock() 457 self.assertIsInstance(result, asyncio.Future) 458 459 with self.assertRaises(StopAsyncIteration): 460 await mock() 461 462 async def test_side_effect_is_AsyncMock(self): 463 effect = AsyncMock(return_value=10) 464 mock = AsyncMock(side_effect=effect) 465 466 result = await mock() 467 self.assertEqual(result, 10) 468 469 async def test_wraps_coroutine(self): 470 value = asyncio.Future() 471 472 ran = False 473 async def inner(): 474 nonlocal ran 475 ran = True 476 return value 477 478 mock = AsyncMock(wraps=inner) 479 result = await mock() 480 self.assertEqual(result, value) 481 mock.assert_awaited() 482 self.assertTrue(ran) 483 484 async def test_wraps_normal_function(self): 485 value = 1 486 487 ran = False 488 def inner(): 489 nonlocal ran 490 ran = True 491 return value 492 493 mock = AsyncMock(wraps=inner) 494 result = await mock() 495 self.assertEqual(result, value) 496 mock.assert_awaited() 497 self.assertTrue(ran) 498 499 async def test_await_args_list_order(self): 500 async_mock = AsyncMock() 501 mock2 = async_mock(2) 502 mock1 = async_mock(1) 503 await mock1 504 await mock2 505 async_mock.assert_has_awaits([call(1), call(2)]) 506 self.assertEqual(async_mock.await_args_list, [call(1), call(2)]) 507 self.assertEqual(async_mock.call_args_list, [call(2), call(1)]) 508 509 510class AsyncMagicMethods(unittest.TestCase): 511 def test_async_magic_methods_return_async_mocks(self): 512 m_mock = MagicMock() 513 self.assertIsInstance(m_mock.__aenter__, AsyncMock) 514 self.assertIsInstance(m_mock.__aexit__, AsyncMock) 515 self.assertIsInstance(m_mock.__anext__, AsyncMock) 516 # __aiter__ is actually a synchronous object 517 # so should return a MagicMock 518 self.assertIsInstance(m_mock.__aiter__, MagicMock) 519 520 def test_sync_magic_methods_return_magic_mocks(self): 521 a_mock = AsyncMock() 522 self.assertIsInstance(a_mock.__enter__, MagicMock) 523 self.assertIsInstance(a_mock.__exit__, MagicMock) 524 self.assertIsInstance(a_mock.__next__, MagicMock) 525 self.assertIsInstance(a_mock.__len__, MagicMock) 526 527 def test_magicmock_has_async_magic_methods(self): 528 m_mock = MagicMock() 529 self.assertTrue(hasattr(m_mock, "__aenter__")) 530 self.assertTrue(hasattr(m_mock, "__aexit__")) 531 self.assertTrue(hasattr(m_mock, "__anext__")) 532 533 def test_asyncmock_has_sync_magic_methods(self): 534 a_mock = AsyncMock() 535 self.assertTrue(hasattr(a_mock, "__enter__")) 536 self.assertTrue(hasattr(a_mock, "__exit__")) 537 self.assertTrue(hasattr(a_mock, "__next__")) 538 self.assertTrue(hasattr(a_mock, "__len__")) 539 540 def test_magic_methods_are_async_functions(self): 541 m_mock = MagicMock() 542 self.assertIsInstance(m_mock.__aenter__, AsyncMock) 543 self.assertIsInstance(m_mock.__aexit__, AsyncMock) 544 # AsyncMocks are also coroutine functions 545 self.assertTrue(asyncio.iscoroutinefunction(m_mock.__aenter__)) 546 self.assertTrue(asyncio.iscoroutinefunction(m_mock.__aexit__)) 547 548class AsyncContextManagerTest(unittest.TestCase): 549 class WithAsyncContextManager: 550 async def __aenter__(self, *args, **kwargs): 551 self.entered = True 552 return self 553 554 async def __aexit__(self, *args, **kwargs): 555 self.exited = True 556 557 class WithSyncContextManager: 558 def __enter__(self, *args, **kwargs): 559 return self 560 561 def __exit__(self, *args, **kwargs): 562 pass 563 564 class ProductionCode: 565 # Example real-world(ish) code 566 def __init__(self): 567 self.session = None 568 569 async def main(self): 570 async with self.session.post('https://python.org') as response: 571 val = await response.json() 572 return val 573 574 def test_set_return_value_of_aenter(self): 575 def inner_test(mock_type): 576 pc = self.ProductionCode() 577 pc.session = MagicMock(name='sessionmock') 578 cm = mock_type(name='magic_cm') 579 response = AsyncMock(name='response') 580 response.json = AsyncMock(return_value={'json': 123}) 581 cm.__aenter__.return_value = response 582 pc.session.post.return_value = cm 583 result = asyncio.run(pc.main()) 584 self.assertEqual(result, {'json': 123}) 585 586 for mock_type in [AsyncMock, MagicMock]: 587 with self.subTest(f"test set return value of aenter with {mock_type}"): 588 inner_test(mock_type) 589 590 def test_mock_supports_async_context_manager(self): 591 def inner_test(mock_type): 592 called = False 593 cm = self.WithAsyncContextManager() 594 cm_mock = mock_type(cm) 595 596 async def use_context_manager(): 597 nonlocal called 598 async with cm_mock as result: 599 called = True 600 return result 601 602 cm_result = asyncio.run(use_context_manager()) 603 self.assertTrue(called) 604 self.assertTrue(cm_mock.__aenter__.called) 605 self.assertTrue(cm_mock.__aexit__.called) 606 cm_mock.__aenter__.assert_awaited() 607 cm_mock.__aexit__.assert_awaited() 608 # We mock __aenter__ so it does not return self 609 self.assertIsNot(cm_mock, cm_result) 610 611 for mock_type in [AsyncMock, MagicMock]: 612 with self.subTest(f"test context manager magics with {mock_type}"): 613 inner_test(mock_type) 614 615 def test_mock_customize_async_context_manager(self): 616 instance = self.WithAsyncContextManager() 617 mock_instance = MagicMock(instance) 618 619 expected_result = object() 620 mock_instance.__aenter__.return_value = expected_result 621 622 async def use_context_manager(): 623 async with mock_instance as result: 624 return result 625 626 self.assertIs(asyncio.run(use_context_manager()), expected_result) 627 628 def test_mock_customize_async_context_manager_with_coroutine(self): 629 enter_called = False 630 exit_called = False 631 632 async def enter_coroutine(*args): 633 nonlocal enter_called 634 enter_called = True 635 636 async def exit_coroutine(*args): 637 nonlocal exit_called 638 exit_called = True 639 640 instance = self.WithAsyncContextManager() 641 mock_instance = MagicMock(instance) 642 643 mock_instance.__aenter__ = enter_coroutine 644 mock_instance.__aexit__ = exit_coroutine 645 646 async def use_context_manager(): 647 async with mock_instance: 648 pass 649 650 asyncio.run(use_context_manager()) 651 self.assertTrue(enter_called) 652 self.assertTrue(exit_called) 653 654 def test_context_manager_raise_exception_by_default(self): 655 async def raise_in(context_manager): 656 async with context_manager: 657 raise TypeError() 658 659 instance = self.WithAsyncContextManager() 660 mock_instance = MagicMock(instance) 661 with self.assertRaises(TypeError): 662 asyncio.run(raise_in(mock_instance)) 663 664 665class AsyncIteratorTest(unittest.TestCase): 666 class WithAsyncIterator(object): 667 def __init__(self): 668 self.items = ["foo", "NormalFoo", "baz"] 669 670 def __aiter__(self): 671 return self 672 673 async def __anext__(self): 674 try: 675 return self.items.pop() 676 except IndexError: 677 pass 678 679 raise StopAsyncIteration 680 681 def test_aiter_set_return_value(self): 682 mock_iter = AsyncMock(name="tester") 683 mock_iter.__aiter__.return_value = [1, 2, 3] 684 async def main(): 685 return [i async for i in mock_iter] 686 result = asyncio.run(main()) 687 self.assertEqual(result, [1, 2, 3]) 688 689 def test_mock_aiter_and_anext_asyncmock(self): 690 def inner_test(mock_type): 691 instance = self.WithAsyncIterator() 692 mock_instance = mock_type(instance) 693 # Check that the mock and the real thing bahave the same 694 # __aiter__ is not actually async, so not a coroutinefunction 695 self.assertFalse(asyncio.iscoroutinefunction(instance.__aiter__)) 696 self.assertFalse(asyncio.iscoroutinefunction(mock_instance.__aiter__)) 697 # __anext__ is async 698 self.assertTrue(asyncio.iscoroutinefunction(instance.__anext__)) 699 self.assertTrue(asyncio.iscoroutinefunction(mock_instance.__anext__)) 700 701 for mock_type in [AsyncMock, MagicMock]: 702 with self.subTest(f"test aiter and anext corourtine with {mock_type}"): 703 inner_test(mock_type) 704 705 706 def test_mock_async_for(self): 707 async def iterate(iterator): 708 accumulator = [] 709 async for item in iterator: 710 accumulator.append(item) 711 712 return accumulator 713 714 expected = ["FOO", "BAR", "BAZ"] 715 def test_default(mock_type): 716 mock_instance = mock_type(self.WithAsyncIterator()) 717 self.assertEqual(asyncio.run(iterate(mock_instance)), []) 718 719 720 def test_set_return_value(mock_type): 721 mock_instance = mock_type(self.WithAsyncIterator()) 722 mock_instance.__aiter__.return_value = expected[:] 723 self.assertEqual(asyncio.run(iterate(mock_instance)), expected) 724 725 def test_set_return_value_iter(mock_type): 726 mock_instance = mock_type(self.WithAsyncIterator()) 727 mock_instance.__aiter__.return_value = iter(expected[:]) 728 self.assertEqual(asyncio.run(iterate(mock_instance)), expected) 729 730 for mock_type in [AsyncMock, MagicMock]: 731 with self.subTest(f"default value with {mock_type}"): 732 test_default(mock_type) 733 734 with self.subTest(f"set return_value with {mock_type}"): 735 test_set_return_value(mock_type) 736 737 with self.subTest(f"set return_value iterator with {mock_type}"): 738 test_set_return_value_iter(mock_type) 739 740 741class AsyncMockAssert(unittest.TestCase): 742 def setUp(self): 743 self.mock = AsyncMock() 744 745 async def _runnable_test(self, *args, **kwargs): 746 await self.mock(*args, **kwargs) 747 748 async def _await_coroutine(self, coroutine): 749 return await coroutine 750 751 def test_assert_called_but_not_awaited(self): 752 mock = AsyncMock(AsyncClass) 753 with self.assertWarns(RuntimeWarning): 754 # Will raise a warning because never awaited 755 mock.async_method() 756 self.assertTrue(asyncio.iscoroutinefunction(mock.async_method)) 757 mock.async_method.assert_called() 758 mock.async_method.assert_called_once() 759 mock.async_method.assert_called_once_with() 760 with self.assertRaises(AssertionError): 761 mock.assert_awaited() 762 with self.assertRaises(AssertionError): 763 mock.async_method.assert_awaited() 764 765 def test_assert_called_then_awaited(self): 766 mock = AsyncMock(AsyncClass) 767 mock_coroutine = mock.async_method() 768 mock.async_method.assert_called() 769 mock.async_method.assert_called_once() 770 mock.async_method.assert_called_once_with() 771 with self.assertRaises(AssertionError): 772 mock.async_method.assert_awaited() 773 774 asyncio.run(self._await_coroutine(mock_coroutine)) 775 # Assert we haven't re-called the function 776 mock.async_method.assert_called_once() 777 mock.async_method.assert_awaited() 778 mock.async_method.assert_awaited_once() 779 mock.async_method.assert_awaited_once_with() 780 781 def test_assert_called_and_awaited_at_same_time(self): 782 with self.assertRaises(AssertionError): 783 self.mock.assert_awaited() 784 785 with self.assertRaises(AssertionError): 786 self.mock.assert_called() 787 788 asyncio.run(self._runnable_test()) 789 self.mock.assert_called_once() 790 self.mock.assert_awaited_once() 791 792 def test_assert_called_twice_and_awaited_once(self): 793 mock = AsyncMock(AsyncClass) 794 coroutine = mock.async_method() 795 with self.assertWarns(RuntimeWarning): 796 # The first call will be awaited so no warning there 797 # But this call will never get awaited, so it will warn here 798 mock.async_method() 799 with self.assertRaises(AssertionError): 800 mock.async_method.assert_awaited() 801 mock.async_method.assert_called() 802 asyncio.run(self._await_coroutine(coroutine)) 803 mock.async_method.assert_awaited() 804 mock.async_method.assert_awaited_once() 805 806 def test_assert_called_once_and_awaited_twice(self): 807 mock = AsyncMock(AsyncClass) 808 coroutine = mock.async_method() 809 mock.async_method.assert_called_once() 810 asyncio.run(self._await_coroutine(coroutine)) 811 with self.assertRaises(RuntimeError): 812 # Cannot reuse already awaited coroutine 813 asyncio.run(self._await_coroutine(coroutine)) 814 mock.async_method.assert_awaited() 815 816 def test_assert_awaited_but_not_called(self): 817 with self.assertRaises(AssertionError): 818 self.mock.assert_awaited() 819 with self.assertRaises(AssertionError): 820 self.mock.assert_called() 821 with self.assertRaises(TypeError): 822 # You cannot await an AsyncMock, it must be a coroutine 823 asyncio.run(self._await_coroutine(self.mock)) 824 825 with self.assertRaises(AssertionError): 826 self.mock.assert_awaited() 827 with self.assertRaises(AssertionError): 828 self.mock.assert_called() 829 830 def test_assert_has_calls_not_awaits(self): 831 kalls = [call('foo')] 832 with self.assertWarns(RuntimeWarning): 833 # Will raise a warning because never awaited 834 self.mock('foo') 835 self.mock.assert_has_calls(kalls) 836 with self.assertRaises(AssertionError): 837 self.mock.assert_has_awaits(kalls) 838 839 def test_assert_has_mock_calls_on_async_mock_no_spec(self): 840 with self.assertWarns(RuntimeWarning): 841 # Will raise a warning because never awaited 842 self.mock() 843 kalls_empty = [('', (), {})] 844 self.assertEqual(self.mock.mock_calls, kalls_empty) 845 846 with self.assertWarns(RuntimeWarning): 847 # Will raise a warning because never awaited 848 self.mock('foo') 849 self.mock('baz') 850 mock_kalls = ([call(), call('foo'), call('baz')]) 851 self.assertEqual(self.mock.mock_calls, mock_kalls) 852 853 def test_assert_has_mock_calls_on_async_mock_with_spec(self): 854 a_class_mock = AsyncMock(AsyncClass) 855 with self.assertWarns(RuntimeWarning): 856 # Will raise a warning because never awaited 857 a_class_mock.async_method() 858 kalls_empty = [('', (), {})] 859 self.assertEqual(a_class_mock.async_method.mock_calls, kalls_empty) 860 self.assertEqual(a_class_mock.mock_calls, [call.async_method()]) 861 862 with self.assertWarns(RuntimeWarning): 863 # Will raise a warning because never awaited 864 a_class_mock.async_method(1, 2, 3, a=4, b=5) 865 method_kalls = [call(), call(1, 2, 3, a=4, b=5)] 866 mock_kalls = [call.async_method(), call.async_method(1, 2, 3, a=4, b=5)] 867 self.assertEqual(a_class_mock.async_method.mock_calls, method_kalls) 868 self.assertEqual(a_class_mock.mock_calls, mock_kalls) 869 870 def test_async_method_calls_recorded(self): 871 with self.assertWarns(RuntimeWarning): 872 # Will raise warnings because never awaited 873 self.mock.something(3, fish=None) 874 self.mock.something_else.something(6, cake=sentinel.Cake) 875 876 self.assertEqual(self.mock.method_calls, [ 877 ("something", (3,), {'fish': None}), 878 ("something_else.something", (6,), {'cake': sentinel.Cake}) 879 ], 880 "method calls not recorded correctly") 881 self.assertEqual(self.mock.something_else.method_calls, 882 [("something", (6,), {'cake': sentinel.Cake})], 883 "method calls not recorded correctly") 884 885 def test_async_arg_lists(self): 886 def assert_attrs(mock): 887 names = ('call_args_list', 'method_calls', 'mock_calls') 888 for name in names: 889 attr = getattr(mock, name) 890 self.assertIsInstance(attr, _CallList) 891 self.assertIsInstance(attr, list) 892 self.assertEqual(attr, []) 893 894 assert_attrs(self.mock) 895 with self.assertWarns(RuntimeWarning): 896 # Will raise warnings because never awaited 897 self.mock() 898 self.mock(1, 2) 899 self.mock(a=3) 900 901 self.mock.reset_mock() 902 assert_attrs(self.mock) 903 904 a_mock = AsyncMock(AsyncClass) 905 with self.assertWarns(RuntimeWarning): 906 # Will raise warnings because never awaited 907 a_mock.async_method() 908 a_mock.async_method(1, a=3) 909 910 a_mock.reset_mock() 911 assert_attrs(a_mock) 912 913 def test_assert_awaited(self): 914 with self.assertRaises(AssertionError): 915 self.mock.assert_awaited() 916 917 asyncio.run(self._runnable_test()) 918 self.mock.assert_awaited() 919 920 def test_assert_awaited_once(self): 921 with self.assertRaises(AssertionError): 922 self.mock.assert_awaited_once() 923 924 asyncio.run(self._runnable_test()) 925 self.mock.assert_awaited_once() 926 927 asyncio.run(self._runnable_test()) 928 with self.assertRaises(AssertionError): 929 self.mock.assert_awaited_once() 930 931 def test_assert_awaited_with(self): 932 msg = 'Not awaited' 933 with self.assertRaisesRegex(AssertionError, msg): 934 self.mock.assert_awaited_with('foo') 935 936 asyncio.run(self._runnable_test()) 937 msg = 'expected await not found' 938 with self.assertRaisesRegex(AssertionError, msg): 939 self.mock.assert_awaited_with('foo') 940 941 asyncio.run(self._runnable_test('foo')) 942 self.mock.assert_awaited_with('foo') 943 944 asyncio.run(self._runnable_test('SomethingElse')) 945 with self.assertRaises(AssertionError): 946 self.mock.assert_awaited_with('foo') 947 948 def test_assert_awaited_once_with(self): 949 with self.assertRaises(AssertionError): 950 self.mock.assert_awaited_once_with('foo') 951 952 asyncio.run(self._runnable_test('foo')) 953 self.mock.assert_awaited_once_with('foo') 954 955 asyncio.run(self._runnable_test('foo')) 956 with self.assertRaises(AssertionError): 957 self.mock.assert_awaited_once_with('foo') 958 959 def test_assert_any_wait(self): 960 with self.assertRaises(AssertionError): 961 self.mock.assert_any_await('foo') 962 963 asyncio.run(self._runnable_test('baz')) 964 with self.assertRaises(AssertionError): 965 self.mock.assert_any_await('foo') 966 967 asyncio.run(self._runnable_test('foo')) 968 self.mock.assert_any_await('foo') 969 970 asyncio.run(self._runnable_test('SomethingElse')) 971 self.mock.assert_any_await('foo') 972 973 def test_assert_has_awaits_no_order(self): 974 calls = [call('foo'), call('baz')] 975 976 with self.assertRaises(AssertionError) as cm: 977 self.mock.assert_has_awaits(calls) 978 self.assertEqual(len(cm.exception.args), 1) 979 980 asyncio.run(self._runnable_test('foo')) 981 with self.assertRaises(AssertionError): 982 self.mock.assert_has_awaits(calls) 983 984 asyncio.run(self._runnable_test('foo')) 985 with self.assertRaises(AssertionError): 986 self.mock.assert_has_awaits(calls) 987 988 asyncio.run(self._runnable_test('baz')) 989 self.mock.assert_has_awaits(calls) 990 991 asyncio.run(self._runnable_test('SomethingElse')) 992 self.mock.assert_has_awaits(calls) 993 994 def test_assert_has_awaits_ordered(self): 995 calls = [call('foo'), call('baz')] 996 with self.assertRaises(AssertionError): 997 self.mock.assert_has_awaits(calls, any_order=True) 998 999 asyncio.run(self._runnable_test('baz')) 1000 with self.assertRaises(AssertionError): 1001 self.mock.assert_has_awaits(calls, any_order=True) 1002 1003 asyncio.run(self._runnable_test('bamf')) 1004 with self.assertRaises(AssertionError): 1005 self.mock.assert_has_awaits(calls, any_order=True) 1006 1007 asyncio.run(self._runnable_test('foo')) 1008 self.mock.assert_has_awaits(calls, any_order=True) 1009 1010 asyncio.run(self._runnable_test('qux')) 1011 self.mock.assert_has_awaits(calls, any_order=True) 1012 1013 def test_assert_not_awaited(self): 1014 self.mock.assert_not_awaited() 1015 1016 asyncio.run(self._runnable_test()) 1017 with self.assertRaises(AssertionError): 1018 self.mock.assert_not_awaited() 1019 1020 def test_assert_has_awaits_not_matching_spec_error(self): 1021 async def f(x=None): pass 1022 1023 self.mock = AsyncMock(spec=f) 1024 asyncio.run(self._runnable_test(1)) 1025 1026 with self.assertRaisesRegex( 1027 AssertionError, 1028 '^{}$'.format( 1029 re.escape('Awaits not found.\n' 1030 'Expected: [call()]\n' 1031 'Actual: [call(1)]'))) as cm: 1032 self.mock.assert_has_awaits([call()]) 1033 self.assertIsNone(cm.exception.__cause__) 1034 1035 with self.assertRaisesRegex( 1036 AssertionError, 1037 '^{}$'.format( 1038 re.escape( 1039 'Error processing expected awaits.\n' 1040 "Errors: [None, TypeError('too many positional " 1041 "arguments')]\n" 1042 'Expected: [call(), call(1, 2)]\n' 1043 'Actual: [call(1)]'))) as cm: 1044 self.mock.assert_has_awaits([call(), call(1, 2)]) 1045 self.assertIsInstance(cm.exception.__cause__, TypeError) 1046