1# cython: language_level=3, binding=True
2# mode: run
3# tag: pep492, pep530, asyncfor, await
4
5###########
6# This file is a copy of the corresponding test file in CPython.
7# Please keep in sync and do not add non-upstream tests.
8###########
9
10import re
11import gc
12import sys
13import copy
14#import types
15import pickle
16import os.path
17#import inspect
18import unittest
19import warnings
20import contextlib
21
22from Cython.Compiler import Errors
23
24
25try:
26    from types import coroutine as types_coroutine
27except ImportError:
28    # duck typed types_coroutine() decorator copied from types.py in Py3.5
29    class types_coroutine(object):
30        def __init__(self, gen):
31            self._gen = gen
32
33        class _GeneratorWrapper(object):
34            def __init__(self, gen):
35                self.__wrapped__ = gen
36                self.send = gen.send
37                self.throw = gen.throw
38                self.close = gen.close
39                self.__name__ = getattr(gen, '__name__', None)
40                self.__qualname__ = getattr(gen, '__qualname__', None)
41            @property
42            def gi_code(self):
43                return self.__wrapped__.gi_code
44            @property
45            def gi_frame(self):
46                return self.__wrapped__.gi_frame
47            @property
48            def gi_running(self):
49                return self.__wrapped__.gi_running
50            cr_code = gi_code
51            cr_frame = gi_frame
52            cr_running = gi_running
53            def __next__(self):
54                return next(self.__wrapped__)
55            def __iter__(self):
56                return self.__wrapped__
57            __await__ = __iter__
58
59        def __call__(self, *args, **kwargs):
60            return self._GeneratorWrapper(self._gen(*args, **kwargs))
61
62try:
63    from sys import getrefcount
64except ImportError:
65    from cpython.ref cimport PyObject
66    def getrefcount(obj):
67        gc.collect()
68        # PyPy needs to execute a bytecode to run the finalizers
69        exec('', {}, {})
70        return (<PyObject*>obj).ob_refcnt
71
72
73# compiled exec()
74def exec(code_string, l, g):
75    from Cython.Shadow import inline
76    try:
77        from StringIO import StringIO
78    except ImportError:
79        from io import StringIO
80
81    old_stderr = sys.stderr
82    try:
83        sys.stderr = StringIO()
84        ns = inline(code_string, locals=l, globals=g, lib_dir=os.path.dirname(__file__), language_level=3)
85    finally:
86        sys.stderr = old_stderr
87    g.update(ns)
88
89
90def compile(code_string, module, level):
91    exec(code_string, {}, {})
92
93
94class AsyncYieldFrom(object):
95    def __init__(self, obj):
96        self.obj = obj
97
98    def __await__(self):
99        yield from self.obj
100
101
102class AsyncYield(object):
103    def __init__(self, value):
104        self.value = value
105
106    def __await__(self):
107        yield self.value
108
109
110def run_async(coro):
111    #assert coro.__class__ is types.GeneratorType
112    assert coro.__class__.__name__ in ('coroutine', '_GeneratorWrapper'), coro.__class__.__name__
113
114    buffer = []
115    result = None
116    while True:
117        try:
118            buffer.append(coro.send(None))
119        except StopIteration as ex:
120            result = ex.value if sys.version_info >= (3, 5) else ex.args[0] if ex.args else None
121            break
122    return buffer, result
123
124
125def run_async__await__(coro):
126    assert coro.__class__.__name__ in ('coroutine', '_GeneratorWrapper'), coro.__class__.__name__
127    aw = coro.__await__()
128    buffer = []
129    result = None
130    i = 0
131    while True:
132        try:
133            if i % 2:
134                buffer.append(next(aw))
135            else:
136                buffer.append(aw.send(None))
137            i += 1
138        except StopIteration as ex:
139            result = ex.value if sys.version_info >= (3, 5) else ex.args[0] if ex.args else None
140            break
141    return buffer, result
142
143
144@contextlib.contextmanager
145def silence_coro_gc():
146    with warnings.catch_warnings():
147        warnings.simplefilter("ignore")
148        yield
149        gc.collect()
150
151
152def min_py27(method):
153    return None if sys.version_info < (2, 7) else method
154
155
156def ignore_py26(manager):
157    @contextlib.contextmanager
158    def dummy():
159        yield
160    return dummy() if sys.version_info < (2, 7) else manager
161
162
163@contextlib.contextmanager
164def captured_stderr():
165    try:
166        # StringIO.StringIO() also accepts str in Py2, io.StringIO() does not
167        from StringIO import StringIO
168    except ImportError:
169        from io import StringIO
170
171    orig_stderr = sys.stderr
172    try:
173        sys.stderr = StringIO()
174        yield sys.stderr
175    finally:
176        sys.stderr = orig_stderr
177
178
179class AsyncBadSyntaxTest(unittest.TestCase):
180
181    @contextlib.contextmanager
182    def assertRaisesRegex(self, exc_type, regex):
183        class Holder(object):
184            exception = None
185
186        holder = Holder()
187        # the error messages usually don't match, so we just ignore them
188        try:
189            yield holder
190        except exc_type as exc:
191            holder.exception = exc
192            self.assertTrue(True)
193        else:
194            self.assertTrue(False)
195
196    def test_badsyntax_1(self):
197        samples = [
198            """def foo():
199                await something()
200            """,
201
202            """await something()""",
203
204            """async def foo():
205                yield from []
206            """,
207
208            """async def foo():
209                await await fut
210            """,
211
212            """async def foo(a=await something()):
213                pass
214            """,
215
216            """async def foo(a:await something()):
217                pass
218            """,
219
220            """async def foo():
221                def bar():
222                 [i async for i in els]
223            """,
224
225            """async def foo():
226                def bar():
227                 [await i for i in els]
228            """,
229
230            """async def foo():
231                def bar():
232                 [i for i in els
233                    async for b in els]
234            """,
235
236            """async def foo():
237                def bar():
238                 [i for i in els
239                    for c in b
240                    async for b in els]
241            """,
242
243            """async def foo():
244                def bar():
245                 [i for i in els
246                    async for b in els
247                    for c in b]
248            """,
249
250            """async def foo():
251                def bar():
252                 [i for i in els
253                    for b in await els]
254            """,
255
256            """async def foo():
257                def bar():
258                 [i for i in els
259                    for b in els
260                        if await b]
261            """,
262
263            """async def foo():
264                def bar():
265                 [i for i in await els]
266            """,
267
268            """async def foo():
269                def bar():
270                 [i for i in els if await i]
271            """,
272
273            """def bar():
274                 [i async for i in els]
275            """,
276
277            """def bar():
278                 [await i for i in els]
279            """,
280
281            """def bar():
282                 [i for i in els
283                    async for b in els]
284            """,
285
286            """def bar():
287                 [i for i in els
288                    for c in b
289                    async for b in els]
290            """,
291
292            """def bar():
293                 [i for i in els
294                    async for b in els
295                    for c in b]
296            """,
297
298            """def bar():
299                 [i for i in els
300                    for b in await els]
301            """,
302
303            """def bar():
304                 [i for i in els
305                    for b in els
306                        if await b]
307            """,
308
309            """def bar():
310                 [i for i in await els]
311            """,
312
313            """def bar():
314                 [i for i in els if await i]
315            """,
316
317            """async def foo():
318                await
319            """,
320
321            """async def foo():
322                   def bar(): pass
323                   await = 1
324            """,
325
326            """async def foo():
327
328                   def bar(): pass
329                   await = 1
330            """,
331
332            """async def foo():
333                   def bar(): pass
334                   if 1:
335                       await = 1
336            """,
337
338            """def foo():
339                   async def bar(): pass
340                   if 1:
341                       await a
342            """,
343
344            """def foo():
345                   async def bar(): pass
346                   await a
347            """,
348
349            """def foo():
350                   def baz(): pass
351                   async def bar(): pass
352                   await a
353            """,
354
355            """def foo():
356                   def baz(): pass
357                   # 456
358                   async def bar(): pass
359                   # 123
360                   await a
361            """,
362
363            """async def foo():
364                   def baz(): pass
365                   # 456
366                   async def bar(): pass
367                   # 123
368                   await = 2
369            """,
370
371            """def foo():
372
373                   def baz(): pass
374
375                   async def bar(): pass
376
377                   await a
378            """,
379
380            """async def foo():
381
382                   def baz(): pass
383
384                   async def bar(): pass
385
386                   await = 2
387            """,
388
389            """async def foo():
390                   def async(): pass
391            """,
392
393            """async def foo():
394                   def await(): pass
395            """,
396
397            """async def foo():
398                   def bar():
399                       await
400            """,
401
402            """async def foo():
403                   return lambda async: await
404            """,
405
406            """async def foo():
407                   return lambda a: await
408            """,
409
410            """await a()""",
411
412            """async def foo(a=await b):
413                   pass
414            """,
415
416            """async def foo(a:await b):
417                   pass
418            """,
419
420            """def baz():
421                   async def foo(a=await b):
422                       pass
423            """,
424
425            """async def foo(async):
426                   pass
427            """,
428
429            """async def foo():
430                   def bar():
431                        def baz():
432                            async = 1
433            """,
434
435            """async def foo():
436                   def bar():
437                        def baz():
438                            pass
439                        async = 1
440            """,
441
442            """def foo():
443                   async def bar():
444
445                        async def baz():
446                            pass
447
448                        def baz():
449                            42
450
451                        async = 1
452            """,
453
454            """async def foo():
455                   def bar():
456                        def baz():
457                            pass\nawait foo()
458            """,
459
460            """def foo():
461                   def bar():
462                        async def baz():
463                            pass\nawait foo()
464            """,
465
466            """async def foo(await):
467                   pass
468            """,
469
470            """def foo():
471
472                   async def bar(): pass
473
474                   await a
475            """,
476
477            """def foo():
478                   async def bar():
479                        pass\nawait a
480            """]
481
482        for code in samples:
483            with self.subTest(code=code), self.assertRaisesRegex(Errors.CompileError, '.'):
484                compile(code, "<test>", "exec")
485
486    def test_badsyntax_2(self):
487        samples = [
488            """def foo():
489                await = 1
490            """,
491
492            """class Bar:
493                def async(): pass
494            """,
495
496            """class Bar:
497                async = 1
498            """,
499
500            """class async:
501                pass
502            """,
503
504            """class await:
505                pass
506            """,
507
508            """import math as await""",
509
510            """def async():
511                pass""",
512
513            """def foo(*, await=1):
514                pass"""
515
516            """async = 1""",
517
518            # FIXME: cannot currently request Py3 syntax in cython.inline()
519            #"""print(await=1)"""
520        ]
521
522        for code in samples:
523            with self.subTest(code=code):  # , self.assertRaisesRegex(Errors.CompileError, '.'):
524                compile(code, "<test>", "exec")
525
526    def test_badsyntax_3(self):
527        #with self.assertRaises(DeprecationWarning):
528            with warnings.catch_warnings():
529                warnings.simplefilter("error")
530                compile("async = 1", "<test>", "exec")
531
532    def test_badsyntax_10(self):
533        # Tests for issue 24619
534
535        samples = [
536            """async def foo():
537                   def bar(): pass
538                   await = 1
539            """,
540
541            """async def foo():
542
543                   def bar(): pass
544                   await = 1
545            """,
546
547            """async def foo():
548                   def bar(): pass
549                   if 1:
550                       await = 1
551            """,
552
553            """def foo():
554                   async def bar(): pass
555                   if 1:
556                       await a
557            """,
558
559            """def foo():
560                   async def bar(): pass
561                   await a
562            """,
563
564            """def foo():
565                   def baz(): pass
566                   async def bar(): pass
567                   await a
568            """,
569
570            """def foo():
571                   def baz(): pass
572                   # 456
573                   async def bar(): pass
574                   # 123
575                   await a
576            """,
577
578            """async def foo():
579                   def baz(): pass
580                   # 456
581                   async def bar(): pass
582                   # 123
583                   await = 2
584            """,
585
586            """def foo():
587
588                   def baz(): pass
589
590                   async def bar(): pass
591
592                   await a
593            """,
594
595            """async def foo():
596
597                   def baz(): pass
598
599                   async def bar(): pass
600
601                   await = 2
602            """,
603
604            """async def foo():
605                   def async(): pass
606            """,
607
608            """async def foo():
609                   def await(): pass
610            """,
611
612            """async def foo():
613                   def bar():
614                       await
615            """,
616
617            """async def foo():
618                   return lambda async: await
619            """,
620
621            """async def foo():
622                   return lambda a: await
623            """,
624
625            """await a()""",
626
627            """async def foo(a=await b):
628                   pass
629            """,
630
631            """async def foo(a:await b):
632                   pass
633            """,
634
635            """def baz():
636                   async def foo(a=await b):
637                       pass
638            """,
639
640            """async def foo(async):
641                   pass
642            """,
643
644            """async def foo():
645                   def bar():
646                        def baz():
647                            async = 1
648            """,
649
650            """async def foo():
651                   def bar():
652                        def baz():
653                            pass
654                        async = 1
655            """,
656
657            """def foo():
658                   async def bar():
659
660                        async def baz():
661                            pass
662
663                        def baz():
664                            42
665
666                        async = 1
667            """,
668
669            """async def foo():
670                   def bar():
671                        def baz():
672                            pass\nawait foo()
673            """,
674
675            """def foo():
676                   def bar():
677                        async def baz():
678                            pass\nawait foo()
679            """,
680
681            """async def foo(await):
682                   pass
683            """,
684
685            """def foo():
686
687                   async def bar(): pass
688
689                   await a
690            """,
691
692            """def foo():
693                   async def bar():
694                        pass\nawait a
695            """]
696
697        for code in samples:
698            # assertRaises() differs in Py2.6, so use our own assertRaisesRegex() instead
699            with self.subTest(code=code), self.assertRaisesRegex(Errors.CompileError, '.'):
700                exec(code, {}, {})
701
702    if not hasattr(unittest.TestCase, 'subTest'):
703        @contextlib.contextmanager
704        def subTest(self, code, **kwargs):
705            try:
706                yield
707            except Exception:
708                print(code)
709                raise
710
711    def test_goodsyntax_1(self):
712        # Tests for issue 24619
713
714        def foo(await):
715            async def foo(): pass
716            async def foo():
717                pass
718            return await + 1
719        self.assertEqual(foo(10), 11)
720
721        def foo(await):
722            async def foo(): pass
723            async def foo(): pass
724            return await + 2
725        self.assertEqual(foo(20), 22)
726
727        def foo(await):
728
729            async def foo(): pass
730
731            async def foo(): pass
732
733            return await + 2
734        self.assertEqual(foo(20), 22)
735
736        def foo(await):
737            """spam"""
738            async def foo(): \
739                pass
740            # 123
741            async def foo(): pass
742            # 456
743            return await + 2
744        self.assertEqual(foo(20), 22)
745
746        def foo(await):
747            def foo(): pass
748            def foo(): pass
749            async def bar(): return await_
750            await_ = await
751            try:
752                bar().send(None)
753            except StopIteration as ex:
754                return ex.args[0]
755        self.assertEqual(foo(42), 42)
756
757        async def f(z):
758            async def g(): pass
759            await z
760        await = 1
761        #self.assertTrue(inspect.iscoroutinefunction(f))
762
763
764class TokenizerRegrTest(unittest.TestCase):
765
766    def test_oneline_defs(self):
767        buf = []
768        for i in range(500):
769            buf.append('def i{i}(): return {i}'.format(i=i))
770        buf = '\n'.join(buf)
771
772        # Test that 500 consequent, one-line defs is OK
773        ns = {}
774        exec(buf, ns, ns)
775        self.assertEqual(ns['i499'](), 499)
776
777        # Test that 500 consequent, one-line defs *and*
778        # one 'async def' following them is OK
779        buf += '\nasync def foo():\n    return'
780        ns = {}
781        exec(buf, ns, ns)
782        self.assertEqual(ns['i499'](), 499)
783        self.assertEqual(type(ns['foo']()).__name__, 'coroutine')
784        #self.assertTrue(inspect.iscoroutinefunction(ns['foo']))
785
786
787class CoroutineTest(unittest.TestCase):
788
789    @classmethod
790    def setUpClass(cls):
791        # never mark warnings as "already seen" to prevent them from being suppressed
792        from warnings import simplefilter
793        simplefilter("always")
794
795    @contextlib.contextmanager
796    def assertRaises(self, exc_type):
797        try:
798            yield
799        except exc_type:
800            self.assertTrue(True)
801        else:
802            self.assertTrue(False)
803
804    @contextlib.contextmanager
805    def assertRaisesRegex(self, exc_type, regex):
806        class Holder(object):
807            exception = None
808
809        holder = Holder()
810        # the error messages usually don't match, so we just ignore them
811        try:
812            yield holder
813        except exc_type as exc:
814            holder.exception = exc
815            self.assertTrue(True)
816        else:
817            self.assertTrue(False)
818
819    @contextlib.contextmanager
820    def assertWarnsRegex(self, exc_type, regex):
821        from warnings import catch_warnings
822        with catch_warnings(record=True) as log:
823            yield
824
825        first_match = None
826        for warning in log:
827            w = warning.message
828            if not isinstance(w, exc_type):
829                continue
830            if first_match is None:
831                first_match = w
832            if re.search(regex, str(w)):
833                self.assertTrue(True)
834                return
835
836        if first_match is None:
837            self.assertTrue(False, "no warning was raised of type '%s'" % exc_type.__name__)
838        else:
839            self.assertTrue(False, "'%s' did not match '%s'" % (first_match, regex))
840
841    if not hasattr(unittest.TestCase, 'assertRegex'):
842        def assertRegex(self, value, regex):
843            self.assertTrue(re.search(regex, str(value)),
844                            "'%s' did not match '%s'" % (value, regex))
845
846    if not hasattr(unittest.TestCase, 'assertIn'):
847        def assertIn(self, member, container, msg=None):
848            self.assertTrue(member in container, msg)
849
850    if not hasattr(unittest.TestCase, 'assertIsNone'):
851        def assertIsNone(self, value, msg=None):
852            self.assertTrue(value is None, msg)
853
854    if not hasattr(unittest.TestCase, 'assertIsNotNone'):
855        def assertIsNotNone(self, value, msg=None):
856            self.assertTrue(value is not None, msg)
857
858    if not hasattr(unittest.TestCase, 'assertIsInstance'):
859        def assertIsInstance(self, obj, cls, msg=None):
860            self.assertTrue(isinstance(obj, cls), msg)
861
862    def test_gen_1(self):
863        def gen(): yield
864        self.assertFalse(hasattr(gen, '__await__'))
865
866    def test_func_attributes(self):
867        async def foo():
868            return 10
869
870        f = foo()
871        self.assertEqual(f.__name__, 'foo')
872        self.assertEqual(f.__qualname__, 'CoroutineTest.test_func_attributes.<locals>.foo')
873        self.assertEqual(f.__module__, 'test_coroutines_pep492')
874
875    def test_func_1(self):
876        async def foo():
877            return 10
878
879        f = foo()
880        self.assertEqual(f.__class__.__name__, 'coroutine')
881        #self.assertIsInstance(f, types.CoroutineType)
882        #self.assertTrue(bool(foo.__code__.co_flags & 0x80))
883        #self.assertTrue(bool(foo.__code__.co_flags & 0x20))
884        #self.assertTrue(bool(f.cr_code.co_flags & 0x80))
885        #self.assertTrue(bool(f.cr_code.co_flags & 0x20))
886        self.assertEqual(run_async(f), ([], 10))
887
888        self.assertEqual(run_async__await__(foo()), ([], 10))
889
890        def bar(): pass
891        self.assertFalse(bool(bar.__code__.co_flags & 0x80))  # inspect.CO_COROUTINE
892
893    # TODO
894    def __test_func_2(self):
895        async def foo():
896            raise StopIteration
897
898        with self.assertRaisesRegex(
899                RuntimeError, "coroutine raised StopIteration"):
900
901            run_async(foo())
902
903    def test_func_3(self):
904        async def foo():
905            raise StopIteration
906
907        with silence_coro_gc():
908            self.assertRegex(repr(foo()), '^<coroutine object.* at 0x.*>$')
909
910    def test_func_4(self):
911        async def foo():
912            raise StopIteration
913
914        check = lambda: self.assertRaisesRegex(
915            TypeError, "'coroutine' object is not iterable")
916
917        with check():
918            list(foo())
919
920        with check():
921            tuple(foo())
922
923        with check():
924            sum(foo())
925
926        with check():
927            iter(foo())
928
929        with check():
930            next(foo())
931
932        with silence_coro_gc(), check():
933            for i in foo():
934                pass
935
936        with silence_coro_gc(), check():
937            [i for i in foo()]
938
939    def test_func_5(self):
940        @types_coroutine
941        def bar():
942            yield 1
943
944        async def foo():
945            await bar()
946
947        check = lambda: self.assertRaisesRegex(
948            TypeError, "'coroutine' object is not iterable")
949
950        with check():
951            for el in foo(): pass
952
953        # the following should pass without an error
954        for el in bar():
955            self.assertEqual(el, 1)
956        self.assertEqual([el for el in bar()], [1])
957        self.assertEqual(tuple(bar()), (1,))
958        self.assertEqual(next(iter(bar())), 1)
959
960    def test_func_6(self):
961        @types_coroutine
962        def bar():
963            yield 1
964            yield 2
965
966        async def foo():
967            await bar()
968
969        f = foo()
970        self.assertEqual(f.send(None), 1)
971        self.assertEqual(f.send(None), 2)
972        with self.assertRaises(StopIteration):
973            f.send(None)
974
975    # TODO (or not? see test_func_8() below)
976    def __test_func_7(self):
977        async def bar():
978            return 10
979
980        def foo():
981            yield from bar()
982
983        with silence_coro_gc(), self.assertRaisesRegex(
984            TypeError,
985            "cannot 'yield from' a coroutine object in a non-coroutine generator"):
986
987            list(foo())
988
989    def test_func_8(self):
990        @types_coroutine
991        def bar():
992            return (yield from foo())
993
994        async def foo():
995            return 'spam'
996
997        self.assertEqual(run_async(bar()), ([], 'spam') )
998
999    def test_func_9(self):
1000        async def foo(): pass
1001
1002        gc.collect()
1003        with self.assertWarnsRegex(
1004            RuntimeWarning, "coroutine '.*test_func_9.*foo' was never awaited"):
1005
1006            foo()
1007            gc.collect()
1008
1009    def test_func_10(self):
1010        N = 0
1011
1012        @types_coroutine
1013        def gen():
1014            nonlocal N
1015            try:
1016                a = yield
1017                yield (a ** 2)
1018            except ZeroDivisionError:
1019                N += 100
1020                raise
1021            finally:
1022                N += 1
1023
1024        async def foo():
1025            await gen()
1026
1027        coro = foo()
1028        aw = coro.__await__()
1029        self.assertTrue(aw is iter(aw))
1030        next(aw)
1031        self.assertEqual(aw.send(10), 100)
1032        with self.assertRaises(TypeError):   # removed from CPython test suite?
1033            type(aw).send(None, None)
1034
1035        self.assertEqual(N, 0)
1036        aw.close()
1037        self.assertEqual(N, 1)
1038        with self.assertRaises(TypeError):   # removed from CPython test suite?
1039            type(aw).close(None)
1040
1041        coro = foo()
1042        aw = coro.__await__()
1043        next(aw)
1044        with self.assertRaises(ZeroDivisionError):
1045            aw.throw(ZeroDivisionError, None, None)
1046        self.assertEqual(N, 102)
1047        with self.assertRaises(TypeError):   # removed from CPython test suite?
1048            type(aw).throw(None, None, None, None)
1049
1050    def test_func_11(self):
1051        async def func(): pass
1052        coro = func()
1053        # Test that PyCoro_Type and _PyCoroWrapper_Type types were properly
1054        # initialized
1055        self.assertIn('__await__', dir(coro))
1056        self.assertIn('__iter__', dir(coro.__await__()))
1057        self.assertIn('coroutine_wrapper', repr(coro.__await__()))
1058        coro.close() # avoid RuntimeWarning
1059
1060    def test_func_12(self):
1061        async def g():
1062            i = me.send(None)
1063            await None
1064        me = g()
1065        with self.assertRaisesRegex(ValueError,
1066                                    "coroutine already executing"):
1067            me.send(None)
1068
1069    def test_func_13(self):
1070        async def g():
1071            pass
1072        with self.assertRaisesRegex(
1073            TypeError,
1074            "can't send non-None value to a just-started coroutine"):
1075
1076            g().send('spam')
1077
1078    def test_func_14(self):
1079        @types_coroutine
1080        def gen():
1081            yield
1082        async def coro():
1083            try:
1084                await gen()
1085            except GeneratorExit:
1086                await gen()
1087        c = coro()
1088        c.send(None)
1089        with self.assertRaisesRegex(RuntimeError,
1090                                    "coroutine ignored GeneratorExit"):
1091            c.close()
1092
1093    def test_func_15(self):
1094        # See http://bugs.python.org/issue25887 for details
1095
1096        async def spammer():
1097            return 'spam'
1098        async def reader(coro):
1099            return await coro
1100
1101        spammer_coro = spammer()
1102
1103        with self.assertRaisesRegex(StopIteration, 'spam'):
1104            reader(spammer_coro).send(None)
1105
1106        with self.assertRaisesRegex(RuntimeError,
1107                                    'cannot reuse already awaited coroutine'):
1108            reader(spammer_coro).send(None)
1109
1110    def test_func_16(self):
1111        # See http://bugs.python.org/issue25887 for details
1112
1113        @types_coroutine
1114        def nop():
1115            yield
1116        async def send():
1117            await nop()
1118            return 'spam'
1119        async def read(coro):
1120            await nop()
1121            return await coro
1122
1123        spammer = send()
1124
1125        reader = read(spammer)
1126        reader.send(None)
1127        reader.send(None)
1128        with self.assertRaisesRegex(Exception, 'ham'):
1129            reader.throw(Exception('ham'))
1130
1131        reader = read(spammer)
1132        reader.send(None)
1133        with self.assertRaisesRegex(RuntimeError,
1134                                    'cannot reuse already awaited coroutine'):
1135            reader.send(None)
1136
1137        with self.assertRaisesRegex(RuntimeError,
1138                                    'cannot reuse already awaited coroutine'):
1139            reader.throw(Exception('wat'))
1140
1141    def test_func_17(self):
1142        # See http://bugs.python.org/issue25887 for details
1143
1144        async def coroutine():
1145            return 'spam'
1146
1147        coro = coroutine()
1148        with self.assertRaisesRegex(StopIteration, 'spam'):
1149            coro.send(None)
1150
1151        with self.assertRaisesRegex(RuntimeError,
1152                                    'cannot reuse already awaited coroutine'):
1153            coro.send(None)
1154
1155        with self.assertRaisesRegex(RuntimeError,
1156                                    'cannot reuse already awaited coroutine'):
1157            coro.throw(Exception('wat'))
1158
1159        # Closing a coroutine shouldn't raise any exception even if it's
1160        # already closed/exhausted (similar to generators)
1161        coro.close()
1162        coro.close()
1163
1164    def test_func_18(self):
1165        # See http://bugs.python.org/issue25887 for details
1166
1167        async def coroutine():
1168            return 'spam'
1169
1170        coro = coroutine()
1171        await_iter = coro.__await__()
1172        it = iter(await_iter)
1173
1174        with self.assertRaisesRegex(StopIteration, 'spam'):
1175            it.send(None)
1176
1177        with self.assertRaisesRegex(RuntimeError,
1178                                    'cannot reuse already awaited coroutine'):
1179            it.send(None)
1180
1181        with self.assertRaisesRegex(RuntimeError,
1182                                    'cannot reuse already awaited coroutine'):
1183            # Although the iterator protocol requires iterators to
1184            # raise another StopIteration here, we don't want to do
1185            # that.  In this particular case, the iterator will raise
1186            # a RuntimeError, so that 'yield from' and 'await'
1187            # expressions will trigger the error, instead of silently
1188            # ignoring the call.
1189            next(it)
1190
1191        with self.assertRaisesRegex(RuntimeError,
1192                                    'cannot reuse already awaited coroutine'):
1193            it.throw(Exception('wat'))
1194
1195        with self.assertRaisesRegex(RuntimeError,
1196                                    'cannot reuse already awaited coroutine'):
1197            it.throw(Exception('wat'))
1198
1199        # Closing a coroutine shouldn't raise any exception even if it's
1200        # already closed/exhausted (similar to generators)
1201        it.close()
1202        it.close()
1203
1204    def test_func_19(self):
1205        CHK = 0
1206
1207        @types_coroutine
1208        def foo():
1209            nonlocal CHK
1210            yield
1211            try:
1212                yield
1213            except GeneratorExit:
1214                CHK += 1
1215
1216        async def coroutine():
1217            await foo()
1218
1219        coro = coroutine()
1220
1221        coro.send(None)
1222        coro.send(None)
1223
1224        self.assertEqual(CHK, 0)
1225        coro.close()
1226        self.assertEqual(CHK, 1)
1227
1228        for _ in range(3):
1229            # Closing a coroutine shouldn't raise any exception even if it's
1230            # already closed/exhausted (similar to generators)
1231            coro.close()
1232            self.assertEqual(CHK, 1)
1233
1234    def test_coro_wrapper_send_tuple(self):
1235        async def foo():
1236            return (10,)
1237
1238        result = run_async__await__(foo())
1239        self.assertEqual(result, ([], (10,)))
1240
1241    def test_coro_wrapper_send_stop_iterator(self):
1242        async def foo():
1243            return StopIteration(10)
1244
1245        result = run_async__await__(foo())
1246        self.assertIsInstance(result[1], StopIteration)
1247        if sys.version_info >= (3, 3):
1248            self.assertEqual(result[1].value, 10)
1249        else:
1250            self.assertEqual(result[1].args[0], 10)
1251
1252    def test_cr_await(self):
1253        @types_coroutine
1254        def a():
1255            #self.assertEqual(inspect.getcoroutinestate(coro_b), inspect.CORO_RUNNING)
1256            self.assertIsNone(coro_b.cr_await)
1257            yield
1258            #self.assertEqual(inspect.getcoroutinestate(coro_b), inspect.CORO_RUNNING)
1259            # FIXME: no idea why the following works in CPython:
1260            #self.assertIsNone(coro_b.cr_await)
1261
1262        async def c():
1263            await a()
1264
1265        async def b():
1266            self.assertIsNone(coro_b.cr_await)
1267            await c()
1268            self.assertIsNone(coro_b.cr_await)
1269
1270        coro_b = b()
1271        #self.assertEqual(inspect.getcoroutinestate(coro_b), inspect.CORO_CREATED)
1272        self.assertIsNone(coro_b.cr_await)
1273
1274        coro_b.send(None)
1275        #self.assertEqual(inspect.getcoroutinestate(coro_b), inspect.CORO_SUSPENDED)
1276        #self.assertEqual(coro_b.cr_await.cr_await.gi_code.co_name, 'a')
1277        self.assertIsNotNone(coro_b.cr_await.cr_await)
1278        self.assertEqual(coro_b.cr_await.cr_await.__name__, 'a')
1279
1280        with self.assertRaises(StopIteration):
1281            coro_b.send(None)  # complete coroutine
1282        #self.assertEqual(inspect.getcoroutinestate(coro_b), inspect.CORO_CLOSED)
1283        self.assertIsNone(coro_b.cr_await)
1284
1285    def test_corotype_1(self):
1286        async def f(): pass
1287        ct = type(f())
1288        self.assertIn('into coroutine', ct.send.__doc__)
1289        self.assertIn('inside coroutine', ct.close.__doc__)
1290        self.assertIn('in coroutine', ct.throw.__doc__)
1291        self.assertIn('of the coroutine', ct.__dict__['__name__'].__doc__)
1292        self.assertIn('of the coroutine', ct.__dict__['__qualname__'].__doc__)
1293        self.assertEqual(ct.__name__, 'coroutine')
1294
1295        async def f(): pass
1296        c = f()
1297        self.assertIn('coroutine object', repr(c))
1298        c.close()
1299
1300    def test_await_1(self):
1301
1302        async def foo():
1303            await 1
1304        with self.assertRaisesRegex(TypeError, "object int can.t.*await"):
1305            run_async(foo())
1306
1307    def test_await_2(self):
1308        async def foo():
1309            await []
1310        with self.assertRaisesRegex(TypeError, "object list can.t.*await"):
1311            run_async(foo())
1312
1313    def test_await_3(self):
1314        async def foo():
1315            await AsyncYieldFrom([1, 2, 3])
1316
1317        self.assertEqual(run_async(foo()), ([1, 2, 3], None))
1318        self.assertEqual(run_async__await__(foo()), ([1, 2, 3], None))
1319
1320    def test_await_4(self):
1321        async def bar():
1322            return 42
1323
1324        async def foo():
1325            return await bar()
1326
1327        self.assertEqual(run_async(foo()), ([], 42))
1328
1329    def test_await_5(self):
1330        class Awaitable(object):
1331            def __await__(self):
1332                return
1333
1334        async def foo():
1335            return (await Awaitable())
1336
1337        with self.assertRaisesRegex(
1338            TypeError, "__await__.*returned non-iterator of type"):
1339
1340            run_async(foo())
1341
1342    def test_await_6(self):
1343        class Awaitable(object):
1344            def __await__(self):
1345                return iter([52])
1346
1347        async def foo():
1348            return (await Awaitable())
1349
1350        self.assertEqual(run_async(foo()), ([52], None))
1351
1352    def test_await_7(self):
1353        class Awaitable(object):
1354            def __await__(self):
1355                yield 42
1356                return 100
1357
1358        async def foo():
1359            return (await Awaitable())
1360
1361        self.assertEqual(run_async(foo()), ([42], 100))
1362
1363    def test_await_8(self):
1364        class Awaitable(object):
1365            pass
1366
1367        async def foo(): return await Awaitable()
1368
1369        with self.assertRaisesRegex(
1370            TypeError, "object Awaitable can't be used in 'await' expression"):
1371
1372            run_async(foo())
1373
1374    def test_await_9(self):
1375        def wrap():
1376            return bar
1377
1378        async def bar():
1379            return 42
1380
1381        async def foo():
1382            b = bar()
1383
1384            db = {'b':  lambda: wrap}
1385
1386            class DB(object):
1387                b = staticmethod(wrap)
1388
1389            return (await bar() + await wrap()() + await db['b']()()() +
1390                    await bar() * 1000 + await DB.b()())
1391
1392        async def foo2():
1393            return -await bar()
1394
1395        self.assertEqual(run_async(foo()), ([], 42168))
1396        self.assertEqual(run_async(foo2()), ([], -42))
1397
1398    def test_await_10(self):
1399        async def baz():
1400            return 42
1401
1402        async def bar():
1403            return baz()
1404
1405        async def foo():
1406            return await (await bar())
1407
1408        self.assertEqual(run_async(foo()), ([], 42))
1409
1410    def test_await_11(self):
1411        def ident(val):
1412            return val
1413
1414        async def bar():
1415            return 'spam'
1416
1417        async def foo():
1418            return ident(val=await bar())
1419
1420        async def foo2():
1421            return await bar(), 'ham'
1422
1423        self.assertEqual(run_async(foo2()), ([], ('spam', 'ham')))
1424
1425    def test_await_12(self):
1426        async def coro():
1427            return 'spam'
1428
1429        class Awaitable(object):
1430            def __await__(self):
1431                return coro()
1432
1433        async def foo():
1434            return await Awaitable()
1435
1436        with self.assertRaisesRegex(
1437            TypeError, r"__await__\(\) returned a coroutine"):
1438
1439            run_async(foo())
1440
1441    def test_await_13(self):
1442        class Awaitable(object):
1443            def __await__(self):
1444                return self
1445
1446        async def foo():
1447            return await Awaitable()
1448
1449        with self.assertRaisesRegex(
1450            TypeError, "__await__.*returned non-iterator of type"):
1451
1452            run_async(foo())
1453
1454    def test_await_14(self):
1455        class Wrapper(object):
1456            # Forces the interpreter to use CoroutineType.__await__
1457            def __init__(self, coro):
1458                self.coro = coro
1459            def __await__(self):
1460                return self.coro.__await__()
1461
1462        class FutureLike(object):
1463            def __await__(self):
1464                return (yield)
1465
1466        class Marker(Exception):
1467            pass
1468
1469        async def coro1():
1470            try:
1471                return await FutureLike()
1472            except ZeroDivisionError:
1473                raise Marker
1474        async def coro2():
1475            return await Wrapper(coro1())
1476
1477        c = coro2()
1478        c.send(None)
1479        with self.assertRaisesRegex(StopIteration, 'spam'):
1480            c.send('spam')
1481
1482        c = coro2()
1483        c.send(None)
1484        with self.assertRaises(Marker):
1485            c.throw(ZeroDivisionError)
1486
1487    def test_await_15(self):
1488        @types_coroutine
1489        def nop():
1490            yield
1491
1492        async def coroutine():
1493            await nop()
1494
1495        async def waiter(coro):
1496            await coro
1497
1498        coro = coroutine()
1499        coro.send(None)
1500
1501        with self.assertRaisesRegex(RuntimeError,
1502                                    "coroutine is being awaited already"):
1503            waiter(coro).send(None)
1504
1505    def test_await_16(self):
1506        # See https://bugs.python.org/issue29600 for details.
1507
1508        async def f():
1509            return ValueError()
1510
1511        async def g():
1512            try:
1513                raise KeyError
1514            except:
1515                return await f()
1516
1517        _, result = run_async(g())
1518        if sys.version_info[0] >= 3:
1519            self.assertIsNone(result.__context__)
1520
1521    # removed from CPython ?
1522    def __test_await_iterator(self):
1523        async def foo():
1524            return 123
1525
1526        coro = foo()
1527        it = coro.__await__()
1528        self.assertEqual(type(it).__name__, 'coroutine_wrapper')
1529
1530        with self.assertRaisesRegex(TypeError, "cannot instantiate 'coroutine_wrapper' type"):
1531            type(it)()  # cannot instantiate
1532
1533        with self.assertRaisesRegex(StopIteration, "123"):
1534            next(it)
1535
1536    def test_with_1(self):
1537        class Manager(object):
1538            def __init__(self, name):
1539                self.name = name
1540
1541            async def __aenter__(self):
1542                await AsyncYieldFrom(['enter-1-' + self.name,
1543                                      'enter-2-' + self.name])
1544                return self
1545
1546            async def __aexit__(self, *args):
1547                await AsyncYieldFrom(['exit-1-' + self.name,
1548                                      'exit-2-' + self.name])
1549
1550                if self.name == 'B':
1551                    return True
1552
1553
1554        async def foo():
1555            async with Manager("A") as a, Manager("B") as b:
1556                await AsyncYieldFrom([('managers', a.name, b.name)])
1557                1/0
1558
1559        f = foo()
1560        result, _ = run_async(f)
1561
1562        self.assertEqual(
1563            result, ['enter-1-A', 'enter-2-A', 'enter-1-B', 'enter-2-B',
1564                     ('managers', 'A', 'B'),
1565                     'exit-1-B', 'exit-2-B', 'exit-1-A', 'exit-2-A']
1566        )
1567
1568        async def foo():
1569            async with Manager("A") as a, Manager("C") as c:
1570                await AsyncYieldFrom([('managers', a.name, c.name)])
1571                1/0
1572
1573        with self.assertRaises(ZeroDivisionError):
1574            run_async(foo())
1575
1576    def test_with_2(self):
1577        class CM(object):
1578            def __aenter__(self):
1579                pass
1580
1581        async def foo():
1582            async with CM():
1583                pass
1584
1585        with self.assertRaisesRegex(AttributeError, '__aexit__'):
1586            run_async(foo())
1587
1588    def test_with_3(self):
1589        class CM(object):
1590            def __aexit__(self):
1591                pass
1592
1593        async def foo():
1594            async with CM():
1595                pass
1596
1597        with self.assertRaisesRegex(AttributeError, '__aenter__'):
1598            run_async(foo())
1599
1600    def test_with_4(self):
1601        class CM(object):
1602            def __enter__(self):
1603                pass
1604
1605            def __exit__(self):
1606                pass
1607
1608        async def foo():
1609            async with CM():
1610                pass
1611
1612        with self.assertRaisesRegex(AttributeError, '__aexit__'):
1613            run_async(foo())
1614
1615    def test_with_5(self):
1616        # While this test doesn't make a lot of sense,
1617        # it's a regression test for an early bug with opcodes
1618        # generation
1619
1620        class CM(object):
1621            async def __aenter__(self):
1622                return self
1623
1624            async def __aexit__(self, *exc):
1625                pass
1626
1627        async def func():
1628            async with CM():
1629                assert (1, ) == 1
1630
1631        with self.assertRaises(AssertionError):
1632            run_async(func())
1633
1634    def test_with_6(self):
1635        class CM(object):
1636            def __aenter__(self):
1637                return 123
1638
1639            def __aexit__(self, *e):
1640                return 456
1641
1642        async def foo():
1643            async with CM():
1644                pass
1645
1646        with self.assertRaisesRegex(
1647            TypeError, "object int can't be used in 'await' expression"):
1648            # it's important that __aexit__ wasn't called
1649            run_async(foo())
1650
1651    def test_with_7(self):
1652        class CM(object):
1653            async def __aenter__(self):
1654                return self
1655
1656            def __aexit__(self, *e):
1657                return 444
1658
1659        async def foo():
1660            async with CM():
1661                1/0
1662
1663        try:
1664            run_async(foo())
1665        except TypeError as exc:
1666            self.assertRegex(
1667                exc.args[0], "object int can't be used in 'await' expression")
1668            if sys.version_info[0] >= 3:
1669                self.assertTrue(exc.__context__ is not None)
1670                self.assertTrue(isinstance(exc.__context__, ZeroDivisionError))
1671        else:
1672            self.fail('invalid asynchronous context manager did not fail')
1673
1674
1675    def test_with_8(self):
1676        CNT = 0
1677
1678        class CM(object):
1679            async def __aenter__(self):
1680                return self
1681
1682            def __aexit__(self, *e):
1683                return 456
1684
1685        async def foo():
1686            nonlocal CNT
1687            async with CM():
1688                CNT += 1
1689
1690
1691        with self.assertRaisesRegex(
1692            TypeError, "object int can't be used in 'await' expression"):
1693
1694            run_async(foo())
1695
1696        self.assertEqual(CNT, 1)
1697
1698    def test_with_9(self):
1699        CNT = 0
1700
1701        class CM(object):
1702            async def __aenter__(self):
1703                return self
1704
1705            async def __aexit__(self, *e):
1706                1/0
1707
1708        async def foo():
1709            nonlocal CNT
1710            async with CM():
1711                CNT += 1
1712
1713        with self.assertRaises(ZeroDivisionError):
1714            run_async(foo())
1715
1716        self.assertEqual(CNT, 1)
1717
1718    def test_with_10(self):
1719        CNT = 0
1720
1721        class CM(object):
1722            async def __aenter__(self):
1723                return self
1724
1725            async def __aexit__(self, *e):
1726                1/0
1727
1728        async def foo():
1729            nonlocal CNT
1730            async with CM():
1731                async with CM():
1732                    raise RuntimeError
1733
1734        try:
1735            run_async(foo())
1736        except ZeroDivisionError as exc:
1737            pass  # FIXME!
1738            #if sys.version_info[0] >= 3:
1739            #    self.assertTrue(exc.__context__ is not None)
1740            #    self.assertTrue(isinstance(exc.__context__, ZeroDivisionError))
1741            #    self.assertTrue(isinstance(exc.__context__.__context__, RuntimeError))
1742        else:
1743            self.fail('exception from __aexit__ did not propagate')
1744
1745    def test_with_11(self):
1746        CNT = 0
1747
1748        class CM(object):
1749            async def __aenter__(self):
1750                raise NotImplementedError
1751
1752            async def __aexit__(self, *e):
1753                1/0
1754
1755        async def foo():
1756            nonlocal CNT
1757            async with CM():
1758                raise RuntimeError
1759
1760        try:
1761            run_async(foo())
1762        except NotImplementedError as exc:
1763            if sys.version_info[0] >= 3:
1764                self.assertTrue(exc.__context__ is None)
1765        else:
1766            self.fail('exception from __aenter__ did not propagate')
1767
1768    def test_with_12(self):
1769        CNT = 0
1770
1771        class CM(object):
1772            async def __aenter__(self):
1773                return self
1774
1775            async def __aexit__(self, *e):
1776                return True
1777
1778        async def foo():
1779            nonlocal CNT
1780            async with CM() as cm:
1781                self.assertIs(cm.__class__, CM)
1782                raise RuntimeError
1783
1784        run_async(foo())
1785
1786    def test_with_13(self):
1787        CNT = 0
1788
1789        class CM(object):
1790            async def __aenter__(self):
1791                1/0
1792
1793            async def __aexit__(self, *e):
1794                return True
1795
1796        async def foo():
1797            nonlocal CNT
1798            CNT += 1
1799            async with CM():
1800                CNT += 1000
1801            CNT += 10000
1802
1803        with self.assertRaises(ZeroDivisionError):
1804            run_async(foo())
1805        self.assertEqual(CNT, 1)
1806
1807    # old-style pre-Py3.5.2 protocol - no longer supported
1808    def __test_for_1(self):
1809        aiter_calls = 0
1810
1811        class AsyncIter(object):
1812            def __init__(self):
1813                self.i = 0
1814
1815            async def __aiter__(self):
1816                nonlocal aiter_calls
1817                aiter_calls += 1
1818                return self
1819
1820            async def __anext__(self):
1821                self.i += 1
1822
1823                if not (self.i % 10):
1824                    await AsyncYield(self.i * 10)
1825
1826                if self.i > 100:
1827                    raise StopAsyncIteration
1828
1829                return self.i, self.i
1830
1831
1832        buffer = []
1833        async def test1():
1834            with ignore_py26(self.assertWarnsRegex(DeprecationWarning, "legacy")):
1835                async for i1, i2 in AsyncIter():
1836                    buffer.append(i1 + i2)
1837
1838        yielded, _ = run_async(test1())
1839        # Make sure that __aiter__ was called only once
1840        self.assertEqual(aiter_calls, 1)
1841        self.assertEqual(yielded, [i * 100 for i in range(1, 11)])
1842        self.assertEqual(buffer, [i*2 for i in range(1, 101)])
1843
1844
1845        buffer = []
1846        async def test2():
1847            nonlocal buffer
1848            with ignore_py26(self.assertWarnsRegex(DeprecationWarning, "legacy")):
1849                async for i in AsyncIter():
1850                    buffer.append(i[0])
1851                    if i[0] == 20:
1852                        break
1853                else:
1854                    buffer.append('what?')
1855            buffer.append('end')
1856
1857        yielded, _ = run_async(test2())
1858        # Make sure that __aiter__ was called only once
1859        self.assertEqual(aiter_calls, 2)
1860        self.assertEqual(yielded, [100, 200])
1861        self.assertEqual(buffer, [i for i in range(1, 21)] + ['end'])
1862
1863
1864        buffer = []
1865        async def test3():
1866            nonlocal buffer
1867            with ignore_py26(self.assertWarnsRegex(DeprecationWarning, "legacy")):
1868                async for i in AsyncIter():
1869                    if i[0] > 20:
1870                        continue
1871                    buffer.append(i[0])
1872                else:
1873                    buffer.append('what?')
1874            buffer.append('end')
1875
1876        yielded, _ = run_async(test3())
1877        # Make sure that __aiter__ was called only once
1878        self.assertEqual(aiter_calls, 3)
1879        self.assertEqual(yielded, [i * 100 for i in range(1, 11)])
1880        self.assertEqual(buffer, [i for i in range(1, 21)] +
1881                                 ['what?', 'end'])
1882
1883    def test_for_2(self):
1884        tup = (1, 2, 3)
1885        refs_before = getrefcount(tup)
1886
1887        async def foo():
1888            async for i in tup:
1889                print('never going to happen')
1890
1891        with self.assertRaisesRegex(
1892                TypeError, "async for' requires an object.*__aiter__.*tuple"):
1893
1894            run_async(foo())
1895
1896        self.assertEqual(getrefcount(tup), refs_before)
1897
1898    def test_for_3(self):
1899        class I(object):
1900            def __aiter__(self):
1901                return self
1902
1903        aiter = I()
1904        refs_before = getrefcount(aiter)
1905
1906        async def foo():
1907            async for i in aiter:
1908                print('never going to happen')
1909
1910        with self.assertRaisesRegex(
1911                TypeError,
1912                "async for' received an invalid object.*__aiter.*\: I"):
1913
1914            run_async(foo())
1915
1916        self.assertEqual(getrefcount(aiter), refs_before)
1917
1918    def test_for_4(self):
1919        class I(object):
1920            def __aiter__(self):
1921                return self
1922
1923            def __anext__(self):
1924                return ()
1925
1926        aiter = I()
1927        refs_before = getrefcount(aiter)
1928
1929        async def foo():
1930            async for i in aiter:
1931                print('never going to happen')
1932
1933        with self.assertRaisesRegex(
1934                TypeError,
1935                "async for' received an invalid object.*__anext__.*tuple"):
1936
1937            run_async(foo())
1938
1939        self.assertEqual(getrefcount(aiter), refs_before)
1940
1941    def test_for_5(self):
1942        class I(object):
1943            async def __aiter__(self):
1944                return self
1945
1946            def __anext__(self):
1947                return 123
1948
1949        async def foo():
1950            with self.assertWarnsRegex(DeprecationWarning, "legacy"):
1951                async for i in I():
1952                    print('never going to happen')
1953
1954        with self.assertRaisesRegex(
1955                TypeError,
1956                "async for' received an invalid object.*__anext.*int"):
1957
1958            run_async(foo())
1959
1960    def test_for_6(self):
1961        I = 0
1962
1963        class Manager(object):
1964            async def __aenter__(self):
1965                nonlocal I
1966                I += 10000
1967
1968            async def __aexit__(self, *args):
1969                nonlocal I
1970                I += 100000
1971
1972        class Iterable(object):
1973            def __init__(self):
1974                self.i = 0
1975
1976            def __aiter__(self):
1977                return self
1978
1979            async def __anext__(self):
1980                if self.i > 10:
1981                    raise StopAsyncIteration
1982                self.i += 1
1983                return self.i
1984
1985        ##############
1986
1987        manager = Manager()
1988        iterable = Iterable()
1989        mrefs_before = getrefcount(manager)
1990        irefs_before = getrefcount(iterable)
1991
1992        async def main():
1993            nonlocal I
1994
1995            async with manager:
1996                async for i in iterable:
1997                    I += 1
1998            I += 1000
1999
2000        run_async(main())
2001        self.assertEqual(I, 111011)
2002
2003        self.assertEqual(getrefcount(manager), mrefs_before)
2004        self.assertEqual(getrefcount(iterable), irefs_before)
2005
2006        ##############
2007
2008        async def main():
2009            nonlocal I
2010
2011            async with Manager():
2012                async for i in Iterable():
2013                    I += 1
2014            I += 1000
2015
2016            async with Manager():
2017                async for i in Iterable():
2018                    I += 1
2019            I += 1000
2020
2021        run_async(main())
2022        self.assertEqual(I, 333033)
2023
2024        ##############
2025
2026        async def main():
2027            nonlocal I
2028
2029            async with Manager():
2030                I += 100
2031                async for i in Iterable():
2032                    I += 1
2033                else:
2034                    I += 10000000
2035            I += 1000
2036
2037            async with Manager():
2038                I += 100
2039                async for i in Iterable():
2040                    I += 1
2041                else:
2042                    I += 10000000
2043            I += 1000
2044
2045        run_async(main())
2046        self.assertEqual(I, 20555255)
2047
2048    # old-style pre-Py3.5.2 protocol - no longer supported
2049    def __test_for_7(self):
2050        CNT = 0
2051        class AI(object):
2052            async def __aiter__(self):
2053                1/0
2054        async def foo():
2055            nonlocal CNT
2056            with self.assertWarnsRegex(DeprecationWarning, "legacy"):
2057                async for i in AI():
2058                    CNT += 1
2059            CNT += 10
2060        with self.assertRaises(ZeroDivisionError):
2061            run_async(foo())
2062        self.assertEqual(CNT, 0)
2063
2064    def test_for_8(self):
2065        CNT = 0
2066        class AI(object):
2067            def __aiter__(self):
2068                1/0
2069        async def foo():
2070            nonlocal CNT
2071            async for i in AI():
2072                CNT += 1
2073            CNT += 10
2074        with self.assertRaises(ZeroDivisionError):
2075            #run_async(foo())
2076            with warnings.catch_warnings():
2077                warnings.simplefilter("error")
2078                # Test that if __aiter__ raises an exception it propagates
2079                # without any kind of warning.
2080                run_async(foo())
2081        self.assertEqual(CNT, 0)
2082
2083    # old-style pre-Py3.5.2 protocol - no longer supported
2084    @min_py27
2085    def __test_for_9(self):
2086        # Test that DeprecationWarning can safely be converted into
2087        # an exception (__aiter__ should not have a chance to raise
2088        # a ZeroDivisionError.)
2089        class AI(object):
2090            async def __aiter__(self):
2091                1/0
2092        async def foo():
2093            async for i in AI():
2094                pass
2095
2096        with self.assertRaises(DeprecationWarning):
2097            with warnings.catch_warnings():
2098                warnings.simplefilter("error")
2099                run_async(foo())
2100
2101    # old-style pre-Py3.5.2 protocol - no longer supported
2102    @min_py27
2103    def __test_for_10(self):
2104        # Test that DeprecationWarning can safely be converted into
2105        # an exception.
2106        class AI(object):
2107            async def __aiter__(self):
2108                pass
2109        async def foo():
2110            async for i in AI():
2111                pass
2112
2113        with self.assertRaises(DeprecationWarning):
2114            with warnings.catch_warnings():
2115                warnings.simplefilter("error")
2116                run_async(foo())
2117
2118    def test_for_11(self):
2119        class F(object):
2120            def __aiter__(self):
2121                return self
2122            def __anext__(self):
2123                return self
2124            def __await__(self):
2125                1 / 0
2126
2127        async def main():
2128            async for _ in F():
2129                pass
2130
2131        if sys.version_info[0] < 3:
2132            with self.assertRaises(ZeroDivisionError) as c:
2133                main().send(None)
2134        else:
2135            with self.assertRaisesRegex(TypeError,
2136                                        'an invalid object from __anext__') as c:
2137                main().send(None)
2138
2139            err = c.exception
2140            self.assertIsInstance(err.__cause__, ZeroDivisionError)
2141
2142    # old-style pre-Py3.5.2 protocol - no longer supported
2143    def __test_for_12(self):
2144        class F(object):
2145            def __aiter__(self):
2146                return self
2147            def __await__(self):
2148                1 / 0
2149
2150        async def main():
2151            async for _ in F():
2152                pass
2153
2154        if sys.version_info[0] < 3:
2155            with self.assertRaises(ZeroDivisionError) as c:
2156                main().send(None)
2157        else:
2158            with self.assertRaisesRegex(TypeError,
2159                                        'an invalid object from __aiter__') as c:
2160                main().send(None)
2161
2162            err = c.exception
2163            self.assertIsInstance(err.__cause__, ZeroDivisionError)
2164
2165    def test_for_tuple(self):
2166        class Done(Exception): pass
2167
2168        class AIter(tuple):
2169            i = 0
2170            def __aiter__(self):
2171                return self
2172            async def __anext__(self):
2173                if self.i >= len(self):
2174                    raise StopAsyncIteration
2175                self.i += 1
2176                return self[self.i - 1]
2177
2178        result = []
2179        async def foo():
2180            async for i in AIter([42]):
2181                result.append(i)
2182            raise Done
2183
2184        with self.assertRaises(Done):
2185            foo().send(None)
2186        self.assertEqual(result, [42])
2187
2188    def test_for_stop_iteration(self):
2189        class Done(Exception): pass
2190
2191        class AIter(StopIteration):
2192            i = 0
2193            def __aiter__(self):
2194                return self
2195            async def __anext__(self):
2196                if self.i:
2197                    raise StopAsyncIteration
2198                self.i += 1
2199                if sys.version_info >= (3, 3):
2200                    return self.value
2201                else:
2202                    return self.args[0]
2203
2204        result = []
2205        async def foo():
2206            async for i in AIter(42):
2207                result.append(i)
2208            raise Done
2209
2210        with self.assertRaises(Done):
2211            foo().send(None)
2212        self.assertEqual(result, [42])
2213
2214    def test_comp_1(self):
2215        async def f(i):
2216            return i
2217
2218        async def run_list():
2219            return [await c for c in [f(1), f(41)]]
2220
2221        async def run_set():
2222            return {await c for c in [f(1), f(41)]}
2223
2224        async def run_dict1():
2225            return {await c: 'a' for c in [f(1), f(41)]}
2226
2227        async def run_dict2():
2228            return {i: await c for i, c in enumerate([f(1), f(41)])}
2229
2230        self.assertEqual(run_async(run_list()), ([], [1, 41]))
2231        self.assertEqual(run_async(run_set()), ([], {1, 41}))
2232        self.assertEqual(run_async(run_dict1()), ([], {1: 'a', 41: 'a'}))
2233        self.assertEqual(run_async(run_dict2()), ([], {0: 1, 1: 41}))
2234
2235    def test_comp_2(self):
2236        async def f(i):
2237            return i
2238
2239        async def run_list():
2240            return [s for c in [f(''), f('abc'), f(''), f(['de', 'fg'])]
2241                    for s in await c]
2242
2243        self.assertEqual(
2244            run_async(run_list()),
2245            ([], ['a', 'b', 'c', 'de', 'fg']))
2246
2247        async def run_set():
2248            return {d
2249                    for c in [f([f([10, 30]),
2250                                 f([20])])]
2251                    for s in await c
2252                    for d in await s}
2253
2254        self.assertEqual(
2255            run_async(run_set()),
2256            ([], {10, 20, 30}))
2257
2258        async def run_set2():
2259            return {await s
2260                    for c in [f([f(10), f(20)])]
2261                    for s in await c}
2262
2263        self.assertEqual(
2264            run_async(run_set2()),
2265            ([], {10, 20}))
2266
2267    def test_comp_3(self):
2268        async def f(it):
2269            for i in it:
2270                yield i
2271
2272        async def run_list():
2273            return [i + 1 async for i in f([10, 20])]
2274        self.assertEqual(
2275            run_async(run_list()),
2276            ([], [11, 21]))
2277
2278        async def run_set():
2279            return {i + 1 async for i in f([10, 20])}
2280        self.assertEqual(
2281            run_async(run_set()),
2282            ([], {11, 21}))
2283
2284        async def run_dict():
2285            return {i + 1: i + 2 async for i in f([10, 20])}
2286        self.assertEqual(
2287            run_async(run_dict()),
2288            ([], {11: 12, 21: 22}))
2289
2290        async def run_gen():
2291            gen = (i + 1 async for i in f([10, 20]))
2292            return [g + 100 async for g in gen]
2293        self.assertEqual(
2294            run_async(run_gen()),
2295            ([], [111, 121]))
2296
2297    def test_comp_4(self):
2298        async def f(it):
2299            for i in it:
2300                yield i
2301
2302        async def run_list():
2303            return [i + 1 async for i in f([10, 20]) if i > 10]
2304        self.assertEqual(
2305            run_async(run_list()),
2306            ([], [21]))
2307
2308        async def run_set():
2309            return {i + 1 async for i in f([10, 20]) if i > 10}
2310        self.assertEqual(
2311            run_async(run_set()),
2312            ([], {21}))
2313
2314        async def run_dict():
2315            return {i + 1: i + 2 async for i in f([10, 20]) if i > 10}
2316        self.assertEqual(
2317            run_async(run_dict()),
2318            ([], {21: 22}))
2319
2320        async def run_gen():
2321            gen = (i + 1 async for i in f([10, 20]) if i > 10)
2322            return [g + 100 async for g in gen]
2323        self.assertEqual(
2324            run_async(run_gen()),
2325            ([], [121]))
2326
2327    def test_comp_5(self):
2328        async def f(it):
2329            for i in it:
2330                yield i
2331
2332        async def run_list():
2333            return [i + 1 for pair in ([10, 20], [30, 40]) if pair[0] > 10
2334                    async for i in f(pair) if i > 30]
2335        self.assertEqual(
2336            run_async(run_list()),
2337            ([], [41]))
2338
2339    def test_comp_6(self):
2340        async def f(it):
2341            for i in it:
2342                yield i
2343
2344        async def run_list():
2345            return [i + 1 async for seq in f([(10, 20), (30,)])
2346                    for i in seq]
2347
2348        self.assertEqual(
2349            run_async(run_list()),
2350            ([], [11, 21, 31]))
2351
2352    def test_comp_7(self):
2353        async def f():
2354            yield 1
2355            yield 2
2356            raise Exception('aaa')
2357
2358        async def run_list():
2359            return [i async for i in f()]
2360
2361        with self.assertRaisesRegex(Exception, 'aaa'):
2362            run_async(run_list())
2363
2364    def test_comp_8(self):
2365        async def f():
2366            return [i for i in [1, 2, 3]]
2367
2368        self.assertEqual(
2369            run_async(f()),
2370            ([], [1, 2, 3]))
2371
2372    def test_comp_9(self):
2373        async def gen():
2374            yield 1
2375            yield 2
2376        async def f():
2377            l = [i async for i in gen()]
2378            return [i for i in l]
2379
2380        self.assertEqual(
2381            run_async(f()),
2382            ([], [1, 2]))
2383
2384    def test_comp_10(self):
2385        async def f():
2386            xx = {i for i in [1, 2, 3]}
2387            return {x: x for x in xx}
2388
2389        self.assertEqual(
2390            run_async(f()),
2391            ([], {1: 1, 2: 2, 3: 3}))
2392
2393    def test_copy(self):
2394        async def func(): pass
2395        coro = func()
2396        with self.assertRaises(TypeError):
2397            copy.copy(coro)
2398
2399        aw = coro.__await__()
2400        try:
2401            with self.assertRaises(TypeError):
2402                copy.copy(aw)
2403        finally:
2404            aw.close()
2405
2406    def test_pickle(self):
2407        async def func(): pass
2408        coro = func()
2409        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
2410            with self.assertRaises((TypeError, pickle.PicklingError)):
2411                pickle.dumps(coro, proto)
2412
2413        aw = coro.__await__()
2414        try:
2415            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
2416                with self.assertRaises((TypeError, pickle.PicklingError)):
2417                    pickle.dumps(aw, proto)
2418        finally:
2419            aw.close()
2420
2421    def test_fatal_coro_warning(self):
2422        # Issue 27811
2423        async def func(): pass
2424
2425        gc.collect()
2426        with warnings.catch_warnings(), captured_stderr() as stderr:
2427            warnings.filterwarnings("error")
2428            func()
2429            gc.collect()
2430        self.assertIn("was never awaited", stderr.getvalue())
2431
2432
2433class CoroAsyncIOCompatTest(unittest.TestCase):
2434
2435    def test_asyncio_1(self):
2436        import asyncio
2437
2438        class MyException(Exception):
2439            pass
2440
2441        buffer = []
2442
2443        class CM(object):
2444            async def __aenter__(self):
2445                buffer.append(1)
2446                await asyncio.sleep(0.01)
2447                buffer.append(2)
2448                return self
2449
2450            async def __aexit__(self, exc_type, exc_val, exc_tb):
2451                await asyncio.sleep(0.01)
2452                buffer.append(exc_type.__name__)
2453
2454        async def f():
2455            async with CM() as c:
2456                await asyncio.sleep(0.01)
2457                raise MyException
2458            buffer.append('unreachable')
2459
2460        loop = asyncio.new_event_loop()
2461        asyncio.set_event_loop(loop)
2462        try:
2463            loop.run_until_complete(f())
2464        except MyException:
2465            pass
2466        finally:
2467            loop.close()
2468            asyncio.set_event_loop(None)
2469
2470        self.assertEqual(buffer, [1, 2, 'MyException'])
2471
2472    def test_asyncio_cython_crash_gh1999(self):
2473        async def await_future(loop):
2474            fut = loop.create_future()
2475            loop.call_later(1, lambda: fut.set_result(1))
2476            await fut
2477
2478        async def delegate_to_await_future(loop):
2479            await await_future(loop)
2480
2481        ns = {}
2482        __builtins__.exec("""
2483        async def call(loop, await_func):  # requires Py3.5+
2484            await await_func(loop)
2485        """.strip(), ns, ns)
2486        call = ns['call']
2487
2488        import asyncio
2489        loop = asyncio.new_event_loop()
2490        asyncio.set_event_loop(loop)
2491        try:
2492            loop.run_until_complete(call(loop, delegate_to_await_future))
2493        finally:
2494            loop.close()
2495            asyncio.set_event_loop(None)
2496
2497
2498class SysSetCoroWrapperTest(unittest.TestCase):
2499
2500    def test_set_wrapper_1(self):
2501        async def foo():
2502            return 'spam'
2503
2504        wrapped = None
2505        def wrap(gen):
2506            nonlocal wrapped
2507            wrapped = gen
2508            return gen
2509
2510        self.assertIsNone(sys.get_coroutine_wrapper())
2511
2512        sys.set_coroutine_wrapper(wrap)
2513        self.assertIs(sys.get_coroutine_wrapper(), wrap)
2514        try:
2515            f = foo()
2516            self.assertTrue(wrapped)
2517
2518            self.assertEqual(run_async(f), ([], 'spam'))
2519        finally:
2520            sys.set_coroutine_wrapper(None)
2521
2522        self.assertIsNone(sys.get_coroutine_wrapper())
2523
2524        wrapped = None
2525        with silence_coro_gc():
2526            foo()
2527        self.assertFalse(wrapped)
2528
2529    def test_set_wrapper_2(self):
2530        self.assertIsNone(sys.get_coroutine_wrapper())
2531        with self.assertRaisesRegex(TypeError, "callable expected, got int"):
2532            sys.set_coroutine_wrapper(1)
2533        self.assertIsNone(sys.get_coroutine_wrapper())
2534
2535    def test_set_wrapper_3(self):
2536        async def foo():
2537            return 'spam'
2538
2539        def wrapper(coro):
2540            async def wrap(coro):
2541                return await coro
2542            return wrap(coro)
2543
2544        sys.set_coroutine_wrapper(wrapper)
2545        try:
2546            with silence_coro_gc(), self.assertRaisesRegex(
2547                RuntimeError,
2548                "coroutine wrapper.*\.wrapper at 0x.*attempted to "
2549                "recursively wrap .* wrap .*"):
2550
2551                foo()
2552        finally:
2553            sys.set_coroutine_wrapper(None)
2554
2555    def test_set_wrapper_4(self):
2556        @types_coroutine
2557        def foo():
2558            return 'spam'
2559
2560        wrapped = None
2561        def wrap(gen):
2562            nonlocal wrapped
2563            wrapped = gen
2564            return gen
2565
2566        sys.set_coroutine_wrapper(wrap)
2567        try:
2568            foo()
2569            self.assertIs(
2570                wrapped, None,
2571                "generator-based coroutine was wrapped via "
2572                "sys.set_coroutine_wrapper")
2573        finally:
2574            sys.set_coroutine_wrapper(None)
2575
2576
2577class CAPITest(unittest.TestCase):
2578
2579    def test_tp_await_1(self):
2580        from _testcapi import awaitType as at
2581
2582        async def foo():
2583            future = at(iter([1]))
2584            return (await future)
2585
2586        self.assertEqual(foo().send(None), 1)
2587
2588    def test_tp_await_2(self):
2589        # Test tp_await to __await__ mapping
2590        from _testcapi import awaitType as at
2591        future = at(iter([1]))
2592        self.assertEqual(next(future.__await__()), 1)
2593
2594    def test_tp_await_3(self):
2595        from _testcapi import awaitType as at
2596
2597        async def foo():
2598            future = at(1)
2599            return (await future)
2600
2601        with self.assertRaisesRegex(
2602                TypeError, "__await__.*returned non-iterator of type 'int'"):
2603            self.assertEqual(foo().send(None), 1)
2604
2605
2606# disable some tests that only apply to CPython
2607
2608# TODO?
2609if True or sys.version_info < (3, 5):
2610    SysSetCoroWrapperTest = None
2611    CAPITest = None
2612
2613if sys.version_info < (3, 5):  # (3, 4, 4)
2614    CoroAsyncIOCompatTest = None
2615else:
2616    try:
2617        import asyncio
2618    except ImportError:
2619        CoroAsyncIOCompatTest = None
2620
2621if __name__=="__main__":
2622    unittest.main()
2623