1"""pytest-asyncio implementation."""
2import asyncio
3import contextlib
4import functools
5import inspect
6import socket
7
8import pytest
9try:
10    from _pytest.python import transfer_markers
11except ImportError:  # Pytest 4.1.0 removes the transfer_marker api (#104)
12    def transfer_markers(*args, **kwargs):  # noqa
13        """Noop when over pytest 4.1.0"""
14        pass
15
16try:
17    from async_generator import isasyncgenfunction
18except ImportError:
19    from inspect import isasyncgenfunction
20
21
22def _is_coroutine(obj):
23    """Check to see if an object is really an asyncio coroutine."""
24    return asyncio.iscoroutinefunction(obj) or inspect.isgeneratorfunction(obj)
25
26
27def pytest_configure(config):
28    """Inject documentation."""
29    config.addinivalue_line("markers",
30                            "asyncio: "
31                            "mark the test as a coroutine, it will be "
32                            "run using an asyncio event loop")
33
34
35@pytest.mark.tryfirst
36def pytest_pycollect_makeitem(collector, name, obj):
37    """A pytest hook to collect asyncio coroutines."""
38    if collector.funcnamefilter(name) and _is_coroutine(obj):
39        item = pytest.Function.from_parent(collector, name=name)
40
41        # Due to how pytest test collection works, module-level pytestmarks
42        # are applied after the collection step. Since this is the collection
43        # step, we look ourselves.
44        transfer_markers(obj, item.cls, item.module)
45        item = pytest.Function.from_parent(collector, name=name)  # To reload keywords.
46
47        if 'asyncio' in item.keywords:
48            return list(collector._genfunctions(name, obj))
49
50
51class FixtureStripper:
52    """Include additional Fixture, and then strip them"""
53    REQUEST = "request"
54    EVENT_LOOP = "event_loop"
55
56    def __init__(self, fixturedef):
57        self.fixturedef = fixturedef
58        self.to_strip = set()
59
60    def add(self, name):
61        """Add fixture name to fixturedef
62         and record in to_strip list (If not previously included)"""
63        if name in self.fixturedef.argnames:
64            return
65        self.fixturedef.argnames += (name, )
66        self.to_strip.add(name)
67
68    def get_and_strip_from(self, name, data_dict):
69        """Strip name from data, and return value"""
70        result = data_dict[name]
71        if name in self.to_strip:
72            del data_dict[name]
73        return result
74
75@pytest.hookimpl(trylast=True)
76def pytest_fixture_post_finalizer(fixturedef, request):
77    """Called after fixture teardown"""
78    if fixturedef.argname == "event_loop":
79        # Set empty loop policy, so that subsequent get_event_loop() provides a new loop
80        asyncio.set_event_loop_policy(None)
81
82
83
84@pytest.hookimpl(hookwrapper=True)
85def pytest_fixture_setup(fixturedef, request):
86    """Adjust the event loop policy when an event loop is produced."""
87    if fixturedef.argname == "event_loop":
88        outcome = yield
89        loop = outcome.get_result()
90        policy = asyncio.get_event_loop_policy()
91        policy.set_event_loop(loop)
92        return
93
94    if isasyncgenfunction(fixturedef.func):
95        # This is an async generator function. Wrap it accordingly.
96        generator = fixturedef.func
97
98        fixture_stripper = FixtureStripper(fixturedef)
99        fixture_stripper.add(FixtureStripper.EVENT_LOOP)
100        fixture_stripper.add(FixtureStripper.REQUEST)
101
102
103        def wrapper(*args, **kwargs):
104            loop = fixture_stripper.get_and_strip_from(FixtureStripper.EVENT_LOOP, kwargs)
105            request = fixture_stripper.get_and_strip_from(FixtureStripper.REQUEST, kwargs)
106
107            gen_obj = generator(*args, **kwargs)
108
109            async def setup():
110                res = await gen_obj.__anext__()
111                return res
112
113            def finalizer():
114                """Yield again, to finalize."""
115                async def async_finalizer():
116                    try:
117                        await gen_obj.__anext__()
118                    except StopAsyncIteration:
119                        pass
120                    else:
121                        msg = "Async generator fixture didn't stop."
122                        msg += "Yield only once."
123                        raise ValueError(msg)
124                loop.run_until_complete(async_finalizer())
125
126            request.addfinalizer(finalizer)
127            return loop.run_until_complete(setup())
128
129        fixturedef.func = wrapper
130    elif inspect.iscoroutinefunction(fixturedef.func):
131        coro = fixturedef.func
132
133        fixture_stripper = FixtureStripper(fixturedef)
134        fixture_stripper.add(FixtureStripper.EVENT_LOOP)
135
136        def wrapper(*args, **kwargs):
137            loop = fixture_stripper.get_and_strip_from(FixtureStripper.EVENT_LOOP, kwargs)
138
139            async def setup():
140                res = await coro(*args, **kwargs)
141                return res
142
143            return loop.run_until_complete(setup())
144
145        fixturedef.func = wrapper
146    yield
147
148
149@pytest.hookimpl(tryfirst=True, hookwrapper=True)
150def pytest_pyfunc_call(pyfuncitem):
151    """
152    Run asyncio marked test functions in an event loop instead of a normal
153    function call.
154    """
155    if 'asyncio' in pyfuncitem.keywords:
156        if getattr(pyfuncitem.obj, 'is_hypothesis_test', False):
157            pyfuncitem.obj.hypothesis.inner_test = wrap_in_sync(
158                pyfuncitem.obj.hypothesis.inner_test,
159                _loop=pyfuncitem.funcargs['event_loop']
160            )
161        else:
162            pyfuncitem.obj = wrap_in_sync(
163                pyfuncitem.obj,
164                _loop=pyfuncitem.funcargs['event_loop']
165            )
166    yield
167
168
169def wrap_in_sync(func, _loop):
170    """Return a sync wrapper around an async function executing it in the
171    current event loop."""
172
173    @functools.wraps(func)
174    def inner(**kwargs):
175        coro = func(**kwargs)
176        if coro is not None:
177            task = asyncio.ensure_future(coro, loop=_loop)
178            try:
179                _loop.run_until_complete(task)
180            except BaseException:
181                # run_until_complete doesn't get the result from exceptions
182                # that are not subclasses of `Exception`. Consume all
183                # exceptions to prevent asyncio's warning from logging.
184                if task.done() and not task.cancelled():
185                    task.exception()
186                raise
187    return inner
188
189
190def pytest_runtest_setup(item):
191    if 'asyncio' in item.keywords:
192        # inject an event loop fixture for all async tests
193        if 'event_loop' in item.fixturenames:
194            item.fixturenames.remove('event_loop')
195        item.fixturenames.insert(0, 'event_loop')
196    if item.get_closest_marker("asyncio") is not None \
197        and not getattr(item.obj, 'hypothesis', False) \
198        and getattr(item.obj, 'is_hypothesis_test', False):
199            pytest.fail(
200                'test function `%r` is using Hypothesis, but pytest-asyncio '
201                'only works with Hypothesis 3.64.0 or later.' % item
202            )
203
204
205@pytest.fixture
206def event_loop(request):
207    """Create an instance of the default event loop for each test case."""
208    loop = asyncio.get_event_loop_policy().new_event_loop()
209    yield loop
210    loop.close()
211
212
213def _unused_tcp_port():
214    """Find an unused localhost TCP port from 1024-65535 and return it."""
215    with contextlib.closing(socket.socket()) as sock:
216        sock.bind(('127.0.0.1', 0))
217        return sock.getsockname()[1]
218
219
220@pytest.fixture
221def unused_tcp_port():
222    return _unused_tcp_port()
223
224
225@pytest.fixture
226def unused_tcp_port_factory():
227    """A factory function, producing different unused TCP ports."""
228    produced = set()
229
230    def factory():
231        """Return an unused port."""
232        port = _unused_tcp_port()
233
234        while port in produced:
235            port = _unused_tcp_port()
236
237        produced.add(port)
238
239        return port
240    return factory
241