1import contextlib
2import copy
3import inspect
4import pickle
5import re
6import sys
7import types
8import unittest
9import warnings
10from test import support
11from test.support.script_helper import assert_python_ok
12
13
14class AsyncYieldFrom:
15    def __init__(self, obj):
16        self.obj = obj
17
18    def __await__(self):
19        yield from self.obj
20
21
22class AsyncYield:
23    def __init__(self, value):
24        self.value = value
25
26    def __await__(self):
27        yield self.value
28
29
30def run_async(coro):
31    assert coro.__class__ in {types.GeneratorType, types.CoroutineType}
32
33    buffer = []
34    result = None
35    while True:
36        try:
37            buffer.append(coro.send(None))
38        except StopIteration as ex:
39            result = ex.args[0] if ex.args else None
40            break
41    return buffer, result
42
43
44def run_async__await__(coro):
45    assert coro.__class__ is types.CoroutineType
46    aw = coro.__await__()
47    buffer = []
48    result = None
49    i = 0
50    while True:
51        try:
52            if i % 2:
53                buffer.append(next(aw))
54            else:
55                buffer.append(aw.send(None))
56            i += 1
57        except StopIteration as ex:
58            result = ex.args[0] if ex.args else None
59            break
60    return buffer, result
61
62
63@contextlib.contextmanager
64def silence_coro_gc():
65    with warnings.catch_warnings():
66        warnings.simplefilter("ignore")
67        yield
68        support.gc_collect()
69
70
71class AsyncBadSyntaxTest(unittest.TestCase):
72
73    def test_badsyntax_1(self):
74        samples = [
75            """def foo():
76                await something()
77            """,
78
79            """await something()""",
80
81            """async def foo():
82                yield from []
83            """,
84
85            """async def foo():
86                await await fut
87            """,
88
89            """async def foo(a=await something()):
90                pass
91            """,
92
93            """async def foo(a:await something()):
94                pass
95            """,
96
97            """async def foo():
98                def bar():
99                 [i async for i in els]
100            """,
101
102            """async def foo():
103                def bar():
104                 [await i for i in els]
105            """,
106
107            """async def foo():
108                def bar():
109                 [i for i in els
110                    async for b in els]
111            """,
112
113            """async def foo():
114                def bar():
115                 [i for i in els
116                    for c in b
117                    async for b in els]
118            """,
119
120            """async def foo():
121                def bar():
122                 [i for i in els
123                    async for b in els
124                    for c in b]
125            """,
126
127            """async def foo():
128                def bar():
129                 [i for i in els
130                    for b in await els]
131            """,
132
133            """async def foo():
134                def bar():
135                 [i for i in els
136                    for b in els
137                        if await b]
138            """,
139
140            """async def foo():
141                def bar():
142                 [i for i in await els]
143            """,
144
145            """async def foo():
146                def bar():
147                 [i for i in els if await i]
148            """,
149
150            """def bar():
151                 [i async for i in els]
152            """,
153
154            """def bar():
155                 {i: i async for i in els}
156            """,
157
158            """def bar():
159                 {i async for i in els}
160            """,
161
162            """def bar():
163                 [await i for i in els]
164            """,
165
166            """def bar():
167                 [i for i in els
168                    async for b in els]
169            """,
170
171            """def bar():
172                 [i for i in els
173                    for c in b
174                    async for b in els]
175            """,
176
177            """def bar():
178                 [i for i in els
179                    async for b in els
180                    for c in b]
181            """,
182
183            """def bar():
184                 [i for i in els
185                    for b in await els]
186            """,
187
188            """def bar():
189                 [i for i in els
190                    for b in els
191                        if await b]
192            """,
193
194            """def bar():
195                 [i for i in await els]
196            """,
197
198            """def bar():
199                 [i for i in els if await i]
200            """,
201
202            """async def foo():
203                await
204            """,
205
206            """async def foo():
207                   def bar(): pass
208                   await = 1
209            """,
210
211            """async def foo():
212
213                   def bar(): pass
214                   await = 1
215            """,
216
217            """async def foo():
218                   def bar(): pass
219                   if 1:
220                       await = 1
221            """,
222
223            """def foo():
224                   async def bar(): pass
225                   if 1:
226                       await a
227            """,
228
229            """def foo():
230                   async def bar(): pass
231                   await a
232            """,
233
234            """def foo():
235                   def baz(): pass
236                   async def bar(): pass
237                   await a
238            """,
239
240            """def foo():
241                   def baz(): pass
242                   # 456
243                   async def bar(): pass
244                   # 123
245                   await a
246            """,
247
248            """async def foo():
249                   def baz(): pass
250                   # 456
251                   async def bar(): pass
252                   # 123
253                   await = 2
254            """,
255
256            """def foo():
257
258                   def baz(): pass
259
260                   async def bar(): pass
261
262                   await a
263            """,
264
265            """async def foo():
266
267                   def baz(): pass
268
269                   async def bar(): pass
270
271                   await = 2
272            """,
273
274            """async def foo():
275                   def async(): pass
276            """,
277
278            """async def foo():
279                   def await(): pass
280            """,
281
282            """async def foo():
283                   def bar():
284                       await
285            """,
286
287            """async def foo():
288                   return lambda async: await
289            """,
290
291            """async def foo():
292                   return lambda a: await
293            """,
294
295            """await a()""",
296
297            """async def foo(a=await b):
298                   pass
299            """,
300
301            """async def foo(a:await b):
302                   pass
303            """,
304
305            """def baz():
306                   async def foo(a=await b):
307                       pass
308            """,
309
310            """async def foo(async):
311                   pass
312            """,
313
314            """async def foo():
315                   def bar():
316                        def baz():
317                            async = 1
318            """,
319
320            """async def foo():
321                   def bar():
322                        def baz():
323                            pass
324                        async = 1
325            """,
326
327            """def foo():
328                   async def bar():
329
330                        async def baz():
331                            pass
332
333                        def baz():
334                            42
335
336                        async = 1
337            """,
338
339            """async def foo():
340                   def bar():
341                        def baz():
342                            pass\nawait foo()
343            """,
344
345            """def foo():
346                   def bar():
347                        async def baz():
348                            pass\nawait foo()
349            """,
350
351            """async def foo(await):
352                   pass
353            """,
354
355            """def foo():
356
357                   async def bar(): pass
358
359                   await a
360            """,
361
362            """def foo():
363                   async def bar():
364                        pass\nawait a
365            """,
366            """def foo():
367                   async for i in arange(2):
368                       pass
369            """,
370            """def foo():
371                   async with resource:
372                       pass
373            """,
374            """async with resource:
375                   pass
376            """,
377            """async for i in arange(2):
378                   pass
379            """,
380            ]
381
382        for code in samples:
383            with self.subTest(code=code), self.assertRaises(SyntaxError):
384                compile(code, "<test>", "exec")
385
386    def test_badsyntax_2(self):
387        samples = [
388            """def foo():
389                await = 1
390            """,
391
392            """class Bar:
393                def async(): pass
394            """,
395
396            """class Bar:
397                async = 1
398            """,
399
400            """class async:
401                pass
402            """,
403
404            """class await:
405                pass
406            """,
407
408            """import math as await""",
409
410            """def async():
411                pass""",
412
413            """def foo(*, await=1):
414                pass"""
415
416            """async = 1""",
417
418            """print(await=1)"""
419        ]
420
421        for code in samples:
422            with self.subTest(code=code), self.assertRaises(SyntaxError):
423                compile(code, "<test>", "exec")
424
425    def test_badsyntax_3(self):
426        with self.assertRaises(SyntaxError):
427            compile("async = 1", "<test>", "exec")
428
429    def test_badsyntax_4(self):
430        samples = [
431            '''def foo(await):
432                async def foo(): pass
433                async def foo():
434                    pass
435                return await + 1
436            ''',
437
438            '''def foo(await):
439                async def foo(): pass
440                async def foo(): pass
441                return await + 1
442            ''',
443
444            '''def foo(await):
445
446                async def foo(): pass
447
448                async def foo(): pass
449
450                return await + 1
451            ''',
452
453            '''def foo(await):
454                """spam"""
455                async def foo(): \
456                    pass
457                # 123
458                async def foo(): pass
459                # 456
460                return await + 1
461            ''',
462
463            '''def foo(await):
464                def foo(): pass
465                def foo(): pass
466                async def bar(): return await_
467                await_ = await
468                try:
469                    bar().send(None)
470                except StopIteration as ex:
471                    return ex.args[0] + 1
472            '''
473        ]
474
475        for code in samples:
476            with self.subTest(code=code), self.assertRaises(SyntaxError):
477                compile(code, "<test>", "exec")
478
479
480class TokenizerRegrTest(unittest.TestCase):
481
482    def test_oneline_defs(self):
483        buf = []
484        for i in range(500):
485            buf.append('def i{i}(): return {i}'.format(i=i))
486        buf = '\n'.join(buf)
487
488        # Test that 500 consequent, one-line defs is OK
489        ns = {}
490        exec(buf, ns, ns)
491        self.assertEqual(ns['i499'](), 499)
492
493        # Test that 500 consequent, one-line defs *and*
494        # one 'async def' following them is OK
495        buf += '\nasync def foo():\n    return'
496        ns = {}
497        exec(buf, ns, ns)
498        self.assertEqual(ns['i499'](), 499)
499        self.assertTrue(inspect.iscoroutinefunction(ns['foo']))
500
501
502class CoroutineTest(unittest.TestCase):
503
504    def test_gen_1(self):
505        def gen(): yield
506        self.assertFalse(hasattr(gen, '__await__'))
507
508    def test_func_1(self):
509        async def foo():
510            return 10
511
512        f = foo()
513        self.assertIsInstance(f, types.CoroutineType)
514        self.assertTrue(bool(foo.__code__.co_flags & inspect.CO_COROUTINE))
515        self.assertFalse(bool(foo.__code__.co_flags & inspect.CO_GENERATOR))
516        self.assertTrue(bool(f.cr_code.co_flags & inspect.CO_COROUTINE))
517        self.assertFalse(bool(f.cr_code.co_flags & inspect.CO_GENERATOR))
518        self.assertEqual(run_async(f), ([], 10))
519
520        self.assertEqual(run_async__await__(foo()), ([], 10))
521
522        def bar(): pass
523        self.assertFalse(bool(bar.__code__.co_flags & inspect.CO_COROUTINE))
524
525    def test_func_2(self):
526        async def foo():
527            raise StopIteration
528
529        with self.assertRaisesRegex(
530                RuntimeError, "coroutine raised StopIteration"):
531
532            run_async(foo())
533
534    def test_func_3(self):
535        async def foo():
536            raise StopIteration
537
538        coro = foo()
539        self.assertRegex(repr(coro), '^<coroutine object.* at 0x.*>$')
540        coro.close()
541
542    def test_func_4(self):
543        async def foo():
544            raise StopIteration
545        coro = foo()
546
547        check = lambda: self.assertRaisesRegex(
548            TypeError, "'coroutine' object is not iterable")
549
550        with check():
551            list(coro)
552
553        with check():
554            tuple(coro)
555
556        with check():
557            sum(coro)
558
559        with check():
560            iter(coro)
561
562        with check():
563            for i in coro:
564                pass
565
566        with check():
567            [i for i in coro]
568
569        coro.close()
570
571    def test_func_5(self):
572        @types.coroutine
573        def bar():
574            yield 1
575
576        async def foo():
577            await bar()
578
579        check = lambda: self.assertRaisesRegex(
580            TypeError, "'coroutine' object is not iterable")
581
582        coro = foo()
583        with check():
584            for el in coro:
585                pass
586        coro.close()
587
588        # the following should pass without an error
589        for el in bar():
590            self.assertEqual(el, 1)
591        self.assertEqual([el for el in bar()], [1])
592        self.assertEqual(tuple(bar()), (1,))
593        self.assertEqual(next(iter(bar())), 1)
594
595    def test_func_6(self):
596        @types.coroutine
597        def bar():
598            yield 1
599            yield 2
600
601        async def foo():
602            await bar()
603
604        f = foo()
605        self.assertEqual(f.send(None), 1)
606        self.assertEqual(f.send(None), 2)
607        with self.assertRaises(StopIteration):
608            f.send(None)
609
610    def test_func_7(self):
611        async def bar():
612            return 10
613        coro = bar()
614
615        def foo():
616            yield from coro
617
618        with self.assertRaisesRegex(
619                TypeError,
620                "cannot 'yield from' a coroutine object in "
621                "a non-coroutine generator"):
622            list(foo())
623
624        coro.close()
625
626    def test_func_8(self):
627        @types.coroutine
628        def bar():
629            return (yield from coro)
630
631        async def foo():
632            return 'spam'
633
634        coro = foo()
635        self.assertEqual(run_async(bar()), ([], 'spam'))
636        coro.close()
637
638    def test_func_9(self):
639        async def foo():
640            pass
641
642        with self.assertWarnsRegex(
643                RuntimeWarning,
644                r"coroutine '.*test_func_9.*foo' was never awaited"):
645
646            foo()
647            support.gc_collect()
648
649        with self.assertWarnsRegex(
650                RuntimeWarning,
651                r"coroutine '.*test_func_9.*foo' was never awaited"):
652
653            with self.assertRaises(TypeError):
654                # See bpo-32703.
655                for _ in foo():
656                    pass
657
658            support.gc_collect()
659
660    def test_func_10(self):
661        N = 0
662
663        @types.coroutine
664        def gen():
665            nonlocal N
666            try:
667                a = yield
668                yield (a ** 2)
669            except ZeroDivisionError:
670                N += 100
671                raise
672            finally:
673                N += 1
674
675        async def foo():
676            await gen()
677
678        coro = foo()
679        aw = coro.__await__()
680        self.assertIs(aw, iter(aw))
681        next(aw)
682        self.assertEqual(aw.send(10), 100)
683
684        self.assertEqual(N, 0)
685        aw.close()
686        self.assertEqual(N, 1)
687
688        coro = foo()
689        aw = coro.__await__()
690        next(aw)
691        with self.assertRaises(ZeroDivisionError):
692            aw.throw(ZeroDivisionError, None, None)
693        self.assertEqual(N, 102)
694
695    def test_func_11(self):
696        async def func(): pass
697        coro = func()
698        # Test that PyCoro_Type and _PyCoroWrapper_Type types were properly
699        # initialized
700        self.assertIn('__await__', dir(coro))
701        self.assertIn('__iter__', dir(coro.__await__()))
702        self.assertIn('coroutine_wrapper', repr(coro.__await__()))
703        coro.close() # avoid RuntimeWarning
704
705    def test_func_12(self):
706        async def g():
707            i = me.send(None)
708            await foo
709        me = g()
710        with self.assertRaisesRegex(ValueError,
711                                    "coroutine already executing"):
712            me.send(None)
713
714    def test_func_13(self):
715        async def g():
716            pass
717
718        coro = g()
719        with self.assertRaisesRegex(
720                TypeError,
721                "can't send non-None value to a just-started coroutine"):
722            coro.send('spam')
723
724        coro.close()
725
726    def test_func_14(self):
727        @types.coroutine
728        def gen():
729            yield
730        async def coro():
731            try:
732                await gen()
733            except GeneratorExit:
734                await gen()
735        c = coro()
736        c.send(None)
737        with self.assertRaisesRegex(RuntimeError,
738                                    "coroutine ignored GeneratorExit"):
739            c.close()
740
741    def test_func_15(self):
742        # See http://bugs.python.org/issue25887 for details
743
744        async def spammer():
745            return 'spam'
746        async def reader(coro):
747            return await coro
748
749        spammer_coro = spammer()
750
751        with self.assertRaisesRegex(StopIteration, 'spam'):
752            reader(spammer_coro).send(None)
753
754        with self.assertRaisesRegex(RuntimeError,
755                                    'cannot reuse already awaited coroutine'):
756            reader(spammer_coro).send(None)
757
758    def test_func_16(self):
759        # See http://bugs.python.org/issue25887 for details
760
761        @types.coroutine
762        def nop():
763            yield
764        async def send():
765            await nop()
766            return 'spam'
767        async def read(coro):
768            await nop()
769            return await coro
770
771        spammer = send()
772
773        reader = read(spammer)
774        reader.send(None)
775        reader.send(None)
776        with self.assertRaisesRegex(Exception, 'ham'):
777            reader.throw(Exception('ham'))
778
779        reader = read(spammer)
780        reader.send(None)
781        with self.assertRaisesRegex(RuntimeError,
782                                    'cannot reuse already awaited coroutine'):
783            reader.send(None)
784
785        with self.assertRaisesRegex(RuntimeError,
786                                    'cannot reuse already awaited coroutine'):
787            reader.throw(Exception('wat'))
788
789    def test_func_17(self):
790        # See http://bugs.python.org/issue25887 for details
791
792        async def coroutine():
793            return 'spam'
794
795        coro = coroutine()
796        with self.assertRaisesRegex(StopIteration, 'spam'):
797            coro.send(None)
798
799        with self.assertRaisesRegex(RuntimeError,
800                                    'cannot reuse already awaited coroutine'):
801            coro.send(None)
802
803        with self.assertRaisesRegex(RuntimeError,
804                                    'cannot reuse already awaited coroutine'):
805            coro.throw(Exception('wat'))
806
807        # Closing a coroutine shouldn't raise any exception even if it's
808        # already closed/exhausted (similar to generators)
809        coro.close()
810        coro.close()
811
812    def test_func_18(self):
813        # See http://bugs.python.org/issue25887 for details
814
815        async def coroutine():
816            return 'spam'
817
818        coro = coroutine()
819        await_iter = coro.__await__()
820        it = iter(await_iter)
821
822        with self.assertRaisesRegex(StopIteration, 'spam'):
823            it.send(None)
824
825        with self.assertRaisesRegex(RuntimeError,
826                                    'cannot reuse already awaited coroutine'):
827            it.send(None)
828
829        with self.assertRaisesRegex(RuntimeError,
830                                    'cannot reuse already awaited coroutine'):
831            # Although the iterator protocol requires iterators to
832            # raise another StopIteration here, we don't want to do
833            # that.  In this particular case, the iterator will raise
834            # a RuntimeError, so that 'yield from' and 'await'
835            # expressions will trigger the error, instead of silently
836            # ignoring the call.
837            next(it)
838
839        with self.assertRaisesRegex(RuntimeError,
840                                    'cannot reuse already awaited coroutine'):
841            it.throw(Exception('wat'))
842
843        with self.assertRaisesRegex(RuntimeError,
844                                    'cannot reuse already awaited coroutine'):
845            it.throw(Exception('wat'))
846
847        # Closing a coroutine shouldn't raise any exception even if it's
848        # already closed/exhausted (similar to generators)
849        it.close()
850        it.close()
851
852    def test_func_19(self):
853        CHK = 0
854
855        @types.coroutine
856        def foo():
857            nonlocal CHK
858            yield
859            try:
860                yield
861            except GeneratorExit:
862                CHK += 1
863
864        async def coroutine():
865            await foo()
866
867        coro = coroutine()
868
869        coro.send(None)
870        coro.send(None)
871
872        self.assertEqual(CHK, 0)
873        coro.close()
874        self.assertEqual(CHK, 1)
875
876        for _ in range(3):
877            # Closing a coroutine shouldn't raise any exception even if it's
878            # already closed/exhausted (similar to generators)
879            coro.close()
880            self.assertEqual(CHK, 1)
881
882    def test_coro_wrapper_send_tuple(self):
883        async def foo():
884            return (10,)
885
886        result = run_async__await__(foo())
887        self.assertEqual(result, ([], (10,)))
888
889    def test_coro_wrapper_send_stop_iterator(self):
890        async def foo():
891            return StopIteration(10)
892
893        result = run_async__await__(foo())
894        self.assertIsInstance(result[1], StopIteration)
895        self.assertEqual(result[1].value, 10)
896
897    def test_cr_await(self):
898        @types.coroutine
899        def a():
900            self.assertEqual(inspect.getcoroutinestate(coro_b), inspect.CORO_RUNNING)
901            self.assertIsNone(coro_b.cr_await)
902            yield
903            self.assertEqual(inspect.getcoroutinestate(coro_b), inspect.CORO_RUNNING)
904            self.assertIsNone(coro_b.cr_await)
905
906        async def c():
907            await a()
908
909        async def b():
910            self.assertIsNone(coro_b.cr_await)
911            await c()
912            self.assertIsNone(coro_b.cr_await)
913
914        coro_b = b()
915        self.assertEqual(inspect.getcoroutinestate(coro_b), inspect.CORO_CREATED)
916        self.assertIsNone(coro_b.cr_await)
917
918        coro_b.send(None)
919        self.assertEqual(inspect.getcoroutinestate(coro_b), inspect.CORO_SUSPENDED)
920        self.assertEqual(coro_b.cr_await.cr_await.gi_code.co_name, 'a')
921
922        with self.assertRaises(StopIteration):
923            coro_b.send(None)  # complete coroutine
924        self.assertEqual(inspect.getcoroutinestate(coro_b), inspect.CORO_CLOSED)
925        self.assertIsNone(coro_b.cr_await)
926
927    def test_corotype_1(self):
928        ct = types.CoroutineType
929        self.assertIn('into coroutine', ct.send.__doc__)
930        self.assertIn('inside coroutine', ct.close.__doc__)
931        self.assertIn('in coroutine', ct.throw.__doc__)
932        self.assertIn('of the coroutine', ct.__dict__['__name__'].__doc__)
933        self.assertIn('of the coroutine', ct.__dict__['__qualname__'].__doc__)
934        self.assertEqual(ct.__name__, 'coroutine')
935
936        async def f(): pass
937        c = f()
938        self.assertIn('coroutine object', repr(c))
939        c.close()
940
941    def test_await_1(self):
942
943        async def foo():
944            await 1
945        with self.assertRaisesRegex(TypeError, "object int can.t.*await"):
946            run_async(foo())
947
948    def test_await_2(self):
949        async def foo():
950            await []
951        with self.assertRaisesRegex(TypeError, "object list can.t.*await"):
952            run_async(foo())
953
954    def test_await_3(self):
955        async def foo():
956            await AsyncYieldFrom([1, 2, 3])
957
958        self.assertEqual(run_async(foo()), ([1, 2, 3], None))
959        self.assertEqual(run_async__await__(foo()), ([1, 2, 3], None))
960
961    def test_await_4(self):
962        async def bar():
963            return 42
964
965        async def foo():
966            return await bar()
967
968        self.assertEqual(run_async(foo()), ([], 42))
969
970    def test_await_5(self):
971        class Awaitable:
972            def __await__(self):
973                return
974
975        async def foo():
976            return (await Awaitable())
977
978        with self.assertRaisesRegex(
979            TypeError, "__await__.*returned non-iterator of type"):
980
981            run_async(foo())
982
983    def test_await_6(self):
984        class Awaitable:
985            def __await__(self):
986                return iter([52])
987
988        async def foo():
989            return (await Awaitable())
990
991        self.assertEqual(run_async(foo()), ([52], None))
992
993    def test_await_7(self):
994        class Awaitable:
995            def __await__(self):
996                yield 42
997                return 100
998
999        async def foo():
1000            return (await Awaitable())
1001
1002        self.assertEqual(run_async(foo()), ([42], 100))
1003
1004    def test_await_8(self):
1005        class Awaitable:
1006            pass
1007
1008        async def foo(): return await Awaitable()
1009
1010        with self.assertRaisesRegex(
1011            TypeError, "object Awaitable can't be used in 'await' expression"):
1012
1013            run_async(foo())
1014
1015    def test_await_9(self):
1016        def wrap():
1017            return bar
1018
1019        async def bar():
1020            return 42
1021
1022        async def foo():
1023            db = {'b':  lambda: wrap}
1024
1025            class DB:
1026                b = wrap
1027
1028            return (await bar() + await wrap()() + await db['b']()()() +
1029                    await bar() * 1000 + await DB.b()())
1030
1031        async def foo2():
1032            return -await bar()
1033
1034        self.assertEqual(run_async(foo()), ([], 42168))
1035        self.assertEqual(run_async(foo2()), ([], -42))
1036
1037    def test_await_10(self):
1038        async def baz():
1039            return 42
1040
1041        async def bar():
1042            return baz()
1043
1044        async def foo():
1045            return await (await bar())
1046
1047        self.assertEqual(run_async(foo()), ([], 42))
1048
1049    def test_await_11(self):
1050        def ident(val):
1051            return val
1052
1053        async def bar():
1054            return 'spam'
1055
1056        async def foo():
1057            return ident(val=await bar())
1058
1059        async def foo2():
1060            return await bar(), 'ham'
1061
1062        self.assertEqual(run_async(foo2()), ([], ('spam', 'ham')))
1063
1064    def test_await_12(self):
1065        async def coro():
1066            return 'spam'
1067        c = coro()
1068
1069        class Awaitable:
1070            def __await__(self):
1071                return c
1072
1073        async def foo():
1074            return await Awaitable()
1075
1076        with self.assertRaisesRegex(
1077                TypeError, r"__await__\(\) returned a coroutine"):
1078            run_async(foo())
1079
1080        c.close()
1081
1082    def test_await_13(self):
1083        class Awaitable:
1084            def __await__(self):
1085                return self
1086
1087        async def foo():
1088            return await Awaitable()
1089
1090        with self.assertRaisesRegex(
1091            TypeError, "__await__.*returned non-iterator of type"):
1092
1093            run_async(foo())
1094
1095    def test_await_14(self):
1096        class Wrapper:
1097            # Forces the interpreter to use CoroutineType.__await__
1098            def __init__(self, coro):
1099                assert coro.__class__ is types.CoroutineType
1100                self.coro = coro
1101            def __await__(self):
1102                return self.coro.__await__()
1103
1104        class FutureLike:
1105            def __await__(self):
1106                return (yield)
1107
1108        class Marker(Exception):
1109            pass
1110
1111        async def coro1():
1112            try:
1113                return await FutureLike()
1114            except ZeroDivisionError:
1115                raise Marker
1116        async def coro2():
1117            return await Wrapper(coro1())
1118
1119        c = coro2()
1120        c.send(None)
1121        with self.assertRaisesRegex(StopIteration, 'spam'):
1122            c.send('spam')
1123
1124        c = coro2()
1125        c.send(None)
1126        with self.assertRaises(Marker):
1127            c.throw(ZeroDivisionError)
1128
1129    def test_await_15(self):
1130        @types.coroutine
1131        def nop():
1132            yield
1133
1134        async def coroutine():
1135            await nop()
1136
1137        async def waiter(coro):
1138            await coro
1139
1140        coro = coroutine()
1141        coro.send(None)
1142
1143        with self.assertRaisesRegex(RuntimeError,
1144                                    "coroutine is being awaited already"):
1145            waiter(coro).send(None)
1146
1147    def test_await_16(self):
1148        # See https://bugs.python.org/issue29600 for details.
1149
1150        async def f():
1151            return ValueError()
1152
1153        async def g():
1154            try:
1155                raise KeyError
1156            except:
1157                return await f()
1158
1159        _, result = run_async(g())
1160        self.assertIsNone(result.__context__)
1161
1162    def test_with_1(self):
1163        class Manager:
1164            def __init__(self, name):
1165                self.name = name
1166
1167            async def __aenter__(self):
1168                await AsyncYieldFrom(['enter-1-' + self.name,
1169                                      'enter-2-' + self.name])
1170                return self
1171
1172            async def __aexit__(self, *args):
1173                await AsyncYieldFrom(['exit-1-' + self.name,
1174                                      'exit-2-' + self.name])
1175
1176                if self.name == 'B':
1177                    return True
1178
1179
1180        async def foo():
1181            async with Manager("A") as a, Manager("B") as b:
1182                await AsyncYieldFrom([('managers', a.name, b.name)])
1183                1/0
1184
1185        f = foo()
1186        result, _ = run_async(f)
1187
1188        self.assertEqual(
1189            result, ['enter-1-A', 'enter-2-A', 'enter-1-B', 'enter-2-B',
1190                     ('managers', 'A', 'B'),
1191                     'exit-1-B', 'exit-2-B', 'exit-1-A', 'exit-2-A']
1192        )
1193
1194        async def foo():
1195            async with Manager("A") as a, Manager("C") as c:
1196                await AsyncYieldFrom([('managers', a.name, c.name)])
1197                1/0
1198
1199        with self.assertRaises(ZeroDivisionError):
1200            run_async(foo())
1201
1202    def test_with_2(self):
1203        class CM:
1204            def __aenter__(self):
1205                pass
1206
1207        async def foo():
1208            async with CM():
1209                pass
1210
1211        with self.assertRaisesRegex(AttributeError, '__aexit__'):
1212            run_async(foo())
1213
1214    def test_with_3(self):
1215        class CM:
1216            def __aexit__(self):
1217                pass
1218
1219        async def foo():
1220            async with CM():
1221                pass
1222
1223        with self.assertRaisesRegex(AttributeError, '__aenter__'):
1224            run_async(foo())
1225
1226    def test_with_4(self):
1227        class CM:
1228            def __enter__(self):
1229                pass
1230
1231            def __exit__(self):
1232                pass
1233
1234        async def foo():
1235            async with CM():
1236                pass
1237
1238        with self.assertRaisesRegex(AttributeError, '__aexit__'):
1239            run_async(foo())
1240
1241    def test_with_5(self):
1242        # While this test doesn't make a lot of sense,
1243        # it's a regression test for an early bug with opcodes
1244        # generation
1245
1246        class CM:
1247            async def __aenter__(self):
1248                return self
1249
1250            async def __aexit__(self, *exc):
1251                pass
1252
1253        async def func():
1254            async with CM():
1255                assert (1, ) == 1
1256
1257        with self.assertRaises(AssertionError):
1258            run_async(func())
1259
1260    def test_with_6(self):
1261        class CM:
1262            def __aenter__(self):
1263                return 123
1264
1265            def __aexit__(self, *e):
1266                return 456
1267
1268        async def foo():
1269            async with CM():
1270                pass
1271
1272        with self.assertRaisesRegex(
1273                TypeError,
1274                "'async with' received an object from __aenter__ "
1275                "that does not implement __await__: int"):
1276            # it's important that __aexit__ wasn't called
1277            run_async(foo())
1278
1279    def test_with_7(self):
1280        class CM:
1281            async def __aenter__(self):
1282                return self
1283
1284            def __aexit__(self, *e):
1285                return 444
1286
1287        # Exit with exception
1288        async def foo():
1289            async with CM():
1290                1/0
1291
1292        try:
1293            run_async(foo())
1294        except TypeError as exc:
1295            self.assertRegex(
1296                exc.args[0],
1297                "'async with' received an object from __aexit__ "
1298                "that does not implement __await__: int")
1299            self.assertTrue(exc.__context__ is not None)
1300            self.assertTrue(isinstance(exc.__context__, ZeroDivisionError))
1301        else:
1302            self.fail('invalid asynchronous context manager did not fail')
1303
1304
1305    def test_with_8(self):
1306        CNT = 0
1307
1308        class CM:
1309            async def __aenter__(self):
1310                return self
1311
1312            def __aexit__(self, *e):
1313                return 456
1314
1315        # Normal exit
1316        async def foo():
1317            nonlocal CNT
1318            async with CM():
1319                CNT += 1
1320        with self.assertRaisesRegex(
1321                TypeError,
1322                "'async with' received an object from __aexit__ "
1323                "that does not implement __await__: int"):
1324            run_async(foo())
1325        self.assertEqual(CNT, 1)
1326
1327        # Exit with 'break'
1328        async def foo():
1329            nonlocal CNT
1330            for i in range(2):
1331                async with CM():
1332                    CNT += 1
1333                    break
1334        with self.assertRaisesRegex(
1335                TypeError,
1336                "'async with' received an object from __aexit__ "
1337                "that does not implement __await__: int"):
1338            run_async(foo())
1339        self.assertEqual(CNT, 2)
1340
1341        # Exit with 'continue'
1342        async def foo():
1343            nonlocal CNT
1344            for i in range(2):
1345                async with CM():
1346                    CNT += 1
1347                    continue
1348        with self.assertRaisesRegex(
1349                TypeError,
1350                "'async with' received an object from __aexit__ "
1351                "that does not implement __await__: int"):
1352            run_async(foo())
1353        self.assertEqual(CNT, 3)
1354
1355        # Exit with 'return'
1356        async def foo():
1357            nonlocal CNT
1358            async with CM():
1359                CNT += 1
1360                return
1361        with self.assertRaisesRegex(
1362                TypeError,
1363                "'async with' received an object from __aexit__ "
1364                "that does not implement __await__: int"):
1365            run_async(foo())
1366        self.assertEqual(CNT, 4)
1367
1368
1369    def test_with_9(self):
1370        CNT = 0
1371
1372        class CM:
1373            async def __aenter__(self):
1374                return self
1375
1376            async def __aexit__(self, *e):
1377                1/0
1378
1379        async def foo():
1380            nonlocal CNT
1381            async with CM():
1382                CNT += 1
1383
1384        with self.assertRaises(ZeroDivisionError):
1385            run_async(foo())
1386
1387        self.assertEqual(CNT, 1)
1388
1389    def test_with_10(self):
1390        CNT = 0
1391
1392        class CM:
1393            async def __aenter__(self):
1394                return self
1395
1396            async def __aexit__(self, *e):
1397                1/0
1398
1399        async def foo():
1400            nonlocal CNT
1401            async with CM():
1402                async with CM():
1403                    raise RuntimeError
1404
1405        try:
1406            run_async(foo())
1407        except ZeroDivisionError as exc:
1408            self.assertTrue(exc.__context__ is not None)
1409            self.assertTrue(isinstance(exc.__context__, ZeroDivisionError))
1410            self.assertTrue(isinstance(exc.__context__.__context__,
1411                                       RuntimeError))
1412        else:
1413            self.fail('exception from __aexit__ did not propagate')
1414
1415    def test_with_11(self):
1416        CNT = 0
1417
1418        class CM:
1419            async def __aenter__(self):
1420                raise NotImplementedError
1421
1422            async def __aexit__(self, *e):
1423                1/0
1424
1425        async def foo():
1426            nonlocal CNT
1427            async with CM():
1428                raise RuntimeError
1429
1430        try:
1431            run_async(foo())
1432        except NotImplementedError as exc:
1433            self.assertTrue(exc.__context__ is None)
1434        else:
1435            self.fail('exception from __aenter__ did not propagate')
1436
1437    def test_with_12(self):
1438        CNT = 0
1439
1440        class CM:
1441            async def __aenter__(self):
1442                return self
1443
1444            async def __aexit__(self, *e):
1445                return True
1446
1447        async def foo():
1448            nonlocal CNT
1449            async with CM() as cm:
1450                self.assertIs(cm.__class__, CM)
1451                raise RuntimeError
1452
1453        run_async(foo())
1454
1455    def test_with_13(self):
1456        CNT = 0
1457
1458        class CM:
1459            async def __aenter__(self):
1460                1/0
1461
1462            async def __aexit__(self, *e):
1463                return True
1464
1465        async def foo():
1466            nonlocal CNT
1467            CNT += 1
1468            async with CM():
1469                CNT += 1000
1470            CNT += 10000
1471
1472        with self.assertRaises(ZeroDivisionError):
1473            run_async(foo())
1474        self.assertEqual(CNT, 1)
1475
1476    def test_for_1(self):
1477        aiter_calls = 0
1478
1479        class AsyncIter:
1480            def __init__(self):
1481                self.i = 0
1482
1483            def __aiter__(self):
1484                nonlocal aiter_calls
1485                aiter_calls += 1
1486                return self
1487
1488            async def __anext__(self):
1489                self.i += 1
1490
1491                if not (self.i % 10):
1492                    await AsyncYield(self.i * 10)
1493
1494                if self.i > 100:
1495                    raise StopAsyncIteration
1496
1497                return self.i, self.i
1498
1499
1500        buffer = []
1501        async def test1():
1502            async for i1, i2 in AsyncIter():
1503                buffer.append(i1 + i2)
1504
1505        yielded, _ = run_async(test1())
1506        # Make sure that __aiter__ was called only once
1507        self.assertEqual(aiter_calls, 1)
1508        self.assertEqual(yielded, [i * 100 for i in range(1, 11)])
1509        self.assertEqual(buffer, [i*2 for i in range(1, 101)])
1510
1511
1512        buffer = []
1513        async def test2():
1514            nonlocal buffer
1515            async for i in AsyncIter():
1516                buffer.append(i[0])
1517                if i[0] == 20:
1518                    break
1519            else:
1520                buffer.append('what?')
1521            buffer.append('end')
1522
1523        yielded, _ = run_async(test2())
1524        # Make sure that __aiter__ was called only once
1525        self.assertEqual(aiter_calls, 2)
1526        self.assertEqual(yielded, [100, 200])
1527        self.assertEqual(buffer, [i for i in range(1, 21)] + ['end'])
1528
1529
1530        buffer = []
1531        async def test3():
1532            nonlocal buffer
1533            async for i in AsyncIter():
1534                if i[0] > 20:
1535                    continue
1536                buffer.append(i[0])
1537            else:
1538                buffer.append('what?')
1539            buffer.append('end')
1540
1541        yielded, _ = run_async(test3())
1542        # Make sure that __aiter__ was called only once
1543        self.assertEqual(aiter_calls, 3)
1544        self.assertEqual(yielded, [i * 100 for i in range(1, 11)])
1545        self.assertEqual(buffer, [i for i in range(1, 21)] +
1546                                 ['what?', 'end'])
1547
1548    def test_for_2(self):
1549        tup = (1, 2, 3)
1550        refs_before = sys.getrefcount(tup)
1551
1552        async def foo():
1553            async for i in tup:
1554                print('never going to happen')
1555
1556        with self.assertRaisesRegex(
1557                TypeError, "async for' requires an object.*__aiter__.*tuple"):
1558
1559            run_async(foo())
1560
1561        self.assertEqual(sys.getrefcount(tup), refs_before)
1562
1563    def test_for_3(self):
1564        class I:
1565            def __aiter__(self):
1566                return self
1567
1568        aiter = I()
1569        refs_before = sys.getrefcount(aiter)
1570
1571        async def foo():
1572            async for i in aiter:
1573                print('never going to happen')
1574
1575        with self.assertRaisesRegex(
1576                TypeError,
1577                r"that does not implement __anext__"):
1578
1579            run_async(foo())
1580
1581        self.assertEqual(sys.getrefcount(aiter), refs_before)
1582
1583    def test_for_4(self):
1584        class I:
1585            def __aiter__(self):
1586                return self
1587
1588            def __anext__(self):
1589                return ()
1590
1591        aiter = I()
1592        refs_before = sys.getrefcount(aiter)
1593
1594        async def foo():
1595            async for i in aiter:
1596                print('never going to happen')
1597
1598        with self.assertRaisesRegex(
1599                TypeError,
1600                "async for' received an invalid object.*__anext__.*tuple"):
1601
1602            run_async(foo())
1603
1604        self.assertEqual(sys.getrefcount(aiter), refs_before)
1605
1606    def test_for_6(self):
1607        I = 0
1608
1609        class Manager:
1610            async def __aenter__(self):
1611                nonlocal I
1612                I += 10000
1613
1614            async def __aexit__(self, *args):
1615                nonlocal I
1616                I += 100000
1617
1618        class Iterable:
1619            def __init__(self):
1620                self.i = 0
1621
1622            def __aiter__(self):
1623                return self
1624
1625            async def __anext__(self):
1626                if self.i > 10:
1627                    raise StopAsyncIteration
1628                self.i += 1
1629                return self.i
1630
1631        ##############
1632
1633        manager = Manager()
1634        iterable = Iterable()
1635        mrefs_before = sys.getrefcount(manager)
1636        irefs_before = sys.getrefcount(iterable)
1637
1638        async def main():
1639            nonlocal I
1640
1641            async with manager:
1642                async for i in iterable:
1643                    I += 1
1644            I += 1000
1645
1646        with warnings.catch_warnings():
1647            warnings.simplefilter("error")
1648            # Test that __aiter__ that returns an asynchronous iterator
1649            # directly does not throw any warnings.
1650            run_async(main())
1651        self.assertEqual(I, 111011)
1652
1653        self.assertEqual(sys.getrefcount(manager), mrefs_before)
1654        self.assertEqual(sys.getrefcount(iterable), irefs_before)
1655
1656        ##############
1657
1658        async def main():
1659            nonlocal I
1660
1661            async with Manager():
1662                async for i in Iterable():
1663                    I += 1
1664            I += 1000
1665
1666            async with Manager():
1667                async for i in Iterable():
1668                    I += 1
1669            I += 1000
1670
1671        run_async(main())
1672        self.assertEqual(I, 333033)
1673
1674        ##############
1675
1676        async def main():
1677            nonlocal I
1678
1679            async with Manager():
1680                I += 100
1681                async for i in Iterable():
1682                    I += 1
1683                else:
1684                    I += 10000000
1685            I += 1000
1686
1687            async with Manager():
1688                I += 100
1689                async for i in Iterable():
1690                    I += 1
1691                else:
1692                    I += 10000000
1693            I += 1000
1694
1695        run_async(main())
1696        self.assertEqual(I, 20555255)
1697
1698    def test_for_7(self):
1699        CNT = 0
1700        class AI:
1701            def __aiter__(self):
1702                1/0
1703        async def foo():
1704            nonlocal CNT
1705            async for i in AI():
1706                CNT += 1
1707            CNT += 10
1708        with self.assertRaises(ZeroDivisionError):
1709            run_async(foo())
1710        self.assertEqual(CNT, 0)
1711
1712    def test_for_8(self):
1713        CNT = 0
1714        class AI:
1715            def __aiter__(self):
1716                1/0
1717        async def foo():
1718            nonlocal CNT
1719            async for i in AI():
1720                CNT += 1
1721            CNT += 10
1722        with self.assertRaises(ZeroDivisionError):
1723            with warnings.catch_warnings():
1724                warnings.simplefilter("error")
1725                # Test that if __aiter__ raises an exception it propagates
1726                # without any kind of warning.
1727                run_async(foo())
1728        self.assertEqual(CNT, 0)
1729
1730    def test_for_11(self):
1731        class F:
1732            def __aiter__(self):
1733                return self
1734            def __anext__(self):
1735                return self
1736            def __await__(self):
1737                1 / 0
1738
1739        async def main():
1740            async for _ in F():
1741                pass
1742
1743        with self.assertRaisesRegex(TypeError,
1744                                    'an invalid object from __anext__') as c:
1745            main().send(None)
1746
1747        err = c.exception
1748        self.assertIsInstance(err.__cause__, ZeroDivisionError)
1749
1750    def test_for_tuple(self):
1751        class Done(Exception): pass
1752
1753        class AIter(tuple):
1754            i = 0
1755            def __aiter__(self):
1756                return self
1757            async def __anext__(self):
1758                if self.i >= len(self):
1759                    raise StopAsyncIteration
1760                self.i += 1
1761                return self[self.i - 1]
1762
1763        result = []
1764        async def foo():
1765            async for i in AIter([42]):
1766                result.append(i)
1767            raise Done
1768
1769        with self.assertRaises(Done):
1770            foo().send(None)
1771        self.assertEqual(result, [42])
1772
1773    def test_for_stop_iteration(self):
1774        class Done(Exception): pass
1775
1776        class AIter(StopIteration):
1777            i = 0
1778            def __aiter__(self):
1779                return self
1780            async def __anext__(self):
1781                if self.i:
1782                    raise StopAsyncIteration
1783                self.i += 1
1784                return self.value
1785
1786        result = []
1787        async def foo():
1788            async for i in AIter(42):
1789                result.append(i)
1790            raise Done
1791
1792        with self.assertRaises(Done):
1793            foo().send(None)
1794        self.assertEqual(result, [42])
1795
1796    def test_comp_1(self):
1797        async def f(i):
1798            return i
1799
1800        async def run_list():
1801            return [await c for c in [f(1), f(41)]]
1802
1803        async def run_set():
1804            return {await c for c in [f(1), f(41)]}
1805
1806        async def run_dict1():
1807            return {await c: 'a' for c in [f(1), f(41)]}
1808
1809        async def run_dict2():
1810            return {i: await c for i, c in enumerate([f(1), f(41)])}
1811
1812        self.assertEqual(run_async(run_list()), ([], [1, 41]))
1813        self.assertEqual(run_async(run_set()), ([], {1, 41}))
1814        self.assertEqual(run_async(run_dict1()), ([], {1: 'a', 41: 'a'}))
1815        self.assertEqual(run_async(run_dict2()), ([], {0: 1, 1: 41}))
1816
1817    def test_comp_2(self):
1818        async def f(i):
1819            return i
1820
1821        async def run_list():
1822            return [s for c in [f(''), f('abc'), f(''), f(['de', 'fg'])]
1823                    for s in await c]
1824
1825        self.assertEqual(
1826            run_async(run_list()),
1827            ([], ['a', 'b', 'c', 'de', 'fg']))
1828
1829        async def run_set():
1830            return {d
1831                    for c in [f([f([10, 30]),
1832                                 f([20])])]
1833                    for s in await c
1834                    for d in await s}
1835
1836        self.assertEqual(
1837            run_async(run_set()),
1838            ([], {10, 20, 30}))
1839
1840        async def run_set2():
1841            return {await s
1842                    for c in [f([f(10), f(20)])]
1843                    for s in await c}
1844
1845        self.assertEqual(
1846            run_async(run_set2()),
1847            ([], {10, 20}))
1848
1849    def test_comp_3(self):
1850        async def f(it):
1851            for i in it:
1852                yield i
1853
1854        async def run_list():
1855            return [i + 1 async for i in f([10, 20])]
1856        self.assertEqual(
1857            run_async(run_list()),
1858            ([], [11, 21]))
1859
1860        async def run_set():
1861            return {i + 1 async for i in f([10, 20])}
1862        self.assertEqual(
1863            run_async(run_set()),
1864            ([], {11, 21}))
1865
1866        async def run_dict():
1867            return {i + 1: i + 2 async for i in f([10, 20])}
1868        self.assertEqual(
1869            run_async(run_dict()),
1870            ([], {11: 12, 21: 22}))
1871
1872        async def run_gen():
1873            gen = (i + 1 async for i in f([10, 20]))
1874            return [g + 100 async for g in gen]
1875        self.assertEqual(
1876            run_async(run_gen()),
1877            ([], [111, 121]))
1878
1879    def test_comp_4(self):
1880        async def f(it):
1881            for i in it:
1882                yield i
1883
1884        async def run_list():
1885            return [i + 1 async for i in f([10, 20]) if i > 10]
1886        self.assertEqual(
1887            run_async(run_list()),
1888            ([], [21]))
1889
1890        async def run_set():
1891            return {i + 1 async for i in f([10, 20]) if i > 10}
1892        self.assertEqual(
1893            run_async(run_set()),
1894            ([], {21}))
1895
1896        async def run_dict():
1897            return {i + 1: i + 2 async for i in f([10, 20]) if i > 10}
1898        self.assertEqual(
1899            run_async(run_dict()),
1900            ([], {21: 22}))
1901
1902        async def run_gen():
1903            gen = (i + 1 async for i in f([10, 20]) if i > 10)
1904            return [g + 100 async for g in gen]
1905        self.assertEqual(
1906            run_async(run_gen()),
1907            ([], [121]))
1908
1909    def test_comp_4_2(self):
1910        async def f(it):
1911            for i in it:
1912                yield i
1913
1914        async def run_list():
1915            return [i + 10 async for i in f(range(5)) if 0 < i < 4]
1916        self.assertEqual(
1917            run_async(run_list()),
1918            ([], [11, 12, 13]))
1919
1920        async def run_set():
1921            return {i + 10 async for i in f(range(5)) if 0 < i < 4}
1922        self.assertEqual(
1923            run_async(run_set()),
1924            ([], {11, 12, 13}))
1925
1926        async def run_dict():
1927            return {i + 10: i + 100 async for i in f(range(5)) if 0 < i < 4}
1928        self.assertEqual(
1929            run_async(run_dict()),
1930            ([], {11: 101, 12: 102, 13: 103}))
1931
1932        async def run_gen():
1933            gen = (i + 10 async for i in f(range(5)) if 0 < i < 4)
1934            return [g + 100 async for g in gen]
1935        self.assertEqual(
1936            run_async(run_gen()),
1937            ([], [111, 112, 113]))
1938
1939    def test_comp_5(self):
1940        async def f(it):
1941            for i in it:
1942                yield i
1943
1944        async def run_list():
1945            return [i + 1 for pair in ([10, 20], [30, 40]) if pair[0] > 10
1946                    async for i in f(pair) if i > 30]
1947        self.assertEqual(
1948            run_async(run_list()),
1949            ([], [41]))
1950
1951    def test_comp_6(self):
1952        async def f(it):
1953            for i in it:
1954                yield i
1955
1956        async def run_list():
1957            return [i + 1 async for seq in f([(10, 20), (30,)])
1958                    for i in seq]
1959
1960        self.assertEqual(
1961            run_async(run_list()),
1962            ([], [11, 21, 31]))
1963
1964    def test_comp_7(self):
1965        async def f():
1966            yield 1
1967            yield 2
1968            raise Exception('aaa')
1969
1970        async def run_list():
1971            return [i async for i in f()]
1972
1973        with self.assertRaisesRegex(Exception, 'aaa'):
1974            run_async(run_list())
1975
1976    def test_comp_8(self):
1977        async def f():
1978            return [i for i in [1, 2, 3]]
1979
1980        self.assertEqual(
1981            run_async(f()),
1982            ([], [1, 2, 3]))
1983
1984    def test_comp_9(self):
1985        async def gen():
1986            yield 1
1987            yield 2
1988        async def f():
1989            l = [i async for i in gen()]
1990            return [i for i in l]
1991
1992        self.assertEqual(
1993            run_async(f()),
1994            ([], [1, 2]))
1995
1996    def test_comp_10(self):
1997        async def f():
1998            xx = {i for i in [1, 2, 3]}
1999            return {x: x for x in xx}
2000
2001        self.assertEqual(
2002            run_async(f()),
2003            ([], {1: 1, 2: 2, 3: 3}))
2004
2005    def test_copy(self):
2006        async def func(): pass
2007        coro = func()
2008        with self.assertRaises(TypeError):
2009            copy.copy(coro)
2010
2011        aw = coro.__await__()
2012        try:
2013            with self.assertRaises(TypeError):
2014                copy.copy(aw)
2015        finally:
2016            aw.close()
2017
2018    def test_pickle(self):
2019        async def func(): pass
2020        coro = func()
2021        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
2022            with self.assertRaises((TypeError, pickle.PicklingError)):
2023                pickle.dumps(coro, proto)
2024
2025        aw = coro.__await__()
2026        try:
2027            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
2028                with self.assertRaises((TypeError, pickle.PicklingError)):
2029                    pickle.dumps(aw, proto)
2030        finally:
2031            aw.close()
2032
2033    def test_fatal_coro_warning(self):
2034        # Issue 27811
2035        async def func(): pass
2036        with warnings.catch_warnings(), support.captured_stderr() as stderr:
2037            warnings.filterwarnings("error")
2038            func()
2039            support.gc_collect()
2040        self.assertIn("was never awaited", stderr.getvalue())
2041
2042
2043class CoroAsyncIOCompatTest(unittest.TestCase):
2044
2045    def test_asyncio_1(self):
2046        # asyncio cannot be imported when Python is compiled without thread
2047        # support
2048        asyncio = support.import_module('asyncio')
2049
2050        class MyException(Exception):
2051            pass
2052
2053        buffer = []
2054
2055        class CM:
2056            async def __aenter__(self):
2057                buffer.append(1)
2058                await asyncio.sleep(0.01)
2059                buffer.append(2)
2060                return self
2061
2062            async def __aexit__(self, exc_type, exc_val, exc_tb):
2063                await asyncio.sleep(0.01)
2064                buffer.append(exc_type.__name__)
2065
2066        async def f():
2067            async with CM() as c:
2068                await asyncio.sleep(0.01)
2069                raise MyException
2070            buffer.append('unreachable')
2071
2072        loop = asyncio.new_event_loop()
2073        asyncio.set_event_loop(loop)
2074        try:
2075            loop.run_until_complete(f())
2076        except MyException:
2077            pass
2078        finally:
2079            loop.close()
2080            asyncio.set_event_loop(None)
2081
2082        self.assertEqual(buffer, [1, 2, 'MyException'])
2083
2084
2085class SysSetCoroWrapperTest(unittest.TestCase):
2086
2087    def test_set_wrapper_1(self):
2088        async def foo():
2089            return 'spam'
2090
2091        wrapped = None
2092        def wrap(gen):
2093            nonlocal wrapped
2094            wrapped = gen
2095            return gen
2096
2097        with self.assertWarns(DeprecationWarning):
2098            self.assertIsNone(sys.get_coroutine_wrapper())
2099
2100        with self.assertWarns(DeprecationWarning):
2101            sys.set_coroutine_wrapper(wrap)
2102        with self.assertWarns(DeprecationWarning):
2103            self.assertIs(sys.get_coroutine_wrapper(), wrap)
2104        try:
2105            f = foo()
2106            self.assertTrue(wrapped)
2107
2108            self.assertEqual(run_async(f), ([], 'spam'))
2109        finally:
2110            with self.assertWarns(DeprecationWarning):
2111                sys.set_coroutine_wrapper(None)
2112            f.close()
2113
2114        with self.assertWarns(DeprecationWarning):
2115            self.assertIsNone(sys.get_coroutine_wrapper())
2116
2117        wrapped = None
2118        coro = foo()
2119        self.assertFalse(wrapped)
2120        coro.close()
2121
2122    def test_set_wrapper_2(self):
2123        with self.assertWarns(DeprecationWarning):
2124            self.assertIsNone(sys.get_coroutine_wrapper())
2125        with self.assertRaisesRegex(TypeError, "callable expected, got int"):
2126            with self.assertWarns(DeprecationWarning):
2127                sys.set_coroutine_wrapper(1)
2128        with self.assertWarns(DeprecationWarning):
2129            self.assertIsNone(sys.get_coroutine_wrapper())
2130
2131    def test_set_wrapper_3(self):
2132        async def foo():
2133            return 'spam'
2134
2135        def wrapper(coro):
2136            async def wrap(coro):
2137                return await coro
2138            return wrap(coro)
2139
2140        with self.assertWarns(DeprecationWarning):
2141            sys.set_coroutine_wrapper(wrapper)
2142        try:
2143            with silence_coro_gc(), self.assertRaisesRegex(
2144                    RuntimeError,
2145                    r"coroutine wrapper.*\.wrapper at 0x.*attempted to "
2146                    r"recursively wrap .* wrap .*"):
2147
2148                foo()
2149
2150        finally:
2151            with self.assertWarns(DeprecationWarning):
2152                sys.set_coroutine_wrapper(None)
2153
2154    def test_set_wrapper_4(self):
2155        @types.coroutine
2156        def foo():
2157            return 'spam'
2158
2159        wrapped = None
2160        def wrap(gen):
2161            nonlocal wrapped
2162            wrapped = gen
2163            return gen
2164
2165        with self.assertWarns(DeprecationWarning):
2166            sys.set_coroutine_wrapper(wrap)
2167        try:
2168            foo()
2169            self.assertIs(
2170                wrapped, None,
2171                "generator-based coroutine was wrapped via "
2172                "sys.set_coroutine_wrapper")
2173        finally:
2174            with self.assertWarns(DeprecationWarning):
2175                sys.set_coroutine_wrapper(None)
2176
2177
2178class OriginTrackingTest(unittest.TestCase):
2179    def here(self):
2180        info = inspect.getframeinfo(inspect.currentframe().f_back)
2181        return (info.filename, info.lineno)
2182
2183    def test_origin_tracking(self):
2184        orig_depth = sys.get_coroutine_origin_tracking_depth()
2185        try:
2186            async def corofn():
2187                pass
2188
2189            sys.set_coroutine_origin_tracking_depth(0)
2190            self.assertEqual(sys.get_coroutine_origin_tracking_depth(), 0)
2191
2192            with contextlib.closing(corofn()) as coro:
2193                self.assertIsNone(coro.cr_origin)
2194
2195            sys.set_coroutine_origin_tracking_depth(1)
2196            self.assertEqual(sys.get_coroutine_origin_tracking_depth(), 1)
2197
2198            fname, lineno = self.here()
2199            with contextlib.closing(corofn()) as coro:
2200                self.assertEqual(coro.cr_origin,
2201                                 ((fname, lineno + 1, "test_origin_tracking"),))
2202
2203            sys.set_coroutine_origin_tracking_depth(2)
2204            self.assertEqual(sys.get_coroutine_origin_tracking_depth(), 2)
2205
2206            def nested():
2207                return (self.here(), corofn())
2208            fname, lineno = self.here()
2209            ((nested_fname, nested_lineno), coro) = nested()
2210            with contextlib.closing(coro):
2211                self.assertEqual(coro.cr_origin,
2212                                 ((nested_fname, nested_lineno, "nested"),
2213                                  (fname, lineno + 1, "test_origin_tracking")))
2214
2215            # Check we handle running out of frames correctly
2216            sys.set_coroutine_origin_tracking_depth(1000)
2217            with contextlib.closing(corofn()) as coro:
2218                self.assertTrue(2 < len(coro.cr_origin) < 1000)
2219
2220            # We can't set depth negative
2221            with self.assertRaises(ValueError):
2222                sys.set_coroutine_origin_tracking_depth(-1)
2223            # And trying leaves it unchanged
2224            self.assertEqual(sys.get_coroutine_origin_tracking_depth(), 1000)
2225
2226        finally:
2227            sys.set_coroutine_origin_tracking_depth(orig_depth)
2228
2229    def test_origin_tracking_warning(self):
2230        async def corofn():
2231            pass
2232
2233        a1_filename, a1_lineno = self.here()
2234        def a1():
2235            return corofn()  # comment in a1
2236        a1_lineno += 2
2237
2238        a2_filename, a2_lineno = self.here()
2239        def a2():
2240            return a1()  # comment in a2
2241        a2_lineno += 2
2242
2243        def check(depth, msg):
2244            sys.set_coroutine_origin_tracking_depth(depth)
2245            with self.assertWarns(RuntimeWarning) as cm:
2246                a2()
2247                support.gc_collect()
2248            self.assertEqual(msg, str(cm.warning))
2249
2250        orig_depth = sys.get_coroutine_origin_tracking_depth()
2251        try:
2252            msg = check(0, f"coroutine '{corofn.__qualname__}' was never awaited")
2253            check(1, "".join([
2254                f"coroutine '{corofn.__qualname__}' was never awaited\n",
2255                "Coroutine created at (most recent call last)\n",
2256                f'  File "{a1_filename}", line {a1_lineno}, in a1\n',
2257                f'    return corofn()  # comment in a1',
2258            ]))
2259            check(2, "".join([
2260                f"coroutine '{corofn.__qualname__}' was never awaited\n",
2261                "Coroutine created at (most recent call last)\n",
2262                f'  File "{a2_filename}", line {a2_lineno}, in a2\n',
2263                f'    return a1()  # comment in a2\n',
2264                f'  File "{a1_filename}", line {a1_lineno}, in a1\n',
2265                f'    return corofn()  # comment in a1',
2266            ]))
2267
2268        finally:
2269            sys.set_coroutine_origin_tracking_depth(orig_depth)
2270
2271    def test_unawaited_warning_when_module_broken(self):
2272        # Make sure we don't blow up too bad if
2273        # warnings._warn_unawaited_coroutine is broken somehow (e.g. because
2274        # of shutdown problems)
2275        async def corofn():
2276            pass
2277
2278        orig_wuc = warnings._warn_unawaited_coroutine
2279        try:
2280            warnings._warn_unawaited_coroutine = lambda coro: 1/0
2281            with support.captured_stderr() as stream:
2282                corofn()
2283                support.gc_collect()
2284            self.assertIn("Exception ignored in", stream.getvalue())
2285            self.assertIn("ZeroDivisionError", stream.getvalue())
2286            self.assertIn("was never awaited", stream.getvalue())
2287
2288            del warnings._warn_unawaited_coroutine
2289            with support.captured_stderr() as stream:
2290                corofn()
2291                support.gc_collect()
2292            self.assertIn("was never awaited", stream.getvalue())
2293
2294        finally:
2295            warnings._warn_unawaited_coroutine = orig_wuc
2296
2297
2298class UnawaitedWarningDuringShutdownTest(unittest.TestCase):
2299    # https://bugs.python.org/issue32591#msg310726
2300    def test_unawaited_warning_during_shutdown(self):
2301        code = ("import asyncio\n"
2302                "async def f(): pass\n"
2303                "asyncio.gather(f())\n")
2304        assert_python_ok("-c", code)
2305
2306        code = ("import sys\n"
2307                "async def f(): pass\n"
2308                "sys.coro = f()\n")
2309        assert_python_ok("-c", code)
2310
2311        code = ("import sys\n"
2312                "async def f(): pass\n"
2313                "sys.corocycle = [f()]\n"
2314                "sys.corocycle.append(sys.corocycle)\n")
2315        assert_python_ok("-c", code)
2316
2317
2318@support.cpython_only
2319class CAPITest(unittest.TestCase):
2320
2321    def test_tp_await_1(self):
2322        from _testcapi import awaitType as at
2323
2324        async def foo():
2325            future = at(iter([1]))
2326            return (await future)
2327
2328        self.assertEqual(foo().send(None), 1)
2329
2330    def test_tp_await_2(self):
2331        # Test tp_await to __await__ mapping
2332        from _testcapi import awaitType as at
2333        future = at(iter([1]))
2334        self.assertEqual(next(future.__await__()), 1)
2335
2336    def test_tp_await_3(self):
2337        from _testcapi import awaitType as at
2338
2339        async def foo():
2340            future = at(1)
2341            return (await future)
2342
2343        with self.assertRaisesRegex(
2344                TypeError, "__await__.*returned non-iterator of type 'int'"):
2345            self.assertEqual(foo().send(None), 1)
2346
2347
2348if __name__=="__main__":
2349    unittest.main()
2350