1import asyncio 2from contextlib import asynccontextmanager, AbstractAsyncContextManager, AsyncExitStack 3import functools 4from test import support 5import unittest 6 7from test.test_contextlib import TestBaseExitStack 8 9 10def _async_test(func): 11 """Decorator to turn an async function into a test case.""" 12 @functools.wraps(func) 13 def wrapper(*args, **kwargs): 14 coro = func(*args, **kwargs) 15 loop = asyncio.new_event_loop() 16 asyncio.set_event_loop(loop) 17 try: 18 return loop.run_until_complete(coro) 19 finally: 20 loop.close() 21 asyncio.set_event_loop(None) 22 return wrapper 23 24 25class TestAbstractAsyncContextManager(unittest.TestCase): 26 27 @_async_test 28 async def test_enter(self): 29 class DefaultEnter(AbstractAsyncContextManager): 30 async def __aexit__(self, *args): 31 await super().__aexit__(*args) 32 33 manager = DefaultEnter() 34 self.assertIs(await manager.__aenter__(), manager) 35 36 async with manager as context: 37 self.assertIs(manager, context) 38 39 @_async_test 40 async def test_async_gen_propagates_generator_exit(self): 41 # A regression test for https://bugs.python.org/issue33786. 42 43 @asynccontextmanager 44 async def ctx(): 45 yield 46 47 async def gen(): 48 async with ctx(): 49 yield 11 50 51 ret = [] 52 exc = ValueError(22) 53 with self.assertRaises(ValueError): 54 async with ctx(): 55 async for val in gen(): 56 ret.append(val) 57 raise exc 58 59 self.assertEqual(ret, [11]) 60 61 def test_exit_is_abstract(self): 62 class MissingAexit(AbstractAsyncContextManager): 63 pass 64 65 with self.assertRaises(TypeError): 66 MissingAexit() 67 68 def test_structural_subclassing(self): 69 class ManagerFromScratch: 70 async def __aenter__(self): 71 return self 72 async def __aexit__(self, exc_type, exc_value, traceback): 73 return None 74 75 self.assertTrue(issubclass(ManagerFromScratch, AbstractAsyncContextManager)) 76 77 class DefaultEnter(AbstractAsyncContextManager): 78 async def __aexit__(self, *args): 79 await super().__aexit__(*args) 80 81 self.assertTrue(issubclass(DefaultEnter, AbstractAsyncContextManager)) 82 83 class NoneAenter(ManagerFromScratch): 84 __aenter__ = None 85 86 self.assertFalse(issubclass(NoneAenter, AbstractAsyncContextManager)) 87 88 class NoneAexit(ManagerFromScratch): 89 __aexit__ = None 90 91 self.assertFalse(issubclass(NoneAexit, AbstractAsyncContextManager)) 92 93 94class AsyncContextManagerTestCase(unittest.TestCase): 95 96 @_async_test 97 async def test_contextmanager_plain(self): 98 state = [] 99 @asynccontextmanager 100 async def woohoo(): 101 state.append(1) 102 yield 42 103 state.append(999) 104 async with woohoo() as x: 105 self.assertEqual(state, [1]) 106 self.assertEqual(x, 42) 107 state.append(x) 108 self.assertEqual(state, [1, 42, 999]) 109 110 @_async_test 111 async def test_contextmanager_finally(self): 112 state = [] 113 @asynccontextmanager 114 async def woohoo(): 115 state.append(1) 116 try: 117 yield 42 118 finally: 119 state.append(999) 120 with self.assertRaises(ZeroDivisionError): 121 async with woohoo() as x: 122 self.assertEqual(state, [1]) 123 self.assertEqual(x, 42) 124 state.append(x) 125 raise ZeroDivisionError() 126 self.assertEqual(state, [1, 42, 999]) 127 128 @_async_test 129 async def test_contextmanager_no_reraise(self): 130 @asynccontextmanager 131 async def whee(): 132 yield 133 ctx = whee() 134 await ctx.__aenter__() 135 # Calling __aexit__ should not result in an exception 136 self.assertFalse(await ctx.__aexit__(TypeError, TypeError("foo"), None)) 137 138 @_async_test 139 async def test_contextmanager_trap_yield_after_throw(self): 140 @asynccontextmanager 141 async def whoo(): 142 try: 143 yield 144 except: 145 yield 146 ctx = whoo() 147 await ctx.__aenter__() 148 with self.assertRaises(RuntimeError): 149 await ctx.__aexit__(TypeError, TypeError('foo'), None) 150 151 @_async_test 152 async def test_contextmanager_trap_no_yield(self): 153 @asynccontextmanager 154 async def whoo(): 155 if False: 156 yield 157 ctx = whoo() 158 with self.assertRaises(RuntimeError): 159 await ctx.__aenter__() 160 161 @_async_test 162 async def test_contextmanager_trap_second_yield(self): 163 @asynccontextmanager 164 async def whoo(): 165 yield 166 yield 167 ctx = whoo() 168 await ctx.__aenter__() 169 with self.assertRaises(RuntimeError): 170 await ctx.__aexit__(None, None, None) 171 172 @_async_test 173 async def test_contextmanager_non_normalised(self): 174 @asynccontextmanager 175 async def whoo(): 176 try: 177 yield 178 except RuntimeError: 179 raise SyntaxError 180 181 ctx = whoo() 182 await ctx.__aenter__() 183 with self.assertRaises(SyntaxError): 184 await ctx.__aexit__(RuntimeError, None, None) 185 186 @_async_test 187 async def test_contextmanager_except(self): 188 state = [] 189 @asynccontextmanager 190 async def woohoo(): 191 state.append(1) 192 try: 193 yield 42 194 except ZeroDivisionError as e: 195 state.append(e.args[0]) 196 self.assertEqual(state, [1, 42, 999]) 197 async with woohoo() as x: 198 self.assertEqual(state, [1]) 199 self.assertEqual(x, 42) 200 state.append(x) 201 raise ZeroDivisionError(999) 202 self.assertEqual(state, [1, 42, 999]) 203 204 @_async_test 205 async def test_contextmanager_except_stopiter(self): 206 @asynccontextmanager 207 async def woohoo(): 208 yield 209 210 for stop_exc in (StopIteration('spam'), StopAsyncIteration('ham')): 211 with self.subTest(type=type(stop_exc)): 212 try: 213 async with woohoo(): 214 raise stop_exc 215 except Exception as ex: 216 self.assertIs(ex, stop_exc) 217 else: 218 self.fail(f'{stop_exc} was suppressed') 219 220 @_async_test 221 async def test_contextmanager_wrap_runtimeerror(self): 222 @asynccontextmanager 223 async def woohoo(): 224 try: 225 yield 226 except Exception as exc: 227 raise RuntimeError(f'caught {exc}') from exc 228 229 with self.assertRaises(RuntimeError): 230 async with woohoo(): 231 1 / 0 232 233 # If the context manager wrapped StopAsyncIteration in a RuntimeError, 234 # we also unwrap it, because we can't tell whether the wrapping was 235 # done by the generator machinery or by the generator itself. 236 with self.assertRaises(StopAsyncIteration): 237 async with woohoo(): 238 raise StopAsyncIteration 239 240 def _create_contextmanager_attribs(self): 241 def attribs(**kw): 242 def decorate(func): 243 for k,v in kw.items(): 244 setattr(func,k,v) 245 return func 246 return decorate 247 @asynccontextmanager 248 @attribs(foo='bar') 249 async def baz(spam): 250 """Whee!""" 251 yield 252 return baz 253 254 def test_contextmanager_attribs(self): 255 baz = self._create_contextmanager_attribs() 256 self.assertEqual(baz.__name__,'baz') 257 self.assertEqual(baz.foo, 'bar') 258 259 @support.requires_docstrings 260 def test_contextmanager_doc_attrib(self): 261 baz = self._create_contextmanager_attribs() 262 self.assertEqual(baz.__doc__, "Whee!") 263 264 @support.requires_docstrings 265 @_async_test 266 async def test_instance_docstring_given_cm_docstring(self): 267 baz = self._create_contextmanager_attribs()(None) 268 self.assertEqual(baz.__doc__, "Whee!") 269 async with baz: 270 pass # suppress warning 271 272 @_async_test 273 async def test_keywords(self): 274 # Ensure no keyword arguments are inhibited 275 @asynccontextmanager 276 async def woohoo(self, func, args, kwds): 277 yield (self, func, args, kwds) 278 async with woohoo(self=11, func=22, args=33, kwds=44) as target: 279 self.assertEqual(target, (11, 22, 33, 44)) 280 281 282class TestAsyncExitStack(TestBaseExitStack, unittest.TestCase): 283 class SyncAsyncExitStack(AsyncExitStack): 284 @staticmethod 285 def run_coroutine(coro): 286 loop = asyncio.get_event_loop() 287 288 f = asyncio.ensure_future(coro) 289 f.add_done_callback(lambda f: loop.stop()) 290 loop.run_forever() 291 292 exc = f.exception() 293 294 if not exc: 295 return f.result() 296 else: 297 context = exc.__context__ 298 299 try: 300 raise exc 301 except: 302 exc.__context__ = context 303 raise exc 304 305 def close(self): 306 return self.run_coroutine(self.aclose()) 307 308 def __enter__(self): 309 return self.run_coroutine(self.__aenter__()) 310 311 def __exit__(self, *exc_details): 312 return self.run_coroutine(self.__aexit__(*exc_details)) 313 314 exit_stack = SyncAsyncExitStack 315 316 def setUp(self): 317 self.loop = asyncio.new_event_loop() 318 asyncio.set_event_loop(self.loop) 319 self.addCleanup(self.loop.close) 320 321 @_async_test 322 async def test_async_callback(self): 323 expected = [ 324 ((), {}), 325 ((1,), {}), 326 ((1,2), {}), 327 ((), dict(example=1)), 328 ((1,), dict(example=1)), 329 ((1,2), dict(example=1)), 330 ] 331 result = [] 332 async def _exit(*args, **kwds): 333 """Test metadata propagation""" 334 result.append((args, kwds)) 335 336 async with AsyncExitStack() as stack: 337 for args, kwds in reversed(expected): 338 if args and kwds: 339 f = stack.push_async_callback(_exit, *args, **kwds) 340 elif args: 341 f = stack.push_async_callback(_exit, *args) 342 elif kwds: 343 f = stack.push_async_callback(_exit, **kwds) 344 else: 345 f = stack.push_async_callback(_exit) 346 self.assertIs(f, _exit) 347 for wrapper in stack._exit_callbacks: 348 self.assertIs(wrapper[1].__wrapped__, _exit) 349 self.assertNotEqual(wrapper[1].__name__, _exit.__name__) 350 self.assertIsNone(wrapper[1].__doc__, _exit.__doc__) 351 352 self.assertEqual(result, expected) 353 354 result = [] 355 async with AsyncExitStack() as stack: 356 with self.assertRaises(TypeError): 357 stack.push_async_callback(arg=1) 358 with self.assertRaises(TypeError): 359 self.exit_stack.push_async_callback(arg=2) 360 stack.push_async_callback(callback=_exit, arg=3) 361 self.assertEqual(result, [((), {'arg': 3})]) 362 363 @_async_test 364 async def test_async_push(self): 365 exc_raised = ZeroDivisionError 366 async def _expect_exc(exc_type, exc, exc_tb): 367 self.assertIs(exc_type, exc_raised) 368 async def _suppress_exc(*exc_details): 369 return True 370 async def _expect_ok(exc_type, exc, exc_tb): 371 self.assertIsNone(exc_type) 372 self.assertIsNone(exc) 373 self.assertIsNone(exc_tb) 374 class ExitCM(object): 375 def __init__(self, check_exc): 376 self.check_exc = check_exc 377 async def __aenter__(self): 378 self.fail("Should not be called!") 379 async def __aexit__(self, *exc_details): 380 await self.check_exc(*exc_details) 381 382 async with self.exit_stack() as stack: 383 stack.push_async_exit(_expect_ok) 384 self.assertIs(stack._exit_callbacks[-1][1], _expect_ok) 385 cm = ExitCM(_expect_ok) 386 stack.push_async_exit(cm) 387 self.assertIs(stack._exit_callbacks[-1][1].__self__, cm) 388 stack.push_async_exit(_suppress_exc) 389 self.assertIs(stack._exit_callbacks[-1][1], _suppress_exc) 390 cm = ExitCM(_expect_exc) 391 stack.push_async_exit(cm) 392 self.assertIs(stack._exit_callbacks[-1][1].__self__, cm) 393 stack.push_async_exit(_expect_exc) 394 self.assertIs(stack._exit_callbacks[-1][1], _expect_exc) 395 stack.push_async_exit(_expect_exc) 396 self.assertIs(stack._exit_callbacks[-1][1], _expect_exc) 397 1/0 398 399 @_async_test 400 async def test_async_enter_context(self): 401 class TestCM(object): 402 async def __aenter__(self): 403 result.append(1) 404 async def __aexit__(self, *exc_details): 405 result.append(3) 406 407 result = [] 408 cm = TestCM() 409 410 async with AsyncExitStack() as stack: 411 @stack.push_async_callback # Registered first => cleaned up last 412 async def _exit(): 413 result.append(4) 414 self.assertIsNotNone(_exit) 415 await stack.enter_async_context(cm) 416 self.assertIs(stack._exit_callbacks[-1][1].__self__, cm) 417 result.append(2) 418 419 self.assertEqual(result, [1, 2, 3, 4]) 420 421 @_async_test 422 async def test_async_exit_exception_chaining(self): 423 # Ensure exception chaining matches the reference behaviour 424 async def raise_exc(exc): 425 raise exc 426 427 saved_details = None 428 async def suppress_exc(*exc_details): 429 nonlocal saved_details 430 saved_details = exc_details 431 return True 432 433 try: 434 async with self.exit_stack() as stack: 435 stack.push_async_callback(raise_exc, IndexError) 436 stack.push_async_callback(raise_exc, KeyError) 437 stack.push_async_callback(raise_exc, AttributeError) 438 stack.push_async_exit(suppress_exc) 439 stack.push_async_callback(raise_exc, ValueError) 440 1 / 0 441 except IndexError as exc: 442 self.assertIsInstance(exc.__context__, KeyError) 443 self.assertIsInstance(exc.__context__.__context__, AttributeError) 444 # Inner exceptions were suppressed 445 self.assertIsNone(exc.__context__.__context__.__context__) 446 else: 447 self.fail("Expected IndexError, but no exception was raised") 448 # Check the inner exceptions 449 inner_exc = saved_details[1] 450 self.assertIsInstance(inner_exc, ValueError) 451 self.assertIsInstance(inner_exc.__context__, ZeroDivisionError) 452 453 454if __name__ == '__main__': 455 unittest.main() 456