1import sys
2from functools import wraps
3from types import coroutine
4import inspect
5from inspect import (
6    getcoroutinestate, CORO_CREATED, CORO_CLOSED, CORO_SUSPENDED
7)
8import collections.abc
9
10
11class YieldWrapper:
12    def __init__(self, payload):
13        self.payload = payload
14
15
16def _wrap(value):
17    return YieldWrapper(value)
18
19
20def _is_wrapped(box):
21    return isinstance(box, YieldWrapper)
22
23
24def _unwrap(box):
25    return box.payload
26
27
28# This is the magic code that lets you use yield_ and yield_from_ with native
29# generators.
30#
31# The old version worked great on Linux and MacOS, but not on Windows, because
32# it depended on _PyAsyncGenValueWrapperNew. The new version segfaults
33# everywhere, and I'm not sure why -- probably my lack of understanding
34# of ctypes and refcounts.
35#
36# There are also some commented out tests that should be re-enabled if this is
37# fixed:
38#
39# if sys.version_info >= (3, 6):
40#     # Use the same box type that the interpreter uses internally. This allows
41#     # yield_ and (more importantly!) yield_from_ to work in built-in
42#     # generators.
43#     import ctypes  # mua ha ha.
44#
45#     # We used to call _PyAsyncGenValueWrapperNew to create and set up new
46#     # wrapper objects, but that symbol isn't available on Windows:
47#     #
48#     #   https://github.com/python-trio/async_generator/issues/5
49#     #
50#     # Fortunately, the type object is available, but it means we have to do
51#     # this the hard way.
52#
53#     # We don't actually need to access this, but we need to make a ctypes
54#     # structure so we can call addressof.
55#     class _ctypes_PyTypeObject(ctypes.Structure):
56#         pass
57#     _PyAsyncGenWrappedValue_Type_ptr = ctypes.addressof(
58#         _ctypes_PyTypeObject.in_dll(
59#             ctypes.pythonapi, "_PyAsyncGenWrappedValue_Type"))
60#     _PyObject_GC_New = ctypes.pythonapi._PyObject_GC_New
61#     _PyObject_GC_New.restype = ctypes.py_object
62#     _PyObject_GC_New.argtypes = (ctypes.c_void_p,)
63#
64#     _Py_IncRef = ctypes.pythonapi.Py_IncRef
65#     _Py_IncRef.restype = None
66#     _Py_IncRef.argtypes = (ctypes.py_object,)
67#
68#     class _ctypes_PyAsyncGenWrappedValue(ctypes.Structure):
69#         _fields_ = [
70#             ('PyObject_HEAD', ctypes.c_byte * object().__sizeof__()),
71#             ('agw_val', ctypes.py_object),
72#         ]
73#     def _wrap(value):
74#         box = _PyObject_GC_New(_PyAsyncGenWrappedValue_Type_ptr)
75#         raw = ctypes.cast(ctypes.c_void_p(id(box)),
76#                           ctypes.POINTER(_ctypes_PyAsyncGenWrappedValue))
77#         raw.contents.agw_val = value
78#         _Py_IncRef(value)
79#         return box
80#
81#     def _unwrap(box):
82#         assert _is_wrapped(box)
83#         raw = ctypes.cast(ctypes.c_void_p(id(box)),
84#                           ctypes.POINTER(_ctypes_PyAsyncGenWrappedValue))
85#         value = raw.contents.agw_val
86#         _Py_IncRef(value)
87#         return value
88#
89#     _PyAsyncGenWrappedValue_Type = type(_wrap(1))
90#     def _is_wrapped(box):
91#         return isinstance(box, _PyAsyncGenWrappedValue_Type)
92
93
94# The magic @coroutine decorator is how you write the bottom level of
95# coroutine stacks -- 'async def' can only use 'await' = yield from; but
96# eventually we must bottom out in a @coroutine that calls plain 'yield'.
97@coroutine
98def _yield_(value):
99    return (yield _wrap(value))
100
101
102# But we wrap the bare @coroutine version in an async def, because async def
103# has the magic feature that users can get warnings messages if they forget to
104# use 'await'.
105async def yield_(value=None):
106    return await _yield_(value)
107
108
109async def yield_from_(delegate):
110    # Transcribed with adaptations from:
111    #
112    #   https://www.python.org/dev/peps/pep-0380/#formal-semantics
113    #
114    # This takes advantage of a sneaky trick: if an @async_generator-wrapped
115    # function calls another async function (like yield_from_), and that
116    # second async function calls yield_, then because of the hack we use to
117    # implement yield_, the yield_ will actually propagate through yield_from_
118    # back to the @async_generator wrapper. So even though we're a regular
119    # function, we can directly yield values out of the calling async
120    # generator.
121    def unpack_StopAsyncIteration(e):
122        if e.args:
123            return e.args[0]
124        else:
125            return None
126
127    _i = type(delegate).__aiter__(delegate)
128    if hasattr(_i, "__await__"):
129        _i = await _i
130    try:
131        _y = await type(_i).__anext__(_i)
132    except StopAsyncIteration as _e:
133        _r = unpack_StopAsyncIteration(_e)
134    else:
135        while 1:
136            try:
137                _s = await yield_(_y)
138            except GeneratorExit as _e:
139                try:
140                    _m = _i.aclose
141                except AttributeError:
142                    pass
143                else:
144                    await _m()
145                raise _e
146            except BaseException as _e:
147                _x = sys.exc_info()
148                try:
149                    _m = _i.athrow
150                except AttributeError:
151                    raise _e
152                else:
153                    try:
154                        _y = await _m(*_x)
155                    except StopAsyncIteration as _e:
156                        _r = unpack_StopAsyncIteration(_e)
157                        break
158            else:
159                try:
160                    if _s is None:
161                        _y = await type(_i).__anext__(_i)
162                    else:
163                        _y = await _i.asend(_s)
164                except StopAsyncIteration as _e:
165                    _r = unpack_StopAsyncIteration(_e)
166                    break
167    return _r
168
169
170# This is the awaitable / iterator that implements asynciter.__anext__() and
171# friends.
172#
173# Note: we can be sloppy about the distinction between
174#
175#   type(self._it).__next__(self._it)
176#
177# and
178#
179#   self._it.__next__()
180#
181# because we happen to know that self._it is not a general iterator object,
182# but specifically a coroutine iterator object where these are equivalent.
183class ANextIter:
184    def __init__(self, it, first_fn, *first_args):
185        self._it = it
186        self._first_fn = first_fn
187        self._first_args = first_args
188
189    def __await__(self):
190        return self
191
192    def __next__(self):
193        if self._first_fn is not None:
194            first_fn = self._first_fn
195            first_args = self._first_args
196            self._first_fn = self._first_args = None
197            return self._invoke(first_fn, *first_args)
198        else:
199            return self._invoke(self._it.__next__)
200
201    def send(self, value):
202        return self._invoke(self._it.send, value)
203
204    def throw(self, type, value=None, traceback=None):
205        return self._invoke(self._it.throw, type, value, traceback)
206
207    def _invoke(self, fn, *args):
208        try:
209            result = fn(*args)
210        except StopIteration as e:
211            # The underlying generator returned, so we should signal the end
212            # of iteration.
213            raise StopAsyncIteration(e.value)
214        except StopAsyncIteration as e:
215            # PEP 479 says: if a generator raises Stop(Async)Iteration, then
216            # it should be wrapped into a RuntimeError. Python automatically
217            # enforces this for StopIteration; for StopAsyncIteration we need
218            # to it ourselves.
219            raise RuntimeError(
220                "async_generator raise StopAsyncIteration"
221            ) from e
222        if _is_wrapped(result):
223            raise StopIteration(_unwrap(result))
224        else:
225            return result
226
227
228UNSPECIFIED = object()
229try:
230    from sys import get_asyncgen_hooks, set_asyncgen_hooks
231
232except ImportError:
233    import threading
234
235    asyncgen_hooks = collections.namedtuple(
236        "asyncgen_hooks", ("firstiter", "finalizer")
237    )
238
239    class _hooks_storage(threading.local):
240        def __init__(self):
241            self.firstiter = None
242            self.finalizer = None
243
244    _hooks = _hooks_storage()
245
246    def get_asyncgen_hooks():
247        return asyncgen_hooks(
248            firstiter=_hooks.firstiter, finalizer=_hooks.finalizer
249        )
250
251    def set_asyncgen_hooks(firstiter=UNSPECIFIED, finalizer=UNSPECIFIED):
252        if firstiter is not UNSPECIFIED:
253            if firstiter is None or callable(firstiter):
254                _hooks.firstiter = firstiter
255            else:
256                raise TypeError(
257                    "callable firstiter expected, got {}".format(
258                        type(firstiter).__name__
259                    )
260                )
261
262        if finalizer is not UNSPECIFIED:
263            if finalizer is None or callable(finalizer):
264                _hooks.finalizer = finalizer
265            else:
266                raise TypeError(
267                    "callable finalizer expected, got {}".format(
268                        type(finalizer).__name__
269                    )
270                )
271
272
273class AsyncGenerator:
274    # https://bitbucket.org/pypy/pypy/issues/2786:
275    # PyPy implements 'await' in a way that requires the frame object
276    # used to execute a coroutine to keep a weakref to that coroutine.
277    # During a GC pass, weakrefs to all doomed objects are broken
278    # before any of the doomed objects' finalizers are invoked.
279    # If an AsyncGenerator is unreachable, its _coroutine probably
280    # is too, and the weakref from ag._coroutine.cr_frame to
281    # ag._coroutine will be broken before ag.__del__ can do its
282    # one-turn close attempt or can schedule a full aclose() using
283    # the registered finalization hook. It doesn't look like the
284    # underlying issue is likely to be fully fixed anytime soon,
285    # so we work around it by preventing an AsyncGenerator and
286    # its _coroutine from being considered newly unreachable at
287    # the same time if the AsyncGenerator's finalizer might want
288    # to iterate the coroutine some more.
289    _pypy_issue2786_workaround = set()
290
291    def __init__(self, coroutine):
292        self._coroutine = coroutine
293        self._it = coroutine.__await__()
294        self.ag_running = False
295        self._finalizer = None
296        self._closed = False
297        self._hooks_inited = False
298
299    # On python 3.5.0 and 3.5.1, __aiter__ must be awaitable.
300    # Starting in 3.5.2, it should not be awaitable, and if it is, then it
301    #   raises a PendingDeprecationWarning.
302    # See:
303    #   https://www.python.org/dev/peps/pep-0492/#api-design-and-implementation-revisions
304    #   https://docs.python.org/3/reference/datamodel.html#async-iterators
305    #   https://bugs.python.org/issue27243
306    if sys.version_info < (3, 5, 2):
307
308        async def __aiter__(self):
309            return self
310    else:
311
312        def __aiter__(self):
313            return self
314
315    ################################################################
316    # Introspection attributes
317    ################################################################
318
319    @property
320    def ag_code(self):
321        return self._coroutine.cr_code
322
323    @property
324    def ag_frame(self):
325        return self._coroutine.cr_frame
326
327    ################################################################
328    # Core functionality
329    ################################################################
330
331    # These need to return awaitables, rather than being async functions,
332    # to match the native behavior where the firstiter hook is called
333    # immediately on asend()/etc, even if the coroutine that asend()
334    # produces isn't awaited for a bit.
335
336    def __anext__(self):
337        return self._do_it(self._it.__next__)
338
339    def asend(self, value):
340        return self._do_it(self._it.send, value)
341
342    def athrow(self, type, value=None, traceback=None):
343        return self._do_it(self._it.throw, type, value, traceback)
344
345    def _do_it(self, start_fn, *args):
346        if not self._hooks_inited:
347            self._hooks_inited = True
348            (firstiter, self._finalizer) = get_asyncgen_hooks()
349            if firstiter is not None:
350                firstiter(self)
351            if sys.implementation.name == "pypy":
352                self._pypy_issue2786_workaround.add(self._coroutine)
353
354        # On CPython 3.5.2 (but not 3.5.0), coroutines get cranky if you try
355        # to iterate them after they're exhausted. Generators OTOH just raise
356        # StopIteration. We want to convert the one into the other, so we need
357        # to avoid iterating stopped coroutines.
358        if getcoroutinestate(self._coroutine) is CORO_CLOSED:
359            raise StopAsyncIteration()
360
361        async def step():
362            if self.ag_running:
363                raise ValueError("async generator already executing")
364            try:
365                self.ag_running = True
366                return await ANextIter(self._it, start_fn, *args)
367            except StopAsyncIteration:
368                self._pypy_issue2786_workaround.discard(self._coroutine)
369                raise
370            finally:
371                self.ag_running = False
372
373        return step()
374
375    ################################################################
376    # Cleanup
377    ################################################################
378
379    async def aclose(self):
380        state = getcoroutinestate(self._coroutine)
381        if state is CORO_CLOSED or self._closed:
382            return
383        # Make sure that even if we raise "async_generator ignored
384        # GeneratorExit", and thus fail to exhaust the coroutine,
385        # __del__ doesn't complain again.
386        self._closed = True
387        if state is CORO_CREATED:
388            # Make sure that aclose() on an unstarted generator returns
389            # successfully and prevents future iteration.
390            self._it.close()
391            return
392        try:
393            await self.athrow(GeneratorExit)
394        except (GeneratorExit, StopAsyncIteration):
395            self._pypy_issue2786_workaround.discard(self._coroutine)
396        else:
397            raise RuntimeError("async_generator ignored GeneratorExit")
398
399    def __del__(self):
400        self._pypy_issue2786_workaround.discard(self._coroutine)
401        if getcoroutinestate(self._coroutine) is CORO_CREATED:
402            # Never started, nothing to clean up, just suppress the "coroutine
403            # never awaited" message.
404            self._coroutine.close()
405        if getcoroutinestate(self._coroutine
406                             ) is CORO_SUSPENDED and not self._closed:
407            if self._finalizer is not None:
408                self._finalizer(self)
409            else:
410                # Mimic the behavior of native generators on GC with no finalizer:
411                # throw in GeneratorExit, run for one turn, and complain if it didn't
412                # finish.
413                thrower = self.athrow(GeneratorExit)
414                try:
415                    thrower.send(None)
416                except (GeneratorExit, StopAsyncIteration):
417                    pass
418                except StopIteration:
419                    raise RuntimeError("async_generator ignored GeneratorExit")
420                else:
421                    raise RuntimeError(
422                        "async_generator {!r} awaited during finalization; install "
423                        "a finalization hook to support this, or wrap it in "
424                        "'async with aclosing(...):'"
425                        .format(self.ag_code.co_name)
426                    )
427                finally:
428                    thrower.close()
429
430
431if hasattr(collections.abc, "AsyncGenerator"):
432    collections.abc.AsyncGenerator.register(AsyncGenerator)
433
434
435def async_generator(coroutine_maker):
436    @wraps(coroutine_maker)
437    def async_generator_maker(*args, **kwargs):
438        return AsyncGenerator(coroutine_maker(*args, **kwargs))
439
440    async_generator_maker._async_gen_function = id(async_generator_maker)
441    return async_generator_maker
442
443
444def isasyncgen(obj):
445    if hasattr(inspect, "isasyncgen"):
446        if inspect.isasyncgen(obj):
447            return True
448    return isinstance(obj, AsyncGenerator)
449
450
451def isasyncgenfunction(obj):
452    if hasattr(inspect, "isasyncgenfunction"):
453        if inspect.isasyncgenfunction(obj):
454            return True
455    return getattr(obj, "_async_gen_function", -1) == id(obj)
456