1"""pytest-asyncio implementation.""" 2import asyncio 3import contextlib 4import functools 5import inspect 6import socket 7 8import pytest 9try: 10 from _pytest.python import transfer_markers 11except ImportError: # Pytest 4.1.0 removes the transfer_marker api (#104) 12 def transfer_markers(*args, **kwargs): # noqa 13 """Noop when over pytest 4.1.0""" 14 pass 15 16try: 17 from async_generator import isasyncgenfunction 18except ImportError: 19 from inspect import isasyncgenfunction 20 21 22def _is_coroutine(obj): 23 """Check to see if an object is really an asyncio coroutine.""" 24 return asyncio.iscoroutinefunction(obj) or inspect.isgeneratorfunction(obj) 25 26 27def pytest_configure(config): 28 """Inject documentation.""" 29 config.addinivalue_line("markers", 30 "asyncio: " 31 "mark the test as a coroutine, it will be " 32 "run using an asyncio event loop") 33 34 35@pytest.mark.tryfirst 36def pytest_pycollect_makeitem(collector, name, obj): 37 """A pytest hook to collect asyncio coroutines.""" 38 if collector.funcnamefilter(name) and _is_coroutine(obj): 39 item = pytest.Function.from_parent(collector, name=name) 40 41 # Due to how pytest test collection works, module-level pytestmarks 42 # are applied after the collection step. Since this is the collection 43 # step, we look ourselves. 44 transfer_markers(obj, item.cls, item.module) 45 item = pytest.Function.from_parent(collector, name=name) # To reload keywords. 46 47 if 'asyncio' in item.keywords: 48 return list(collector._genfunctions(name, obj)) 49 50 51class FixtureStripper: 52 """Include additional Fixture, and then strip them""" 53 REQUEST = "request" 54 EVENT_LOOP = "event_loop" 55 56 def __init__(self, fixturedef): 57 self.fixturedef = fixturedef 58 self.to_strip = set() 59 60 def add(self, name): 61 """Add fixture name to fixturedef 62 and record in to_strip list (If not previously included)""" 63 if name in self.fixturedef.argnames: 64 return 65 self.fixturedef.argnames += (name, ) 66 self.to_strip.add(name) 67 68 def get_and_strip_from(self, name, data_dict): 69 """Strip name from data, and return value""" 70 result = data_dict[name] 71 if name in self.to_strip: 72 del data_dict[name] 73 return result 74 75@pytest.hookimpl(trylast=True) 76def pytest_fixture_post_finalizer(fixturedef, request): 77 """Called after fixture teardown""" 78 if fixturedef.argname == "event_loop": 79 # Set empty loop policy, so that subsequent get_event_loop() provides a new loop 80 asyncio.set_event_loop_policy(None) 81 82 83 84@pytest.hookimpl(hookwrapper=True) 85def pytest_fixture_setup(fixturedef, request): 86 """Adjust the event loop policy when an event loop is produced.""" 87 if fixturedef.argname == "event_loop": 88 outcome = yield 89 loop = outcome.get_result() 90 policy = asyncio.get_event_loop_policy() 91 policy.set_event_loop(loop) 92 return 93 94 if isasyncgenfunction(fixturedef.func): 95 # This is an async generator function. Wrap it accordingly. 96 generator = fixturedef.func 97 98 fixture_stripper = FixtureStripper(fixturedef) 99 fixture_stripper.add(FixtureStripper.EVENT_LOOP) 100 fixture_stripper.add(FixtureStripper.REQUEST) 101 102 103 def wrapper(*args, **kwargs): 104 loop = fixture_stripper.get_and_strip_from(FixtureStripper.EVENT_LOOP, kwargs) 105 request = fixture_stripper.get_and_strip_from(FixtureStripper.REQUEST, kwargs) 106 107 gen_obj = generator(*args, **kwargs) 108 109 async def setup(): 110 res = await gen_obj.__anext__() 111 return res 112 113 def finalizer(): 114 """Yield again, to finalize.""" 115 async def async_finalizer(): 116 try: 117 await gen_obj.__anext__() 118 except StopAsyncIteration: 119 pass 120 else: 121 msg = "Async generator fixture didn't stop." 122 msg += "Yield only once." 123 raise ValueError(msg) 124 loop.run_until_complete(async_finalizer()) 125 126 request.addfinalizer(finalizer) 127 return loop.run_until_complete(setup()) 128 129 fixturedef.func = wrapper 130 elif inspect.iscoroutinefunction(fixturedef.func): 131 coro = fixturedef.func 132 133 fixture_stripper = FixtureStripper(fixturedef) 134 fixture_stripper.add(FixtureStripper.EVENT_LOOP) 135 136 def wrapper(*args, **kwargs): 137 loop = fixture_stripper.get_and_strip_from(FixtureStripper.EVENT_LOOP, kwargs) 138 139 async def setup(): 140 res = await coro(*args, **kwargs) 141 return res 142 143 return loop.run_until_complete(setup()) 144 145 fixturedef.func = wrapper 146 yield 147 148 149@pytest.hookimpl(tryfirst=True, hookwrapper=True) 150def pytest_pyfunc_call(pyfuncitem): 151 """ 152 Run asyncio marked test functions in an event loop instead of a normal 153 function call. 154 """ 155 if 'asyncio' in pyfuncitem.keywords: 156 if getattr(pyfuncitem.obj, 'is_hypothesis_test', False): 157 pyfuncitem.obj.hypothesis.inner_test = wrap_in_sync( 158 pyfuncitem.obj.hypothesis.inner_test, 159 _loop=pyfuncitem.funcargs['event_loop'] 160 ) 161 else: 162 pyfuncitem.obj = wrap_in_sync( 163 pyfuncitem.obj, 164 _loop=pyfuncitem.funcargs['event_loop'] 165 ) 166 yield 167 168 169def wrap_in_sync(func, _loop): 170 """Return a sync wrapper around an async function executing it in the 171 current event loop.""" 172 173 @functools.wraps(func) 174 def inner(**kwargs): 175 coro = func(**kwargs) 176 if coro is not None: 177 task = asyncio.ensure_future(coro, loop=_loop) 178 try: 179 _loop.run_until_complete(task) 180 except BaseException: 181 # run_until_complete doesn't get the result from exceptions 182 # that are not subclasses of `Exception`. Consume all 183 # exceptions to prevent asyncio's warning from logging. 184 if task.done() and not task.cancelled(): 185 task.exception() 186 raise 187 return inner 188 189 190def pytest_runtest_setup(item): 191 if 'asyncio' in item.keywords: 192 # inject an event loop fixture for all async tests 193 if 'event_loop' in item.fixturenames: 194 item.fixturenames.remove('event_loop') 195 item.fixturenames.insert(0, 'event_loop') 196 if item.get_closest_marker("asyncio") is not None \ 197 and not getattr(item.obj, 'hypothesis', False) \ 198 and getattr(item.obj, 'is_hypothesis_test', False): 199 pytest.fail( 200 'test function `%r` is using Hypothesis, but pytest-asyncio ' 201 'only works with Hypothesis 3.64.0 or later.' % item 202 ) 203 204 205@pytest.fixture 206def event_loop(request): 207 """Create an instance of the default event loop for each test case.""" 208 loop = asyncio.get_event_loop_policy().new_event_loop() 209 yield loop 210 loop.close() 211 212 213def _unused_tcp_port(): 214 """Find an unused localhost TCP port from 1024-65535 and return it.""" 215 with contextlib.closing(socket.socket()) as sock: 216 sock.bind(('127.0.0.1', 0)) 217 return sock.getsockname()[1] 218 219 220@pytest.fixture 221def unused_tcp_port(): 222 return _unused_tcp_port() 223 224 225@pytest.fixture 226def unused_tcp_port_factory(): 227 """A factory function, producing different unused TCP ports.""" 228 produced = set() 229 230 def factory(): 231 """Return an unused port.""" 232 port = _unused_tcp_port() 233 234 while port in produced: 235 port = _unused_tcp_port() 236 237 produced.add(port) 238 239 return port 240 return factory 241