1# mode: run
2# tag: pep492, asyncfor, await
3
4import sys
5
6if sys.version_info >= (3, 5, 0, 'beta'):
7    # pass Cython implemented AsyncIter() into a Python async-for loop
8    __doc__ = u"""
9>>> def test_py35(AsyncIterClass):
10...     buffer = []
11...     async def coro():
12...         async for i1, i2 in AsyncIterClass(1):
13...             buffer.append(i1 + i2)
14...     return coro, buffer
15
16>>> testfunc, buffer = test_py35(AsyncIterOld if sys.version_info < (3, 5, 2) else AsyncIter)
17>>> buffer
18[]
19
20>>> yielded, _ = run_async(testfunc(), check_type=False)
21>>> yielded == [i * 100 for i in range(1, 11)] or yielded
22True
23>>> buffer == [i*2 for i in range(1, 101)] or buffer
24True
25"""
26
27
28cdef class AsyncYieldFrom:
29    cdef object obj
30    def __init__(self, obj):
31        self.obj = obj
32
33    def __await__(self):
34        yield from self.obj
35
36
37cdef class AsyncYield:
38    cdef object value
39    def __init__(self, value):
40        self.value = value
41
42    def __await__(self):
43        yield self.value
44
45
46def run_async(coro, check_type='coroutine'):
47    if check_type:
48        assert coro.__class__.__name__ == check_type, \
49            'type(%s) != %s' % (coro.__class__, check_type)
50
51    buffer = []
52    result = None
53    while True:
54        try:
55            buffer.append(coro.send(None))
56        except StopIteration as ex:
57            result = ex.args[0] if ex.args else None
58            break
59    return buffer, result
60
61
62cdef class AsyncIter:
63    cdef long i
64    cdef long aiter_calls
65    cdef long max_iter_calls
66
67    def __init__(self, long max_iter_calls=1):
68        self.i = 0
69        self.aiter_calls = 0
70        self.max_iter_calls = max_iter_calls
71
72    def __aiter__(self):
73        self.aiter_calls += 1
74        return self
75
76    async def __anext__(self):
77        self.i += 1
78        assert self.aiter_calls <= self.max_iter_calls
79
80        if not (self.i % 10):
81            await AsyncYield(self.i * 10)
82
83        if self.i > 100:
84            raise StopAsyncIteration
85
86        return self.i, self.i
87
88
89cdef class AsyncIterOld(AsyncIter):
90    """
91    Same as AsyncIter, but with the old async-def interface for __aiter__().
92    """
93    async def __aiter__(self):
94        self.aiter_calls += 1
95        return self
96
97
98def test_for_1():
99    """
100    >>> testfunc, buffer = test_for_1()
101    >>> buffer
102    []
103    >>> yielded, _ = run_async(testfunc())
104    >>> yielded == [i * 100 for i in range(1, 11)] or yielded
105    True
106    >>> buffer == [i*2 for i in range(1, 101)] or buffer
107    True
108    """
109    buffer = []
110    async def test1():
111        async for i1, i2 in AsyncIter(1):
112            buffer.append(i1 + i2)
113    return test1, buffer
114
115
116def test_for_2():
117    """
118    >>> testfunc, buffer = test_for_2()
119    >>> buffer
120    []
121    >>> yielded, _ = run_async(testfunc())
122    >>> yielded == [100, 200] or yielded
123    True
124    >>> buffer == [i for i in range(1, 21)] + ['end'] or buffer
125    True
126    """
127    buffer = []
128    async def test2():
129        nonlocal buffer
130        async for i in AsyncIter(2):
131            buffer.append(i[0])
132            if i[0] == 20:
133                break
134        else:
135            buffer.append('what?')
136        buffer.append('end')
137    return test2, buffer
138
139
140
141def test_for_3():
142    """
143    >>> testfunc, buffer = test_for_3()
144    >>> buffer
145    []
146    >>> yielded, _ = run_async(testfunc())
147    >>> yielded == [i * 100 for i in range(1, 11)] or yielded
148    True
149    >>> buffer == [i for i in range(1, 21)] + ['what?', 'end'] or buffer
150    True
151    """
152    buffer = []
153    async def test3():
154        nonlocal buffer
155        async for i in AsyncIter(3):
156            if i[0] > 20:
157                continue
158            buffer.append(i[0])
159        else:
160            buffer.append('what?')
161        buffer.append('end')
162    return test3, buffer
163
164
165cdef class NonAwaitableFromAnext:
166    def __aiter__(self):
167        return self
168
169    def __anext__(self):
170        return 123
171
172
173def test_broken_anext():
174    """
175    >>> testfunc = test_broken_anext()
176    >>> try: run_async(testfunc())
177    ... except TypeError as exc:
178    ...     assert ' int' in str(exc)
179    ... else:
180    ...     print("NOT RAISED!")
181    """
182    async def foo():
183        async for i in NonAwaitableFromAnext():
184            print('never going to happen')
185    return foo
186
187
188cdef class Manager:
189    cdef readonly list counter
190    def __init__(self, counter):
191        self.counter = counter
192
193    async def __aenter__(self):
194        self.counter[0] += 10000
195
196    async def __aexit__(self, *args):
197        self.counter[0] += 100000
198
199
200cdef class Iterable:
201    cdef long i
202    def __init__(self):
203        self.i = 0
204
205    def __aiter__(self):
206        return self
207
208    async def __anext__(self):
209        if self.i > 10:
210            raise StopAsyncIteration
211        self.i += 1
212        return self.i
213
214
215def test_with_for():
216    """
217    >>> test_with_for()
218    111011
219    333033
220    20555255
221    """
222    I = [0]
223
224    manager = Manager(I)
225    iterable = Iterable()
226    mrefs_before = sys.getrefcount(manager)
227    irefs_before = sys.getrefcount(iterable)
228
229    async def main():
230        async with manager:
231            async for i in iterable:
232                I[0] += 1
233        I[0] += 1000
234
235    run_async(main())
236    print(I[0])
237
238    assert sys.getrefcount(manager) == mrefs_before
239    assert sys.getrefcount(iterable) == irefs_before
240
241    ##############
242
243    async def main():
244        nonlocal I
245
246        async with Manager(I):
247            async for i in Iterable():
248                I[0] += 1
249        I[0] += 1000
250
251        async with Manager(I):
252            async for i in Iterable():
253                I[0] += 1
254        I[0] += 1000
255
256    run_async(main())
257    print(I[0])
258
259    ##############
260
261    async def main():
262        async with Manager(I):
263            I[0] += 100
264            async for i in Iterable():
265                I[0] += 1
266            else:
267                I[0] += 10000000
268        I[0] += 1000
269
270        async with Manager(I):
271            I[0] += 100
272            async for i in Iterable():
273                I[0] += 1
274            else:
275                I[0] += 10000000
276        I[0] += 1000
277
278    run_async(main())
279    print(I[0])
280
281
282# old-style pre-3.5.2 AIter protocol - no longer supported
283#cdef class AI_old:
284#    async def __aiter__(self):
285#        1/0
286
287
288cdef class AI_new:
289    def __aiter__(self):
290        1/0
291
292
293def test_aiter_raises(AI):
294    """
295    #>>> test_aiter_raises(AI_old)
296    #RAISED
297    #0
298    >>> test_aiter_raises(AI_new)
299    RAISED
300    0
301    """
302    CNT = 0
303
304    async def foo():
305        nonlocal CNT
306        async for i in AI():
307            CNT += 1
308        CNT += 10
309
310    try:
311        run_async(foo())
312    except ZeroDivisionError:
313        print("RAISED")
314    else:
315        print("NOT RAISED")
316    return CNT
317