1import asyncio 2from contextlib import ( 3 asynccontextmanager, AbstractAsyncContextManager, 4 AsyncExitStack, nullcontext, aclosing, contextmanager) 5import functools 6from test import support 7import unittest 8 9from test.test_contextlib import TestBaseExitStack 10 11 12def _async_test(func): 13 """Decorator to turn an async function into a test case.""" 14 @functools.wraps(func) 15 def wrapper(*args, **kwargs): 16 coro = func(*args, **kwargs) 17 loop = asyncio.new_event_loop() 18 asyncio.set_event_loop(loop) 19 try: 20 return loop.run_until_complete(coro) 21 finally: 22 loop.close() 23 asyncio.set_event_loop_policy(None) 24 return wrapper 25 26 27class TestAbstractAsyncContextManager(unittest.TestCase): 28 29 @_async_test 30 async def test_enter(self): 31 class DefaultEnter(AbstractAsyncContextManager): 32 async def __aexit__(self, *args): 33 await super().__aexit__(*args) 34 35 manager = DefaultEnter() 36 self.assertIs(await manager.__aenter__(), manager) 37 38 async with manager as context: 39 self.assertIs(manager, context) 40 41 @_async_test 42 async def test_async_gen_propagates_generator_exit(self): 43 # A regression test for https://bugs.python.org/issue33786. 44 45 @asynccontextmanager 46 async def ctx(): 47 yield 48 49 async def gen(): 50 async with ctx(): 51 yield 11 52 53 ret = [] 54 exc = ValueError(22) 55 with self.assertRaises(ValueError): 56 async with ctx(): 57 async for val in gen(): 58 ret.append(val) 59 raise exc 60 61 self.assertEqual(ret, [11]) 62 63 def test_exit_is_abstract(self): 64 class MissingAexit(AbstractAsyncContextManager): 65 pass 66 67 with self.assertRaises(TypeError): 68 MissingAexit() 69 70 def test_structural_subclassing(self): 71 class ManagerFromScratch: 72 async def __aenter__(self): 73 return self 74 async def __aexit__(self, exc_type, exc_value, traceback): 75 return None 76 77 self.assertTrue(issubclass(ManagerFromScratch, AbstractAsyncContextManager)) 78 79 class DefaultEnter(AbstractAsyncContextManager): 80 async def __aexit__(self, *args): 81 await super().__aexit__(*args) 82 83 self.assertTrue(issubclass(DefaultEnter, AbstractAsyncContextManager)) 84 85 class NoneAenter(ManagerFromScratch): 86 __aenter__ = None 87 88 self.assertFalse(issubclass(NoneAenter, AbstractAsyncContextManager)) 89 90 class NoneAexit(ManagerFromScratch): 91 __aexit__ = None 92 93 self.assertFalse(issubclass(NoneAexit, AbstractAsyncContextManager)) 94 95 96class AsyncContextManagerTestCase(unittest.TestCase): 97 98 @_async_test 99 async def test_contextmanager_plain(self): 100 state = [] 101 @asynccontextmanager 102 async def woohoo(): 103 state.append(1) 104 yield 42 105 state.append(999) 106 async with woohoo() as x: 107 self.assertEqual(state, [1]) 108 self.assertEqual(x, 42) 109 state.append(x) 110 self.assertEqual(state, [1, 42, 999]) 111 112 @_async_test 113 async def test_contextmanager_finally(self): 114 state = [] 115 @asynccontextmanager 116 async def woohoo(): 117 state.append(1) 118 try: 119 yield 42 120 finally: 121 state.append(999) 122 with self.assertRaises(ZeroDivisionError): 123 async with woohoo() as x: 124 self.assertEqual(state, [1]) 125 self.assertEqual(x, 42) 126 state.append(x) 127 raise ZeroDivisionError() 128 self.assertEqual(state, [1, 42, 999]) 129 130 @_async_test 131 async def test_contextmanager_no_reraise(self): 132 @asynccontextmanager 133 async def whee(): 134 yield 135 ctx = whee() 136 await ctx.__aenter__() 137 # Calling __aexit__ should not result in an exception 138 self.assertFalse(await ctx.__aexit__(TypeError, TypeError("foo"), None)) 139 140 @_async_test 141 async def test_contextmanager_trap_yield_after_throw(self): 142 @asynccontextmanager 143 async def whoo(): 144 try: 145 yield 146 except: 147 yield 148 ctx = whoo() 149 await ctx.__aenter__() 150 with self.assertRaises(RuntimeError): 151 await ctx.__aexit__(TypeError, TypeError('foo'), None) 152 153 @_async_test 154 async def test_contextmanager_trap_no_yield(self): 155 @asynccontextmanager 156 async def whoo(): 157 if False: 158 yield 159 ctx = whoo() 160 with self.assertRaises(RuntimeError): 161 await ctx.__aenter__() 162 163 @_async_test 164 async def test_contextmanager_trap_second_yield(self): 165 @asynccontextmanager 166 async def whoo(): 167 yield 168 yield 169 ctx = whoo() 170 await ctx.__aenter__() 171 with self.assertRaises(RuntimeError): 172 await ctx.__aexit__(None, None, None) 173 174 @_async_test 175 async def test_contextmanager_non_normalised(self): 176 @asynccontextmanager 177 async def whoo(): 178 try: 179 yield 180 except RuntimeError: 181 raise SyntaxError 182 183 ctx = whoo() 184 await ctx.__aenter__() 185 with self.assertRaises(SyntaxError): 186 await ctx.__aexit__(RuntimeError, None, None) 187 188 @_async_test 189 async def test_contextmanager_except(self): 190 state = [] 191 @asynccontextmanager 192 async def woohoo(): 193 state.append(1) 194 try: 195 yield 42 196 except ZeroDivisionError as e: 197 state.append(e.args[0]) 198 self.assertEqual(state, [1, 42, 999]) 199 async with woohoo() as x: 200 self.assertEqual(state, [1]) 201 self.assertEqual(x, 42) 202 state.append(x) 203 raise ZeroDivisionError(999) 204 self.assertEqual(state, [1, 42, 999]) 205 206 @_async_test 207 async def test_contextmanager_except_stopiter(self): 208 @asynccontextmanager 209 async def woohoo(): 210 yield 211 212 class StopIterationSubclass(StopIteration): 213 pass 214 215 class StopAsyncIterationSubclass(StopAsyncIteration): 216 pass 217 218 for stop_exc in ( 219 StopIteration('spam'), 220 StopAsyncIteration('ham'), 221 StopIterationSubclass('spam'), 222 StopAsyncIterationSubclass('spam') 223 ): 224 with self.subTest(type=type(stop_exc)): 225 try: 226 async with woohoo(): 227 raise stop_exc 228 except Exception as ex: 229 self.assertIs(ex, stop_exc) 230 else: 231 self.fail(f'{stop_exc} was suppressed') 232 233 @_async_test 234 async def test_contextmanager_wrap_runtimeerror(self): 235 @asynccontextmanager 236 async def woohoo(): 237 try: 238 yield 239 except Exception as exc: 240 raise RuntimeError(f'caught {exc}') from exc 241 242 with self.assertRaises(RuntimeError): 243 async with woohoo(): 244 1 / 0 245 246 # If the context manager wrapped StopAsyncIteration in a RuntimeError, 247 # we also unwrap it, because we can't tell whether the wrapping was 248 # done by the generator machinery or by the generator itself. 249 with self.assertRaises(StopAsyncIteration): 250 async with woohoo(): 251 raise StopAsyncIteration 252 253 def _create_contextmanager_attribs(self): 254 def attribs(**kw): 255 def decorate(func): 256 for k,v in kw.items(): 257 setattr(func,k,v) 258 return func 259 return decorate 260 @asynccontextmanager 261 @attribs(foo='bar') 262 async def baz(spam): 263 """Whee!""" 264 yield 265 return baz 266 267 def test_contextmanager_attribs(self): 268 baz = self._create_contextmanager_attribs() 269 self.assertEqual(baz.__name__,'baz') 270 self.assertEqual(baz.foo, 'bar') 271 272 @support.requires_docstrings 273 def test_contextmanager_doc_attrib(self): 274 baz = self._create_contextmanager_attribs() 275 self.assertEqual(baz.__doc__, "Whee!") 276 277 @support.requires_docstrings 278 @_async_test 279 async def test_instance_docstring_given_cm_docstring(self): 280 baz = self._create_contextmanager_attribs()(None) 281 self.assertEqual(baz.__doc__, "Whee!") 282 async with baz: 283 pass # suppress warning 284 285 @_async_test 286 async def test_keywords(self): 287 # Ensure no keyword arguments are inhibited 288 @asynccontextmanager 289 async def woohoo(self, func, args, kwds): 290 yield (self, func, args, kwds) 291 async with woohoo(self=11, func=22, args=33, kwds=44) as target: 292 self.assertEqual(target, (11, 22, 33, 44)) 293 294 @_async_test 295 async def test_recursive(self): 296 depth = 0 297 ncols = 0 298 299 @asynccontextmanager 300 async def woohoo(): 301 nonlocal ncols 302 ncols += 1 303 304 nonlocal depth 305 before = depth 306 depth += 1 307 yield 308 depth -= 1 309 self.assertEqual(depth, before) 310 311 @woohoo() 312 async def recursive(): 313 if depth < 10: 314 await recursive() 315 316 await recursive() 317 318 self.assertEqual(ncols, 10) 319 self.assertEqual(depth, 0) 320 321 @_async_test 322 async def test_decorator(self): 323 entered = False 324 325 @asynccontextmanager 326 async def context(): 327 nonlocal entered 328 entered = True 329 yield 330 entered = False 331 332 @context() 333 async def test(): 334 self.assertTrue(entered) 335 336 self.assertFalse(entered) 337 await test() 338 self.assertFalse(entered) 339 340 @_async_test 341 async def test_decorator_with_exception(self): 342 entered = False 343 344 @asynccontextmanager 345 async def context(): 346 nonlocal entered 347 try: 348 entered = True 349 yield 350 finally: 351 entered = False 352 353 @context() 354 async def test(): 355 self.assertTrue(entered) 356 raise NameError('foo') 357 358 self.assertFalse(entered) 359 with self.assertRaisesRegex(NameError, 'foo'): 360 await test() 361 self.assertFalse(entered) 362 363 @_async_test 364 async def test_decorating_method(self): 365 366 @asynccontextmanager 367 async def context(): 368 yield 369 370 371 class Test(object): 372 373 @context() 374 async def method(self, a, b, c=None): 375 self.a = a 376 self.b = b 377 self.c = c 378 379 # these tests are for argument passing when used as a decorator 380 test = Test() 381 await test.method(1, 2) 382 self.assertEqual(test.a, 1) 383 self.assertEqual(test.b, 2) 384 self.assertEqual(test.c, None) 385 386 test = Test() 387 await test.method('a', 'b', 'c') 388 self.assertEqual(test.a, 'a') 389 self.assertEqual(test.b, 'b') 390 self.assertEqual(test.c, 'c') 391 392 test = Test() 393 await test.method(a=1, b=2) 394 self.assertEqual(test.a, 1) 395 self.assertEqual(test.b, 2) 396 397 398class AclosingTestCase(unittest.TestCase): 399 400 @support.requires_docstrings 401 def test_instance_docs(self): 402 cm_docstring = aclosing.__doc__ 403 obj = aclosing(None) 404 self.assertEqual(obj.__doc__, cm_docstring) 405 406 @_async_test 407 async def test_aclosing(self): 408 state = [] 409 class C: 410 async def aclose(self): 411 state.append(1) 412 x = C() 413 self.assertEqual(state, []) 414 async with aclosing(x) as y: 415 self.assertEqual(x, y) 416 self.assertEqual(state, [1]) 417 418 @_async_test 419 async def test_aclosing_error(self): 420 state = [] 421 class C: 422 async def aclose(self): 423 state.append(1) 424 x = C() 425 self.assertEqual(state, []) 426 with self.assertRaises(ZeroDivisionError): 427 async with aclosing(x) as y: 428 self.assertEqual(x, y) 429 1 / 0 430 self.assertEqual(state, [1]) 431 432 @_async_test 433 async def test_aclosing_bpo41229(self): 434 state = [] 435 436 @contextmanager 437 def sync_resource(): 438 try: 439 yield 440 finally: 441 state.append(1) 442 443 async def agenfunc(): 444 with sync_resource(): 445 yield -1 446 yield -2 447 448 x = agenfunc() 449 self.assertEqual(state, []) 450 with self.assertRaises(ZeroDivisionError): 451 async with aclosing(x) as y: 452 self.assertEqual(x, y) 453 self.assertEqual(-1, await x.__anext__()) 454 1 / 0 455 self.assertEqual(state, [1]) 456 457 458class TestAsyncExitStack(TestBaseExitStack, unittest.TestCase): 459 class SyncAsyncExitStack(AsyncExitStack): 460 @staticmethod 461 def run_coroutine(coro): 462 loop = asyncio.get_event_loop_policy().get_event_loop() 463 t = loop.create_task(coro) 464 t.add_done_callback(lambda f: loop.stop()) 465 loop.run_forever() 466 467 exc = t.exception() 468 if not exc: 469 return t.result() 470 else: 471 context = exc.__context__ 472 473 try: 474 raise exc 475 except: 476 exc.__context__ = context 477 raise exc 478 479 def close(self): 480 return self.run_coroutine(self.aclose()) 481 482 def __enter__(self): 483 return self.run_coroutine(self.__aenter__()) 484 485 def __exit__(self, *exc_details): 486 return self.run_coroutine(self.__aexit__(*exc_details)) 487 488 exit_stack = SyncAsyncExitStack 489 490 def setUp(self): 491 self.loop = asyncio.new_event_loop() 492 asyncio.set_event_loop(self.loop) 493 self.addCleanup(self.loop.close) 494 self.addCleanup(asyncio.set_event_loop_policy, None) 495 496 @_async_test 497 async def test_async_callback(self): 498 expected = [ 499 ((), {}), 500 ((1,), {}), 501 ((1,2), {}), 502 ((), dict(example=1)), 503 ((1,), dict(example=1)), 504 ((1,2), dict(example=1)), 505 ] 506 result = [] 507 async def _exit(*args, **kwds): 508 """Test metadata propagation""" 509 result.append((args, kwds)) 510 511 async with AsyncExitStack() as stack: 512 for args, kwds in reversed(expected): 513 if args and kwds: 514 f = stack.push_async_callback(_exit, *args, **kwds) 515 elif args: 516 f = stack.push_async_callback(_exit, *args) 517 elif kwds: 518 f = stack.push_async_callback(_exit, **kwds) 519 else: 520 f = stack.push_async_callback(_exit) 521 self.assertIs(f, _exit) 522 for wrapper in stack._exit_callbacks: 523 self.assertIs(wrapper[1].__wrapped__, _exit) 524 self.assertNotEqual(wrapper[1].__name__, _exit.__name__) 525 self.assertIsNone(wrapper[1].__doc__, _exit.__doc__) 526 527 self.assertEqual(result, expected) 528 529 result = [] 530 async with AsyncExitStack() as stack: 531 with self.assertRaises(TypeError): 532 stack.push_async_callback(arg=1) 533 with self.assertRaises(TypeError): 534 self.exit_stack.push_async_callback(arg=2) 535 with self.assertRaises(TypeError): 536 stack.push_async_callback(callback=_exit, arg=3) 537 self.assertEqual(result, []) 538 539 @_async_test 540 async def test_async_push(self): 541 exc_raised = ZeroDivisionError 542 async def _expect_exc(exc_type, exc, exc_tb): 543 self.assertIs(exc_type, exc_raised) 544 async def _suppress_exc(*exc_details): 545 return True 546 async def _expect_ok(exc_type, exc, exc_tb): 547 self.assertIsNone(exc_type) 548 self.assertIsNone(exc) 549 self.assertIsNone(exc_tb) 550 class ExitCM(object): 551 def __init__(self, check_exc): 552 self.check_exc = check_exc 553 async def __aenter__(self): 554 self.fail("Should not be called!") 555 async def __aexit__(self, *exc_details): 556 await self.check_exc(*exc_details) 557 558 async with self.exit_stack() as stack: 559 stack.push_async_exit(_expect_ok) 560 self.assertIs(stack._exit_callbacks[-1][1], _expect_ok) 561 cm = ExitCM(_expect_ok) 562 stack.push_async_exit(cm) 563 self.assertIs(stack._exit_callbacks[-1][1].__self__, cm) 564 stack.push_async_exit(_suppress_exc) 565 self.assertIs(stack._exit_callbacks[-1][1], _suppress_exc) 566 cm = ExitCM(_expect_exc) 567 stack.push_async_exit(cm) 568 self.assertIs(stack._exit_callbacks[-1][1].__self__, cm) 569 stack.push_async_exit(_expect_exc) 570 self.assertIs(stack._exit_callbacks[-1][1], _expect_exc) 571 stack.push_async_exit(_expect_exc) 572 self.assertIs(stack._exit_callbacks[-1][1], _expect_exc) 573 1/0 574 575 @_async_test 576 async def test_enter_async_context(self): 577 class TestCM(object): 578 async def __aenter__(self): 579 result.append(1) 580 async def __aexit__(self, *exc_details): 581 result.append(3) 582 583 result = [] 584 cm = TestCM() 585 586 async with AsyncExitStack() as stack: 587 @stack.push_async_callback # Registered first => cleaned up last 588 async def _exit(): 589 result.append(4) 590 self.assertIsNotNone(_exit) 591 await stack.enter_async_context(cm) 592 self.assertIs(stack._exit_callbacks[-1][1].__self__, cm) 593 result.append(2) 594 595 self.assertEqual(result, [1, 2, 3, 4]) 596 597 @_async_test 598 async def test_enter_async_context_errors(self): 599 class LacksEnterAndExit: 600 pass 601 class LacksEnter: 602 async def __aexit__(self, *exc_info): 603 pass 604 class LacksExit: 605 async def __aenter__(self): 606 pass 607 608 async with self.exit_stack() as stack: 609 with self.assertRaisesRegex(TypeError, 'asynchronous context manager'): 610 await stack.enter_async_context(LacksEnterAndExit()) 611 with self.assertRaisesRegex(TypeError, 'asynchronous context manager'): 612 await stack.enter_async_context(LacksEnter()) 613 with self.assertRaisesRegex(TypeError, 'asynchronous context manager'): 614 await stack.enter_async_context(LacksExit()) 615 self.assertFalse(stack._exit_callbacks) 616 617 @_async_test 618 async def test_async_exit_exception_chaining(self): 619 # Ensure exception chaining matches the reference behaviour 620 async def raise_exc(exc): 621 raise exc 622 623 saved_details = None 624 async def suppress_exc(*exc_details): 625 nonlocal saved_details 626 saved_details = exc_details 627 return True 628 629 try: 630 async with self.exit_stack() as stack: 631 stack.push_async_callback(raise_exc, IndexError) 632 stack.push_async_callback(raise_exc, KeyError) 633 stack.push_async_callback(raise_exc, AttributeError) 634 stack.push_async_exit(suppress_exc) 635 stack.push_async_callback(raise_exc, ValueError) 636 1 / 0 637 except IndexError as exc: 638 self.assertIsInstance(exc.__context__, KeyError) 639 self.assertIsInstance(exc.__context__.__context__, AttributeError) 640 # Inner exceptions were suppressed 641 self.assertIsNone(exc.__context__.__context__.__context__) 642 else: 643 self.fail("Expected IndexError, but no exception was raised") 644 # Check the inner exceptions 645 inner_exc = saved_details[1] 646 self.assertIsInstance(inner_exc, ValueError) 647 self.assertIsInstance(inner_exc.__context__, ZeroDivisionError) 648 649 @_async_test 650 async def test_async_exit_exception_explicit_none_context(self): 651 # Ensure AsyncExitStack chaining matches actual nested `with` statements 652 # regarding explicit __context__ = None. 653 654 class MyException(Exception): 655 pass 656 657 @asynccontextmanager 658 async def my_cm(): 659 try: 660 yield 661 except BaseException: 662 exc = MyException() 663 try: 664 raise exc 665 finally: 666 exc.__context__ = None 667 668 @asynccontextmanager 669 async def my_cm_with_exit_stack(): 670 async with self.exit_stack() as stack: 671 await stack.enter_async_context(my_cm()) 672 yield stack 673 674 for cm in (my_cm, my_cm_with_exit_stack): 675 with self.subTest(): 676 try: 677 async with cm(): 678 raise IndexError() 679 except MyException as exc: 680 self.assertIsNone(exc.__context__) 681 else: 682 self.fail("Expected IndexError, but no exception was raised") 683 684 @_async_test 685 async def test_instance_bypass_async(self): 686 class Example(object): pass 687 cm = Example() 688 cm.__aenter__ = object() 689 cm.__aexit__ = object() 690 stack = self.exit_stack() 691 with self.assertRaisesRegex(TypeError, 'asynchronous context manager'): 692 await stack.enter_async_context(cm) 693 stack.push_async_exit(cm) 694 self.assertIs(stack._exit_callbacks[-1][1], cm) 695 696 697class TestAsyncNullcontext(unittest.TestCase): 698 @_async_test 699 async def test_async_nullcontext(self): 700 class C: 701 pass 702 c = C() 703 async with nullcontext(c) as c_in: 704 self.assertIs(c_in, c) 705 706 707if __name__ == '__main__': 708 unittest.main() 709