1import contextlib
2import os
3import threading
4from textwrap import dedent
5import unittest
6import time
7
8import _xxsubinterpreters as _interpreters
9from test.support import interpreters
10
11
12def _captured_script(script):
13    r, w = os.pipe()
14    indented = script.replace('\n', '\n                ')
15    wrapped = dedent(f"""
16        import contextlib
17        with open({w}, 'w', encoding='utf-8') as spipe:
18            with contextlib.redirect_stdout(spipe):
19                {indented}
20        """)
21    return wrapped, open(r, encoding='utf-8')
22
23
24def clean_up_interpreters():
25    for interp in interpreters.list_all():
26        if interp.id == 0:  # main
27            continue
28        try:
29            interp.close()
30        except RuntimeError:
31            pass  # already destroyed
32
33
34def _run_output(interp, request, channels=None):
35    script, rpipe = _captured_script(request)
36    with rpipe:
37        interp.run(script, channels=channels)
38        return rpipe.read()
39
40
41@contextlib.contextmanager
42def _running(interp):
43    r, w = os.pipe()
44    def run():
45        interp.run(dedent(f"""
46            # wait for "signal"
47            with open({r}) as rpipe:
48                rpipe.read()
49            """))
50
51    t = threading.Thread(target=run)
52    t.start()
53
54    yield
55
56    with open(w, 'w') as spipe:
57        spipe.write('done')
58    t.join()
59
60
61class TestBase(unittest.TestCase):
62
63    def tearDown(self):
64        clean_up_interpreters()
65
66
67class CreateTests(TestBase):
68
69    def test_in_main(self):
70        interp = interpreters.create()
71        self.assertIsInstance(interp, interpreters.Interpreter)
72        self.assertIn(interp, interpreters.list_all())
73
74    def test_in_thread(self):
75        lock = threading.Lock()
76        interp = None
77        def f():
78            nonlocal interp
79            interp = interpreters.create()
80            lock.acquire()
81            lock.release()
82        t = threading.Thread(target=f)
83        with lock:
84            t.start()
85        t.join()
86        self.assertIn(interp, interpreters.list_all())
87
88    def test_in_subinterpreter(self):
89        main, = interpreters.list_all()
90        interp = interpreters.create()
91        out = _run_output(interp, dedent("""
92            from test.support import interpreters
93            interp = interpreters.create()
94            print(interp.id)
95            """))
96        interp2 = interpreters.Interpreter(int(out))
97        self.assertEqual(interpreters.list_all(), [main, interp, interp2])
98
99    def test_after_destroy_all(self):
100        before = set(interpreters.list_all())
101        # Create 3 subinterpreters.
102        interp_lst = []
103        for _ in range(3):
104            interps = interpreters.create()
105            interp_lst.append(interps)
106        # Now destroy them.
107        for interp in interp_lst:
108            interp.close()
109        # Finally, create another.
110        interp = interpreters.create()
111        self.assertEqual(set(interpreters.list_all()), before | {interp})
112
113    def test_after_destroy_some(self):
114        before = set(interpreters.list_all())
115        # Create 3 subinterpreters.
116        interp1 = interpreters.create()
117        interp2 = interpreters.create()
118        interp3 = interpreters.create()
119        # Now destroy 2 of them.
120        interp1.close()
121        interp2.close()
122        # Finally, create another.
123        interp = interpreters.create()
124        self.assertEqual(set(interpreters.list_all()), before | {interp3, interp})
125
126
127class GetCurrentTests(TestBase):
128
129    def test_main(self):
130        main = interpreters.get_main()
131        current = interpreters.get_current()
132        self.assertEqual(current, main)
133
134    def test_subinterpreter(self):
135        main = _interpreters.get_main()
136        interp = interpreters.create()
137        out = _run_output(interp, dedent("""
138            from test.support import interpreters
139            cur = interpreters.get_current()
140            print(cur.id)
141            """))
142        current = interpreters.Interpreter(int(out))
143        self.assertNotEqual(current, main)
144
145
146class ListAllTests(TestBase):
147
148    def test_initial(self):
149        interps = interpreters.list_all()
150        self.assertEqual(1, len(interps))
151
152    def test_after_creating(self):
153        main = interpreters.get_current()
154        first = interpreters.create()
155        second = interpreters.create()
156
157        ids = []
158        for interp in interpreters.list_all():
159            ids.append(interp.id)
160
161        self.assertEqual(ids, [main.id, first.id, second.id])
162
163    def test_after_destroying(self):
164        main = interpreters.get_current()
165        first = interpreters.create()
166        second = interpreters.create()
167        first.close()
168
169        ids = []
170        for interp in interpreters.list_all():
171            ids.append(interp.id)
172
173        self.assertEqual(ids, [main.id, second.id])
174
175
176class TestInterpreterAttrs(TestBase):
177
178    def test_id_type(self):
179        main = interpreters.get_main()
180        current = interpreters.get_current()
181        interp = interpreters.create()
182        self.assertIsInstance(main.id, _interpreters.InterpreterID)
183        self.assertIsInstance(current.id, _interpreters.InterpreterID)
184        self.assertIsInstance(interp.id, _interpreters.InterpreterID)
185
186    def test_main_id(self):
187        main = interpreters.get_main()
188        self.assertEqual(main.id, 0)
189
190    def test_custom_id(self):
191        interp = interpreters.Interpreter(1)
192        self.assertEqual(interp.id, 1)
193
194        with self.assertRaises(TypeError):
195            interpreters.Interpreter('1')
196
197    def test_id_readonly(self):
198        interp = interpreters.Interpreter(1)
199        with self.assertRaises(AttributeError):
200            interp.id = 2
201
202    @unittest.skip('not ready yet (see bpo-32604)')
203    def test_main_isolated(self):
204        main = interpreters.get_main()
205        self.assertFalse(main.isolated)
206
207    @unittest.skip('not ready yet (see bpo-32604)')
208    def test_subinterpreter_isolated_default(self):
209        interp = interpreters.create()
210        self.assertFalse(interp.isolated)
211
212    def test_subinterpreter_isolated_explicit(self):
213        interp1 = interpreters.create(isolated=True)
214        interp2 = interpreters.create(isolated=False)
215        self.assertTrue(interp1.isolated)
216        self.assertFalse(interp2.isolated)
217
218    @unittest.skip('not ready yet (see bpo-32604)')
219    def test_custom_isolated_default(self):
220        interp = interpreters.Interpreter(1)
221        self.assertFalse(interp.isolated)
222
223    def test_custom_isolated_explicit(self):
224        interp1 = interpreters.Interpreter(1, isolated=True)
225        interp2 = interpreters.Interpreter(1, isolated=False)
226        self.assertTrue(interp1.isolated)
227        self.assertFalse(interp2.isolated)
228
229    def test_isolated_readonly(self):
230        interp = interpreters.Interpreter(1)
231        with self.assertRaises(AttributeError):
232            interp.isolated = True
233
234    def test_equality(self):
235        interp1 = interpreters.create()
236        interp2 = interpreters.create()
237        self.assertEqual(interp1, interp1)
238        self.assertNotEqual(interp1, interp2)
239
240
241class TestInterpreterIsRunning(TestBase):
242
243    def test_main(self):
244        main = interpreters.get_main()
245        self.assertTrue(main.is_running())
246
247    @unittest.skip('Fails on FreeBSD')
248    def test_subinterpreter(self):
249        interp = interpreters.create()
250        self.assertFalse(interp.is_running())
251
252        with _running(interp):
253            self.assertTrue(interp.is_running())
254        self.assertFalse(interp.is_running())
255
256    def test_from_subinterpreter(self):
257        interp = interpreters.create()
258        out = _run_output(interp, dedent(f"""
259            import _xxsubinterpreters as _interpreters
260            if _interpreters.is_running({interp.id}):
261                print(True)
262            else:
263                print(False)
264            """))
265        self.assertEqual(out.strip(), 'True')
266
267    def test_already_destroyed(self):
268        interp = interpreters.create()
269        interp.close()
270        with self.assertRaises(RuntimeError):
271            interp.is_running()
272
273    def test_does_not_exist(self):
274        interp = interpreters.Interpreter(1_000_000)
275        with self.assertRaises(RuntimeError):
276            interp.is_running()
277
278    def test_bad_id(self):
279        interp = interpreters.Interpreter(-1)
280        with self.assertRaises(ValueError):
281            interp.is_running()
282
283
284class TestInterpreterClose(TestBase):
285
286    def test_basic(self):
287        main = interpreters.get_main()
288        interp1 = interpreters.create()
289        interp2 = interpreters.create()
290        interp3 = interpreters.create()
291        self.assertEqual(set(interpreters.list_all()),
292                         {main, interp1, interp2, interp3})
293        interp2.close()
294        self.assertEqual(set(interpreters.list_all()),
295                         {main, interp1, interp3})
296
297    def test_all(self):
298        before = set(interpreters.list_all())
299        interps = set()
300        for _ in range(3):
301            interp = interpreters.create()
302            interps.add(interp)
303        self.assertEqual(set(interpreters.list_all()), before | interps)
304        for interp in interps:
305            interp.close()
306        self.assertEqual(set(interpreters.list_all()), before)
307
308    def test_main(self):
309        main, = interpreters.list_all()
310        with self.assertRaises(RuntimeError):
311            main.close()
312
313        def f():
314            with self.assertRaises(RuntimeError):
315                main.close()
316
317        t = threading.Thread(target=f)
318        t.start()
319        t.join()
320
321    def test_already_destroyed(self):
322        interp = interpreters.create()
323        interp.close()
324        with self.assertRaises(RuntimeError):
325            interp.close()
326
327    def test_does_not_exist(self):
328        interp = interpreters.Interpreter(1_000_000)
329        with self.assertRaises(RuntimeError):
330            interp.close()
331
332    def test_bad_id(self):
333        interp = interpreters.Interpreter(-1)
334        with self.assertRaises(ValueError):
335            interp.close()
336
337    def test_from_current(self):
338        main, = interpreters.list_all()
339        interp = interpreters.create()
340        out = _run_output(interp, dedent(f"""
341            from test.support import interpreters
342            interp = interpreters.Interpreter({int(interp.id)})
343            try:
344                interp.close()
345            except RuntimeError:
346                print('failed')
347            """))
348        self.assertEqual(out.strip(), 'failed')
349        self.assertEqual(set(interpreters.list_all()), {main, interp})
350
351    def test_from_sibling(self):
352        main, = interpreters.list_all()
353        interp1 = interpreters.create()
354        interp2 = interpreters.create()
355        self.assertEqual(set(interpreters.list_all()),
356                         {main, interp1, interp2})
357        interp1.run(dedent(f"""
358            from test.support import interpreters
359            interp2 = interpreters.Interpreter(int({interp2.id}))
360            interp2.close()
361            interp3 = interpreters.create()
362            interp3.close()
363            """))
364        self.assertEqual(set(interpreters.list_all()), {main, interp1})
365
366    def test_from_other_thread(self):
367        interp = interpreters.create()
368        def f():
369            interp.close()
370
371        t = threading.Thread(target=f)
372        t.start()
373        t.join()
374
375    @unittest.skip('Fails on FreeBSD')
376    def test_still_running(self):
377        main, = interpreters.list_all()
378        interp = interpreters.create()
379        with _running(interp):
380            with self.assertRaises(RuntimeError):
381                interp.close()
382            self.assertTrue(interp.is_running())
383
384
385class TestInterpreterRun(TestBase):
386
387    def test_success(self):
388        interp = interpreters.create()
389        script, file = _captured_script('print("it worked!", end="")')
390        with file:
391            interp.run(script)
392            out = file.read()
393
394        self.assertEqual(out, 'it worked!')
395
396    def test_in_thread(self):
397        interp = interpreters.create()
398        script, file = _captured_script('print("it worked!", end="")')
399        with file:
400            def f():
401                interp.run(script)
402
403            t = threading.Thread(target=f)
404            t.start()
405            t.join()
406            out = file.read()
407
408        self.assertEqual(out, 'it worked!')
409
410    @unittest.skipUnless(hasattr(os, 'fork'), "test needs os.fork()")
411    def test_fork(self):
412        interp = interpreters.create()
413        import tempfile
414        with tempfile.NamedTemporaryFile('w+', encoding='utf-8') as file:
415            file.write('')
416            file.flush()
417
418            expected = 'spam spam spam spam spam'
419            script = dedent(f"""
420                import os
421                try:
422                    os.fork()
423                except RuntimeError:
424                    with open('{file.name}', 'w', encoding='utf-8') as out:
425                        out.write('{expected}')
426                """)
427            interp.run(script)
428
429            file.seek(0)
430            content = file.read()
431            self.assertEqual(content, expected)
432
433    @unittest.skip('Fails on FreeBSD')
434    def test_already_running(self):
435        interp = interpreters.create()
436        with _running(interp):
437            with self.assertRaises(RuntimeError):
438                interp.run('print("spam")')
439
440    def test_does_not_exist(self):
441        interp = interpreters.Interpreter(1_000_000)
442        with self.assertRaises(RuntimeError):
443            interp.run('print("spam")')
444
445    def test_bad_id(self):
446        interp = interpreters.Interpreter(-1)
447        with self.assertRaises(ValueError):
448            interp.run('print("spam")')
449
450    def test_bad_script(self):
451        interp = interpreters.create()
452        with self.assertRaises(TypeError):
453            interp.run(10)
454
455    def test_bytes_for_script(self):
456        interp = interpreters.create()
457        with self.assertRaises(TypeError):
458            interp.run(b'print("spam")')
459
460    # test_xxsubinterpreters covers the remaining Interpreter.run() behavior.
461
462
463class TestIsShareable(TestBase):
464
465    def test_default_shareables(self):
466        shareables = [
467                # singletons
468                None,
469                # builtin objects
470                b'spam',
471                'spam',
472                10,
473                -10,
474                ]
475        for obj in shareables:
476            with self.subTest(obj):
477                shareable = interpreters.is_shareable(obj)
478                self.assertTrue(shareable)
479
480    def test_not_shareable(self):
481        class Cheese:
482            def __init__(self, name):
483                self.name = name
484            def __str__(self):
485                return self.name
486
487        class SubBytes(bytes):
488            """A subclass of a shareable type."""
489
490        not_shareables = [
491                # singletons
492                True,
493                False,
494                NotImplemented,
495                ...,
496                # builtin types and objects
497                type,
498                object,
499                object(),
500                Exception(),
501                100.0,
502                # user-defined types and objects
503                Cheese,
504                Cheese('Wensleydale'),
505                SubBytes(b'spam'),
506                ]
507        for obj in not_shareables:
508            with self.subTest(repr(obj)):
509                self.assertFalse(
510                    interpreters.is_shareable(obj))
511
512
513class TestChannels(TestBase):
514
515    def test_create(self):
516        r, s = interpreters.create_channel()
517        self.assertIsInstance(r, interpreters.RecvChannel)
518        self.assertIsInstance(s, interpreters.SendChannel)
519
520    def test_list_all(self):
521        self.assertEqual(interpreters.list_all_channels(), [])
522        created = set()
523        for _ in range(3):
524            ch = interpreters.create_channel()
525            created.add(ch)
526        after = set(interpreters.list_all_channels())
527        self.assertEqual(after, created)
528
529
530class TestRecvChannelAttrs(TestBase):
531
532    def test_id_type(self):
533        rch, _ = interpreters.create_channel()
534        self.assertIsInstance(rch.id, _interpreters.ChannelID)
535
536    def test_custom_id(self):
537        rch = interpreters.RecvChannel(1)
538        self.assertEqual(rch.id, 1)
539
540        with self.assertRaises(TypeError):
541            interpreters.RecvChannel('1')
542
543    def test_id_readonly(self):
544        rch = interpreters.RecvChannel(1)
545        with self.assertRaises(AttributeError):
546            rch.id = 2
547
548    def test_equality(self):
549        ch1, _ = interpreters.create_channel()
550        ch2, _ = interpreters.create_channel()
551        self.assertEqual(ch1, ch1)
552        self.assertNotEqual(ch1, ch2)
553
554
555class TestSendChannelAttrs(TestBase):
556
557    def test_id_type(self):
558        _, sch = interpreters.create_channel()
559        self.assertIsInstance(sch.id, _interpreters.ChannelID)
560
561    def test_custom_id(self):
562        sch = interpreters.SendChannel(1)
563        self.assertEqual(sch.id, 1)
564
565        with self.assertRaises(TypeError):
566            interpreters.SendChannel('1')
567
568    def test_id_readonly(self):
569        sch = interpreters.SendChannel(1)
570        with self.assertRaises(AttributeError):
571            sch.id = 2
572
573    def test_equality(self):
574        _, ch1 = interpreters.create_channel()
575        _, ch2 = interpreters.create_channel()
576        self.assertEqual(ch1, ch1)
577        self.assertNotEqual(ch1, ch2)
578
579
580class TestSendRecv(TestBase):
581
582    def test_send_recv_main(self):
583        r, s = interpreters.create_channel()
584        orig = b'spam'
585        s.send_nowait(orig)
586        obj = r.recv()
587
588        self.assertEqual(obj, orig)
589        self.assertIsNot(obj, orig)
590
591    def test_send_recv_same_interpreter(self):
592        interp = interpreters.create()
593        interp.run(dedent("""
594            from test.support import interpreters
595            r, s = interpreters.create_channel()
596            orig = b'spam'
597            s.send_nowait(orig)
598            obj = r.recv()
599            assert obj == orig, 'expected: obj == orig'
600            assert obj is not orig, 'expected: obj is not orig'
601            """))
602
603    @unittest.skip('broken (see BPO-...)')
604    def test_send_recv_different_interpreters(self):
605        r1, s1 = interpreters.create_channel()
606        r2, s2 = interpreters.create_channel()
607        orig1 = b'spam'
608        s1.send_nowait(orig1)
609        out = _run_output(
610            interpreters.create(),
611            dedent(f"""
612                obj1 = r.recv()
613                assert obj1 == b'spam', 'expected: obj1 == orig1'
614                # When going to another interpreter we get a copy.
615                assert id(obj1) != {id(orig1)}, 'expected: obj1 is not orig1'
616                orig2 = b'eggs'
617                print(id(orig2))
618                s.send_nowait(orig2)
619                """),
620            channels=dict(r=r1, s=s2),
621            )
622        obj2 = r2.recv()
623
624        self.assertEqual(obj2, b'eggs')
625        self.assertNotEqual(id(obj2), int(out))
626
627    def test_send_recv_different_threads(self):
628        r, s = interpreters.create_channel()
629
630        def f():
631            while True:
632                try:
633                    obj = r.recv()
634                    break
635                except interpreters.ChannelEmptyError:
636                    time.sleep(0.1)
637            s.send(obj)
638        t = threading.Thread(target=f)
639        t.start()
640
641        orig = b'spam'
642        s.send(orig)
643        t.join()
644        obj = r.recv()
645
646        self.assertEqual(obj, orig)
647        self.assertIsNot(obj, orig)
648
649    def test_send_recv_nowait_main(self):
650        r, s = interpreters.create_channel()
651        orig = b'spam'
652        s.send_nowait(orig)
653        obj = r.recv_nowait()
654
655        self.assertEqual(obj, orig)
656        self.assertIsNot(obj, orig)
657
658    def test_send_recv_nowait_main_with_default(self):
659        r, _ = interpreters.create_channel()
660        obj = r.recv_nowait(None)
661
662        self.assertIsNone(obj)
663
664    def test_send_recv_nowait_same_interpreter(self):
665        interp = interpreters.create()
666        interp.run(dedent("""
667            from test.support import interpreters
668            r, s = interpreters.create_channel()
669            orig = b'spam'
670            s.send_nowait(orig)
671            obj = r.recv_nowait()
672            assert obj == orig, 'expected: obj == orig'
673            # When going back to the same interpreter we get the same object.
674            assert obj is not orig, 'expected: obj is not orig'
675            """))
676
677    @unittest.skip('broken (see BPO-...)')
678    def test_send_recv_nowait_different_interpreters(self):
679        r1, s1 = interpreters.create_channel()
680        r2, s2 = interpreters.create_channel()
681        orig1 = b'spam'
682        s1.send_nowait(orig1)
683        out = _run_output(
684            interpreters.create(),
685            dedent(f"""
686                obj1 = r.recv_nowait()
687                assert obj1 == b'spam', 'expected: obj1 == orig1'
688                # When going to another interpreter we get a copy.
689                assert id(obj1) != {id(orig1)}, 'expected: obj1 is not orig1'
690                orig2 = b'eggs'
691                print(id(orig2))
692                s.send_nowait(orig2)
693                """),
694            channels=dict(r=r1, s=s2),
695            )
696        obj2 = r2.recv_nowait()
697
698        self.assertEqual(obj2, b'eggs')
699        self.assertNotEqual(id(obj2), int(out))
700
701    def test_recv_channel_does_not_exist(self):
702        ch = interpreters.RecvChannel(1_000_000)
703        with self.assertRaises(interpreters.ChannelNotFoundError):
704            ch.recv()
705
706    def test_send_channel_does_not_exist(self):
707        ch = interpreters.SendChannel(1_000_000)
708        with self.assertRaises(interpreters.ChannelNotFoundError):
709            ch.send(b'spam')
710
711    def test_recv_nowait_channel_does_not_exist(self):
712        ch = interpreters.RecvChannel(1_000_000)
713        with self.assertRaises(interpreters.ChannelNotFoundError):
714            ch.recv_nowait()
715
716    def test_send_nowait_channel_does_not_exist(self):
717        ch = interpreters.SendChannel(1_000_000)
718        with self.assertRaises(interpreters.ChannelNotFoundError):
719            ch.send_nowait(b'spam')
720
721    def test_recv_nowait_empty(self):
722        ch, _ = interpreters.create_channel()
723        with self.assertRaises(interpreters.ChannelEmptyError):
724            ch.recv_nowait()
725
726    def test_recv_nowait_default(self):
727        default = object()
728        rch, sch = interpreters.create_channel()
729        obj1 = rch.recv_nowait(default)
730        sch.send_nowait(None)
731        sch.send_nowait(1)
732        sch.send_nowait(b'spam')
733        sch.send_nowait(b'eggs')
734        obj2 = rch.recv_nowait(default)
735        obj3 = rch.recv_nowait(default)
736        obj4 = rch.recv_nowait()
737        obj5 = rch.recv_nowait(default)
738        obj6 = rch.recv_nowait(default)
739
740        self.assertIs(obj1, default)
741        self.assertIs(obj2, None)
742        self.assertEqual(obj3, 1)
743        self.assertEqual(obj4, b'spam')
744        self.assertEqual(obj5, b'eggs')
745        self.assertIs(obj6, default)
746