1from collections import namedtuple
2import contextlib
3import itertools
4import os
5import pickle
6import sys
7from textwrap import dedent
8import threading
9import time
10import unittest
11
12from test import support
13from test.support import script_helper
14
15
16interpreters = support.import_module('_xxsubinterpreters')
17
18
19##################################
20# helpers
21
22def powerset(*sets):
23    return itertools.chain.from_iterable(
24        combinations(sets, r)
25        for r in range(len(sets)+1))
26
27
28def _captured_script(script):
29    r, w = os.pipe()
30    indented = script.replace('\n', '\n                ')
31    wrapped = dedent(f"""
32        import contextlib
33        with open({w}, 'w') as spipe:
34            with contextlib.redirect_stdout(spipe):
35                {indented}
36        """)
37    return wrapped, open(r)
38
39
40def _run_output(interp, request, shared=None):
41    script, rpipe = _captured_script(request)
42    with rpipe:
43        interpreters.run_string(interp, script, shared)
44        return rpipe.read()
45
46
47@contextlib.contextmanager
48def _running(interp):
49    r, w = os.pipe()
50    def run():
51        interpreters.run_string(interp, dedent(f"""
52            # wait for "signal"
53            with open({r}) as rpipe:
54                rpipe.read()
55            """))
56
57    t = threading.Thread(target=run)
58    t.start()
59
60    yield
61
62    with open(w, 'w') as spipe:
63        spipe.write('done')
64    t.join()
65
66
67#@contextmanager
68#def run_threaded(id, source, **shared):
69#    def run():
70#        run_interp(id, source, **shared)
71#    t = threading.Thread(target=run)
72#    t.start()
73#    yield
74#    t.join()
75
76
77def run_interp(id, source, **shared):
78    _run_interp(id, source, shared)
79
80
81def _run_interp(id, source, shared, _mainns={}):
82    source = dedent(source)
83    main = interpreters.get_main()
84    if main == id:
85        if interpreters.get_current() != main:
86            raise RuntimeError
87        # XXX Run a func?
88        exec(source, _mainns)
89    else:
90        interpreters.run_string(id, source, shared)
91
92
93def run_interp_threaded(id, source, **shared):
94    def run():
95        _run(id, source, shared)
96    t = threading.Thread(target=run)
97    t.start()
98    t.join()
99
100
101class Interpreter(namedtuple('Interpreter', 'name id')):
102
103    @classmethod
104    def from_raw(cls, raw):
105        if isinstance(raw, cls):
106            return raw
107        elif isinstance(raw, str):
108            return cls(raw)
109        else:
110            raise NotImplementedError
111
112    def __new__(cls, name=None, id=None):
113        main = interpreters.get_main()
114        if id == main:
115            if not name:
116                name = 'main'
117            elif name != 'main':
118                raise ValueError(
119                    'name mismatch (expected "main", got "{}")'.format(name))
120            id = main
121        elif id is not None:
122            if not name:
123                name = 'interp'
124            elif name == 'main':
125                raise ValueError('name mismatch (unexpected "main")')
126            if not isinstance(id, interpreters.InterpreterID):
127                id = interpreters.InterpreterID(id)
128        elif not name or name == 'main':
129            name = 'main'
130            id = main
131        else:
132            id = interpreters.create()
133        self = super().__new__(cls, name, id)
134        return self
135
136
137# XXX expect_channel_closed() is unnecessary once we improve exc propagation.
138
139@contextlib.contextmanager
140def expect_channel_closed():
141    try:
142        yield
143    except interpreters.ChannelClosedError:
144        pass
145    else:
146        assert False, 'channel not closed'
147
148
149class ChannelAction(namedtuple('ChannelAction', 'action end interp')):
150
151    def __new__(cls, action, end=None, interp=None):
152        if not end:
153            end = 'both'
154        if not interp:
155            interp = 'main'
156        self = super().__new__(cls, action, end, interp)
157        return self
158
159    def __init__(self, *args, **kwargs):
160        if self.action == 'use':
161            if self.end not in ('same', 'opposite', 'send', 'recv'):
162                raise ValueError(self.end)
163        elif self.action in ('close', 'force-close'):
164            if self.end not in ('both', 'same', 'opposite', 'send', 'recv'):
165                raise ValueError(self.end)
166        else:
167            raise ValueError(self.action)
168        if self.interp not in ('main', 'same', 'other', 'extra'):
169            raise ValueError(self.interp)
170
171    def resolve_end(self, end):
172        if self.end == 'same':
173            return end
174        elif self.end == 'opposite':
175            return 'recv' if end == 'send' else 'send'
176        else:
177            return self.end
178
179    def resolve_interp(self, interp, other, extra):
180        if self.interp == 'same':
181            return interp
182        elif self.interp == 'other':
183            if other is None:
184                raise RuntimeError
185            return other
186        elif self.interp == 'extra':
187            if extra is None:
188                raise RuntimeError
189            return extra
190        elif self.interp == 'main':
191            if interp.name == 'main':
192                return interp
193            elif other and other.name == 'main':
194                return other
195            else:
196                raise RuntimeError
197        # Per __init__(), there aren't any others.
198
199
200class ChannelState(namedtuple('ChannelState', 'pending closed')):
201
202    def __new__(cls, pending=0, *, closed=False):
203        self = super().__new__(cls, pending, closed)
204        return self
205
206    def incr(self):
207        return type(self)(self.pending + 1, closed=self.closed)
208
209    def decr(self):
210        return type(self)(self.pending - 1, closed=self.closed)
211
212    def close(self, *, force=True):
213        if self.closed:
214            if not force or self.pending == 0:
215                return self
216        return type(self)(0 if force else self.pending, closed=True)
217
218
219def run_action(cid, action, end, state, *, hideclosed=True):
220    if state.closed:
221        if action == 'use' and end == 'recv' and state.pending:
222            expectfail = False
223        else:
224            expectfail = True
225    else:
226        expectfail = False
227
228    try:
229        result = _run_action(cid, action, end, state)
230    except interpreters.ChannelClosedError:
231        if not hideclosed and not expectfail:
232            raise
233        result = state.close()
234    else:
235        if expectfail:
236            raise ...  # XXX
237    return result
238
239
240def _run_action(cid, action, end, state):
241    if action == 'use':
242        if end == 'send':
243            interpreters.channel_send(cid, b'spam')
244            return state.incr()
245        elif end == 'recv':
246            if not state.pending:
247                try:
248                    interpreters.channel_recv(cid)
249                except interpreters.ChannelEmptyError:
250                    return state
251                else:
252                    raise Exception('expected ChannelEmptyError')
253            else:
254                interpreters.channel_recv(cid)
255                return state.decr()
256        else:
257            raise ValueError(end)
258    elif action == 'close':
259        kwargs = {}
260        if end in ('recv', 'send'):
261            kwargs[end] = True
262        interpreters.channel_close(cid, **kwargs)
263        return state.close()
264    elif action == 'force-close':
265        kwargs = {
266            'force': True,
267            }
268        if end in ('recv', 'send'):
269            kwargs[end] = True
270        interpreters.channel_close(cid, **kwargs)
271        return state.close(force=True)
272    else:
273        raise ValueError(action)
274
275
276def clean_up_interpreters():
277    for id in interpreters.list_all():
278        if id == 0:  # main
279            continue
280        try:
281            interpreters.destroy(id)
282        except RuntimeError:
283            pass  # already destroyed
284
285
286def clean_up_channels():
287    for cid in interpreters.channel_list_all():
288        try:
289            interpreters.channel_destroy(cid)
290        except interpreters.ChannelNotFoundError:
291            pass  # already destroyed
292
293
294class TestBase(unittest.TestCase):
295
296    def tearDown(self):
297        clean_up_interpreters()
298        clean_up_channels()
299
300
301##################################
302# misc. tests
303
304class IsShareableTests(unittest.TestCase):
305
306    def test_default_shareables(self):
307        shareables = [
308                # singletons
309                None,
310                # builtin objects
311                b'spam',
312                'spam',
313                10,
314                -10,
315                ]
316        for obj in shareables:
317            with self.subTest(obj):
318                self.assertTrue(
319                    interpreters.is_shareable(obj))
320
321    def test_not_shareable(self):
322        class Cheese:
323            def __init__(self, name):
324                self.name = name
325            def __str__(self):
326                return self.name
327
328        class SubBytes(bytes):
329            """A subclass of a shareable type."""
330
331        not_shareables = [
332                # singletons
333                True,
334                False,
335                NotImplemented,
336                ...,
337                # builtin types and objects
338                type,
339                object,
340                object(),
341                Exception(),
342                100.0,
343                # user-defined types and objects
344                Cheese,
345                Cheese('Wensleydale'),
346                SubBytes(b'spam'),
347                ]
348        for obj in not_shareables:
349            with self.subTest(repr(obj)):
350                self.assertFalse(
351                    interpreters.is_shareable(obj))
352
353
354class ShareableTypeTests(unittest.TestCase):
355
356    def setUp(self):
357        super().setUp()
358        self.cid = interpreters.channel_create()
359
360    def tearDown(self):
361        interpreters.channel_destroy(self.cid)
362        super().tearDown()
363
364    def _assert_values(self, values):
365        for obj in values:
366            with self.subTest(obj):
367                interpreters.channel_send(self.cid, obj)
368                got = interpreters.channel_recv(self.cid)
369
370                self.assertEqual(got, obj)
371                self.assertIs(type(got), type(obj))
372                # XXX Check the following in the channel tests?
373                #self.assertIsNot(got, obj)
374
375    def test_singletons(self):
376        for obj in [None]:
377            with self.subTest(obj):
378                interpreters.channel_send(self.cid, obj)
379                got = interpreters.channel_recv(self.cid)
380
381                # XXX What about between interpreters?
382                self.assertIs(got, obj)
383
384    def test_types(self):
385        self._assert_values([
386            b'spam',
387            9999,
388            self.cid,
389            ])
390
391    def test_bytes(self):
392        self._assert_values(i.to_bytes(2, 'little', signed=True)
393                            for i in range(-1, 258))
394
395    def test_strs(self):
396        self._assert_values(['hello world', '你好世界', ''])
397
398    def test_int(self):
399        self._assert_values(itertools.chain(range(-1, 258),
400                                            [sys.maxsize, -sys.maxsize - 1]))
401
402    def test_non_shareable_int(self):
403        ints = [
404            sys.maxsize + 1,
405            -sys.maxsize - 2,
406            2**1000,
407        ]
408        for i in ints:
409            with self.subTest(i):
410                with self.assertRaises(OverflowError):
411                    interpreters.channel_send(self.cid, i)
412
413
414##################################
415# interpreter tests
416
417class ListAllTests(TestBase):
418
419    def test_initial(self):
420        main = interpreters.get_main()
421        ids = interpreters.list_all()
422        self.assertEqual(ids, [main])
423
424    def test_after_creating(self):
425        main = interpreters.get_main()
426        first = interpreters.create()
427        second = interpreters.create()
428        ids = interpreters.list_all()
429        self.assertEqual(ids, [main, first, second])
430
431    def test_after_destroying(self):
432        main = interpreters.get_main()
433        first = interpreters.create()
434        second = interpreters.create()
435        interpreters.destroy(first)
436        ids = interpreters.list_all()
437        self.assertEqual(ids, [main, second])
438
439
440class GetCurrentTests(TestBase):
441
442    def test_main(self):
443        main = interpreters.get_main()
444        cur = interpreters.get_current()
445        self.assertEqual(cur, main)
446        self.assertIsInstance(cur, interpreters.InterpreterID)
447
448    def test_subinterpreter(self):
449        main = interpreters.get_main()
450        interp = interpreters.create()
451        out = _run_output(interp, dedent("""
452            import _xxsubinterpreters as _interpreters
453            cur = _interpreters.get_current()
454            print(cur)
455            assert isinstance(cur, _interpreters.InterpreterID)
456            """))
457        cur = int(out.strip())
458        _, expected = interpreters.list_all()
459        self.assertEqual(cur, expected)
460        self.assertNotEqual(cur, main)
461
462
463class GetMainTests(TestBase):
464
465    def test_from_main(self):
466        [expected] = interpreters.list_all()
467        main = interpreters.get_main()
468        self.assertEqual(main, expected)
469        self.assertIsInstance(main, interpreters.InterpreterID)
470
471    def test_from_subinterpreter(self):
472        [expected] = interpreters.list_all()
473        interp = interpreters.create()
474        out = _run_output(interp, dedent("""
475            import _xxsubinterpreters as _interpreters
476            main = _interpreters.get_main()
477            print(main)
478            assert isinstance(main, _interpreters.InterpreterID)
479            """))
480        main = int(out.strip())
481        self.assertEqual(main, expected)
482
483
484class IsRunningTests(TestBase):
485
486    def test_main(self):
487        main = interpreters.get_main()
488        self.assertTrue(interpreters.is_running(main))
489
490    def test_subinterpreter(self):
491        interp = interpreters.create()
492        self.assertFalse(interpreters.is_running(interp))
493
494        with _running(interp):
495            self.assertTrue(interpreters.is_running(interp))
496        self.assertFalse(interpreters.is_running(interp))
497
498    def test_from_subinterpreter(self):
499        interp = interpreters.create()
500        out = _run_output(interp, dedent(f"""
501            import _xxsubinterpreters as _interpreters
502            if _interpreters.is_running({interp}):
503                print(True)
504            else:
505                print(False)
506            """))
507        self.assertEqual(out.strip(), 'True')
508
509    def test_already_destroyed(self):
510        interp = interpreters.create()
511        interpreters.destroy(interp)
512        with self.assertRaises(RuntimeError):
513            interpreters.is_running(interp)
514
515    def test_does_not_exist(self):
516        with self.assertRaises(RuntimeError):
517            interpreters.is_running(1_000_000)
518
519    def test_bad_id(self):
520        with self.assertRaises(ValueError):
521            interpreters.is_running(-1)
522
523
524class InterpreterIDTests(TestBase):
525
526    def test_with_int(self):
527        id = interpreters.InterpreterID(10, force=True)
528
529        self.assertEqual(int(id), 10)
530
531    def test_coerce_id(self):
532        class Int(str):
533            def __index__(self):
534                return 10
535
536        id = interpreters.InterpreterID(Int(), force=True)
537        self.assertEqual(int(id), 10)
538
539    def test_bad_id(self):
540        self.assertRaises(TypeError, interpreters.InterpreterID, object())
541        self.assertRaises(TypeError, interpreters.InterpreterID, 10.0)
542        self.assertRaises(TypeError, interpreters.InterpreterID, '10')
543        self.assertRaises(TypeError, interpreters.InterpreterID, b'10')
544        self.assertRaises(ValueError, interpreters.InterpreterID, -1)
545        self.assertRaises(OverflowError, interpreters.InterpreterID, 2**64)
546
547    def test_does_not_exist(self):
548        id = interpreters.channel_create()
549        with self.assertRaises(RuntimeError):
550            interpreters.InterpreterID(int(id) + 1)  # unforced
551
552    def test_str(self):
553        id = interpreters.InterpreterID(10, force=True)
554        self.assertEqual(str(id), '10')
555
556    def test_repr(self):
557        id = interpreters.InterpreterID(10, force=True)
558        self.assertEqual(repr(id), 'InterpreterID(10)')
559
560    def test_equality(self):
561        id1 = interpreters.create()
562        id2 = interpreters.InterpreterID(int(id1))
563        id3 = interpreters.create()
564
565        self.assertTrue(id1 == id1)
566        self.assertTrue(id1 == id2)
567        self.assertTrue(id1 == int(id1))
568        self.assertTrue(int(id1) == id1)
569        self.assertTrue(id1 == float(int(id1)))
570        self.assertTrue(float(int(id1)) == id1)
571        self.assertFalse(id1 == float(int(id1)) + 0.1)
572        self.assertFalse(id1 == str(int(id1)))
573        self.assertFalse(id1 == 2**1000)
574        self.assertFalse(id1 == float('inf'))
575        self.assertFalse(id1 == 'spam')
576        self.assertFalse(id1 == id3)
577
578        self.assertFalse(id1 != id1)
579        self.assertFalse(id1 != id2)
580        self.assertTrue(id1 != id3)
581
582
583class CreateTests(TestBase):
584
585    def test_in_main(self):
586        id = interpreters.create()
587        self.assertIsInstance(id, interpreters.InterpreterID)
588
589        self.assertIn(id, interpreters.list_all())
590
591    @unittest.skip('enable this test when working on pystate.c')
592    def test_unique_id(self):
593        seen = set()
594        for _ in range(100):
595            id = interpreters.create()
596            interpreters.destroy(id)
597            seen.add(id)
598
599        self.assertEqual(len(seen), 100)
600
601    def test_in_thread(self):
602        lock = threading.Lock()
603        id = None
604        def f():
605            nonlocal id
606            id = interpreters.create()
607            lock.acquire()
608            lock.release()
609
610        t = threading.Thread(target=f)
611        with lock:
612            t.start()
613        t.join()
614        self.assertIn(id, interpreters.list_all())
615
616    def test_in_subinterpreter(self):
617        main, = interpreters.list_all()
618        id1 = interpreters.create()
619        out = _run_output(id1, dedent("""
620            import _xxsubinterpreters as _interpreters
621            id = _interpreters.create()
622            print(id)
623            assert isinstance(id, _interpreters.InterpreterID)
624            """))
625        id2 = int(out.strip())
626
627        self.assertEqual(set(interpreters.list_all()), {main, id1, id2})
628
629    def test_in_threaded_subinterpreter(self):
630        main, = interpreters.list_all()
631        id1 = interpreters.create()
632        id2 = None
633        def f():
634            nonlocal id2
635            out = _run_output(id1, dedent("""
636                import _xxsubinterpreters as _interpreters
637                id = _interpreters.create()
638                print(id)
639                """))
640            id2 = int(out.strip())
641
642        t = threading.Thread(target=f)
643        t.start()
644        t.join()
645
646        self.assertEqual(set(interpreters.list_all()), {main, id1, id2})
647
648    def test_after_destroy_all(self):
649        before = set(interpreters.list_all())
650        # Create 3 subinterpreters.
651        ids = []
652        for _ in range(3):
653            id = interpreters.create()
654            ids.append(id)
655        # Now destroy them.
656        for id in ids:
657            interpreters.destroy(id)
658        # Finally, create another.
659        id = interpreters.create()
660        self.assertEqual(set(interpreters.list_all()), before | {id})
661
662    def test_after_destroy_some(self):
663        before = set(interpreters.list_all())
664        # Create 3 subinterpreters.
665        id1 = interpreters.create()
666        id2 = interpreters.create()
667        id3 = interpreters.create()
668        # Now destroy 2 of them.
669        interpreters.destroy(id1)
670        interpreters.destroy(id3)
671        # Finally, create another.
672        id = interpreters.create()
673        self.assertEqual(set(interpreters.list_all()), before | {id, id2})
674
675
676class DestroyTests(TestBase):
677
678    def test_one(self):
679        id1 = interpreters.create()
680        id2 = interpreters.create()
681        id3 = interpreters.create()
682        self.assertIn(id2, interpreters.list_all())
683        interpreters.destroy(id2)
684        self.assertNotIn(id2, interpreters.list_all())
685        self.assertIn(id1, interpreters.list_all())
686        self.assertIn(id3, interpreters.list_all())
687
688    def test_all(self):
689        before = set(interpreters.list_all())
690        ids = set()
691        for _ in range(3):
692            id = interpreters.create()
693            ids.add(id)
694        self.assertEqual(set(interpreters.list_all()), before | ids)
695        for id in ids:
696            interpreters.destroy(id)
697        self.assertEqual(set(interpreters.list_all()), before)
698
699    def test_main(self):
700        main, = interpreters.list_all()
701        with self.assertRaises(RuntimeError):
702            interpreters.destroy(main)
703
704        def f():
705            with self.assertRaises(RuntimeError):
706                interpreters.destroy(main)
707
708        t = threading.Thread(target=f)
709        t.start()
710        t.join()
711
712    def test_already_destroyed(self):
713        id = interpreters.create()
714        interpreters.destroy(id)
715        with self.assertRaises(RuntimeError):
716            interpreters.destroy(id)
717
718    def test_does_not_exist(self):
719        with self.assertRaises(RuntimeError):
720            interpreters.destroy(1_000_000)
721
722    def test_bad_id(self):
723        with self.assertRaises(ValueError):
724            interpreters.destroy(-1)
725
726    def test_from_current(self):
727        main, = interpreters.list_all()
728        id = interpreters.create()
729        script = dedent(f"""
730            import _xxsubinterpreters as _interpreters
731            try:
732                _interpreters.destroy({id})
733            except RuntimeError:
734                pass
735            """)
736
737        interpreters.run_string(id, script)
738        self.assertEqual(set(interpreters.list_all()), {main, id})
739
740    def test_from_sibling(self):
741        main, = interpreters.list_all()
742        id1 = interpreters.create()
743        id2 = interpreters.create()
744        script = dedent(f"""
745            import _xxsubinterpreters as _interpreters
746            _interpreters.destroy({id2})
747            """)
748        interpreters.run_string(id1, script)
749
750        self.assertEqual(set(interpreters.list_all()), {main, id1})
751
752    def test_from_other_thread(self):
753        id = interpreters.create()
754        def f():
755            interpreters.destroy(id)
756
757        t = threading.Thread(target=f)
758        t.start()
759        t.join()
760
761    def test_still_running(self):
762        main, = interpreters.list_all()
763        interp = interpreters.create()
764        with _running(interp):
765            self.assertTrue(interpreters.is_running(interp),
766                            msg=f"Interp {interp} should be running before destruction.")
767
768            with self.assertRaises(RuntimeError,
769                                   msg=f"Should not be able to destroy interp {interp} while it's still running."):
770                interpreters.destroy(interp)
771            self.assertTrue(interpreters.is_running(interp))
772
773
774class RunStringTests(TestBase):
775
776    SCRIPT = dedent("""
777        with open('{}', 'w') as out:
778            out.write('{}')
779        """)
780    FILENAME = 'spam'
781
782    def setUp(self):
783        super().setUp()
784        self.id = interpreters.create()
785        self._fs = None
786
787    def tearDown(self):
788        if self._fs is not None:
789            self._fs.close()
790        super().tearDown()
791
792    @property
793    def fs(self):
794        if self._fs is None:
795            self._fs = FSFixture(self)
796        return self._fs
797
798    def test_success(self):
799        script, file = _captured_script('print("it worked!", end="")')
800        with file:
801            interpreters.run_string(self.id, script)
802            out = file.read()
803
804        self.assertEqual(out, 'it worked!')
805
806    def test_in_thread(self):
807        script, file = _captured_script('print("it worked!", end="")')
808        with file:
809            def f():
810                interpreters.run_string(self.id, script)
811
812            t = threading.Thread(target=f)
813            t.start()
814            t.join()
815            out = file.read()
816
817        self.assertEqual(out, 'it worked!')
818
819    def test_create_thread(self):
820        script, file = _captured_script("""
821            import threading
822            def f():
823                print('it worked!', end='')
824
825            t = threading.Thread(target=f)
826            t.start()
827            t.join()
828            """)
829        with file:
830            interpreters.run_string(self.id, script)
831            out = file.read()
832
833        self.assertEqual(out, 'it worked!')
834
835    @unittest.skipUnless(hasattr(os, 'fork'), "test needs os.fork()")
836    def test_fork(self):
837        import tempfile
838        with tempfile.NamedTemporaryFile('w+') as file:
839            file.write('')
840            file.flush()
841
842            expected = 'spam spam spam spam spam'
843            script = dedent(f"""
844                import os
845                try:
846                    os.fork()
847                except RuntimeError:
848                    with open('{file.name}', 'w') as out:
849                        out.write('{expected}')
850                """)
851            interpreters.run_string(self.id, script)
852
853            file.seek(0)
854            content = file.read()
855            self.assertEqual(content, expected)
856
857    def test_already_running(self):
858        with _running(self.id):
859            with self.assertRaises(RuntimeError):
860                interpreters.run_string(self.id, 'print("spam")')
861
862    def test_does_not_exist(self):
863        id = 0
864        while id in interpreters.list_all():
865            id += 1
866        with self.assertRaises(RuntimeError):
867            interpreters.run_string(id, 'print("spam")')
868
869    def test_error_id(self):
870        with self.assertRaises(ValueError):
871            interpreters.run_string(-1, 'print("spam")')
872
873    def test_bad_id(self):
874        with self.assertRaises(TypeError):
875            interpreters.run_string('spam', 'print("spam")')
876
877    def test_bad_script(self):
878        with self.assertRaises(TypeError):
879            interpreters.run_string(self.id, 10)
880
881    def test_bytes_for_script(self):
882        with self.assertRaises(TypeError):
883            interpreters.run_string(self.id, b'print("spam")')
884
885    @contextlib.contextmanager
886    def assert_run_failed(self, exctype, msg=None):
887        with self.assertRaises(interpreters.RunFailedError) as caught:
888            yield
889        if msg is None:
890            self.assertEqual(str(caught.exception).split(':')[0],
891                             str(exctype))
892        else:
893            self.assertEqual(str(caught.exception),
894                             "{}: {}".format(exctype, msg))
895
896    def test_invalid_syntax(self):
897        with self.assert_run_failed(SyntaxError):
898            # missing close paren
899            interpreters.run_string(self.id, 'print("spam"')
900
901    def test_failure(self):
902        with self.assert_run_failed(Exception, 'spam'):
903            interpreters.run_string(self.id, 'raise Exception("spam")')
904
905    def test_SystemExit(self):
906        with self.assert_run_failed(SystemExit, '42'):
907            interpreters.run_string(self.id, 'raise SystemExit(42)')
908
909    def test_sys_exit(self):
910        with self.assert_run_failed(SystemExit):
911            interpreters.run_string(self.id, dedent("""
912                import sys
913                sys.exit()
914                """))
915
916        with self.assert_run_failed(SystemExit, '42'):
917            interpreters.run_string(self.id, dedent("""
918                import sys
919                sys.exit(42)
920                """))
921
922    def test_with_shared(self):
923        r, w = os.pipe()
924
925        shared = {
926                'spam': b'ham',
927                'eggs': b'-1',
928                'cheddar': None,
929                }
930        script = dedent(f"""
931            eggs = int(eggs)
932            spam = 42
933            result = spam + eggs
934
935            ns = dict(vars())
936            del ns['__builtins__']
937            import pickle
938            with open({w}, 'wb') as chan:
939                pickle.dump(ns, chan)
940            """)
941        interpreters.run_string(self.id, script, shared)
942        with open(r, 'rb') as chan:
943            ns = pickle.load(chan)
944
945        self.assertEqual(ns['spam'], 42)
946        self.assertEqual(ns['eggs'], -1)
947        self.assertEqual(ns['result'], 41)
948        self.assertIsNone(ns['cheddar'])
949
950    def test_shared_overwrites(self):
951        interpreters.run_string(self.id, dedent("""
952            spam = 'eggs'
953            ns1 = dict(vars())
954            del ns1['__builtins__']
955            """))
956
957        shared = {'spam': b'ham'}
958        script = dedent(f"""
959            ns2 = dict(vars())
960            del ns2['__builtins__']
961        """)
962        interpreters.run_string(self.id, script, shared)
963
964        r, w = os.pipe()
965        script = dedent(f"""
966            ns = dict(vars())
967            del ns['__builtins__']
968            import pickle
969            with open({w}, 'wb') as chan:
970                pickle.dump(ns, chan)
971            """)
972        interpreters.run_string(self.id, script)
973        with open(r, 'rb') as chan:
974            ns = pickle.load(chan)
975
976        self.assertEqual(ns['ns1']['spam'], 'eggs')
977        self.assertEqual(ns['ns2']['spam'], b'ham')
978        self.assertEqual(ns['spam'], b'ham')
979
980    def test_shared_overwrites_default_vars(self):
981        r, w = os.pipe()
982
983        shared = {'__name__': b'not __main__'}
984        script = dedent(f"""
985            spam = 42
986
987            ns = dict(vars())
988            del ns['__builtins__']
989            import pickle
990            with open({w}, 'wb') as chan:
991                pickle.dump(ns, chan)
992            """)
993        interpreters.run_string(self.id, script, shared)
994        with open(r, 'rb') as chan:
995            ns = pickle.load(chan)
996
997        self.assertEqual(ns['__name__'], b'not __main__')
998
999    def test_main_reused(self):
1000        r, w = os.pipe()
1001        interpreters.run_string(self.id, dedent(f"""
1002            spam = True
1003
1004            ns = dict(vars())
1005            del ns['__builtins__']
1006            import pickle
1007            with open({w}, 'wb') as chan:
1008                pickle.dump(ns, chan)
1009            del ns, pickle, chan
1010            """))
1011        with open(r, 'rb') as chan:
1012            ns1 = pickle.load(chan)
1013
1014        r, w = os.pipe()
1015        interpreters.run_string(self.id, dedent(f"""
1016            eggs = False
1017
1018            ns = dict(vars())
1019            del ns['__builtins__']
1020            import pickle
1021            with open({w}, 'wb') as chan:
1022                pickle.dump(ns, chan)
1023            """))
1024        with open(r, 'rb') as chan:
1025            ns2 = pickle.load(chan)
1026
1027        self.assertIn('spam', ns1)
1028        self.assertNotIn('eggs', ns1)
1029        self.assertIn('eggs', ns2)
1030        self.assertIn('spam', ns2)
1031
1032    def test_execution_namespace_is_main(self):
1033        r, w = os.pipe()
1034
1035        script = dedent(f"""
1036            spam = 42
1037
1038            ns = dict(vars())
1039            ns['__builtins__'] = str(ns['__builtins__'])
1040            import pickle
1041            with open({w}, 'wb') as chan:
1042                pickle.dump(ns, chan)
1043            """)
1044        interpreters.run_string(self.id, script)
1045        with open(r, 'rb') as chan:
1046            ns = pickle.load(chan)
1047
1048        ns.pop('__builtins__')
1049        ns.pop('__loader__')
1050        self.assertEqual(ns, {
1051            '__name__': '__main__',
1052            '__annotations__': {},
1053            '__doc__': None,
1054            '__package__': None,
1055            '__spec__': None,
1056            'spam': 42,
1057            })
1058
1059    # XXX Fix this test!
1060    @unittest.skip('blocking forever')
1061    def test_still_running_at_exit(self):
1062        script = dedent(f"""
1063        from textwrap import dedent
1064        import threading
1065        import _xxsubinterpreters as _interpreters
1066        id = _interpreters.create()
1067        def f():
1068            _interpreters.run_string(id, dedent('''
1069                import time
1070                # Give plenty of time for the main interpreter to finish.
1071                time.sleep(1_000_000)
1072                '''))
1073
1074        t = threading.Thread(target=f)
1075        t.start()
1076        """)
1077        with support.temp_dir() as dirname:
1078            filename = script_helper.make_script(dirname, 'interp', script)
1079            with script_helper.spawn_python(filename) as proc:
1080                retcode = proc.wait()
1081
1082        self.assertEqual(retcode, 0)
1083
1084
1085##################################
1086# channel tests
1087
1088class ChannelIDTests(TestBase):
1089
1090    def test_default_kwargs(self):
1091        cid = interpreters._channel_id(10, force=True)
1092
1093        self.assertEqual(int(cid), 10)
1094        self.assertEqual(cid.end, 'both')
1095
1096    def test_with_kwargs(self):
1097        cid = interpreters._channel_id(10, send=True, force=True)
1098        self.assertEqual(cid.end, 'send')
1099
1100        cid = interpreters._channel_id(10, send=True, recv=False, force=True)
1101        self.assertEqual(cid.end, 'send')
1102
1103        cid = interpreters._channel_id(10, recv=True, force=True)
1104        self.assertEqual(cid.end, 'recv')
1105
1106        cid = interpreters._channel_id(10, recv=True, send=False, force=True)
1107        self.assertEqual(cid.end, 'recv')
1108
1109        cid = interpreters._channel_id(10, send=True, recv=True, force=True)
1110        self.assertEqual(cid.end, 'both')
1111
1112    def test_coerce_id(self):
1113        class Int(str):
1114            def __index__(self):
1115                return 10
1116
1117        cid = interpreters._channel_id(Int(), force=True)
1118        self.assertEqual(int(cid), 10)
1119
1120    def test_bad_id(self):
1121        self.assertRaises(TypeError, interpreters._channel_id, object())
1122        self.assertRaises(TypeError, interpreters._channel_id, 10.0)
1123        self.assertRaises(TypeError, interpreters._channel_id, '10')
1124        self.assertRaises(TypeError, interpreters._channel_id, b'10')
1125        self.assertRaises(ValueError, interpreters._channel_id, -1)
1126        self.assertRaises(OverflowError, interpreters._channel_id, 2**64)
1127
1128    def test_bad_kwargs(self):
1129        with self.assertRaises(ValueError):
1130            interpreters._channel_id(10, send=False, recv=False)
1131
1132    def test_does_not_exist(self):
1133        cid = interpreters.channel_create()
1134        with self.assertRaises(interpreters.ChannelNotFoundError):
1135            interpreters._channel_id(int(cid) + 1)  # unforced
1136
1137    def test_str(self):
1138        cid = interpreters._channel_id(10, force=True)
1139        self.assertEqual(str(cid), '10')
1140
1141    def test_repr(self):
1142        cid = interpreters._channel_id(10, force=True)
1143        self.assertEqual(repr(cid), 'ChannelID(10)')
1144
1145        cid = interpreters._channel_id(10, send=True, force=True)
1146        self.assertEqual(repr(cid), 'ChannelID(10, send=True)')
1147
1148        cid = interpreters._channel_id(10, recv=True, force=True)
1149        self.assertEqual(repr(cid), 'ChannelID(10, recv=True)')
1150
1151        cid = interpreters._channel_id(10, send=True, recv=True, force=True)
1152        self.assertEqual(repr(cid), 'ChannelID(10)')
1153
1154    def test_equality(self):
1155        cid1 = interpreters.channel_create()
1156        cid2 = interpreters._channel_id(int(cid1))
1157        cid3 = interpreters.channel_create()
1158
1159        self.assertTrue(cid1 == cid1)
1160        self.assertTrue(cid1 == cid2)
1161        self.assertTrue(cid1 == int(cid1))
1162        self.assertTrue(int(cid1) == cid1)
1163        self.assertTrue(cid1 == float(int(cid1)))
1164        self.assertTrue(float(int(cid1)) == cid1)
1165        self.assertFalse(cid1 == float(int(cid1)) + 0.1)
1166        self.assertFalse(cid1 == str(int(cid1)))
1167        self.assertFalse(cid1 == 2**1000)
1168        self.assertFalse(cid1 == float('inf'))
1169        self.assertFalse(cid1 == 'spam')
1170        self.assertFalse(cid1 == cid3)
1171
1172        self.assertFalse(cid1 != cid1)
1173        self.assertFalse(cid1 != cid2)
1174        self.assertTrue(cid1 != cid3)
1175
1176
1177class ChannelTests(TestBase):
1178
1179    def test_create_cid(self):
1180        cid = interpreters.channel_create()
1181        self.assertIsInstance(cid, interpreters.ChannelID)
1182
1183    def test_sequential_ids(self):
1184        before = interpreters.channel_list_all()
1185        id1 = interpreters.channel_create()
1186        id2 = interpreters.channel_create()
1187        id3 = interpreters.channel_create()
1188        after = interpreters.channel_list_all()
1189
1190        self.assertEqual(id2, int(id1) + 1)
1191        self.assertEqual(id3, int(id2) + 1)
1192        self.assertEqual(set(after) - set(before), {id1, id2, id3})
1193
1194    def test_ids_global(self):
1195        id1 = interpreters.create()
1196        out = _run_output(id1, dedent("""
1197            import _xxsubinterpreters as _interpreters
1198            cid = _interpreters.channel_create()
1199            print(cid)
1200            """))
1201        cid1 = int(out.strip())
1202
1203        id2 = interpreters.create()
1204        out = _run_output(id2, dedent("""
1205            import _xxsubinterpreters as _interpreters
1206            cid = _interpreters.channel_create()
1207            print(cid)
1208            """))
1209        cid2 = int(out.strip())
1210
1211        self.assertEqual(cid2, int(cid1) + 1)
1212
1213    ####################
1214
1215    def test_send_recv_main(self):
1216        cid = interpreters.channel_create()
1217        orig = b'spam'
1218        interpreters.channel_send(cid, orig)
1219        obj = interpreters.channel_recv(cid)
1220
1221        self.assertEqual(obj, orig)
1222        self.assertIsNot(obj, orig)
1223
1224    def test_send_recv_same_interpreter(self):
1225        id1 = interpreters.create()
1226        out = _run_output(id1, dedent("""
1227            import _xxsubinterpreters as _interpreters
1228            cid = _interpreters.channel_create()
1229            orig = b'spam'
1230            _interpreters.channel_send(cid, orig)
1231            obj = _interpreters.channel_recv(cid)
1232            assert obj is not orig
1233            assert obj == orig
1234            """))
1235
1236    def test_send_recv_different_interpreters(self):
1237        cid = interpreters.channel_create()
1238        id1 = interpreters.create()
1239        out = _run_output(id1, dedent(f"""
1240            import _xxsubinterpreters as _interpreters
1241            _interpreters.channel_send({cid}, b'spam')
1242            """))
1243        obj = interpreters.channel_recv(cid)
1244
1245        self.assertEqual(obj, b'spam')
1246
1247    def test_send_recv_different_threads(self):
1248        cid = interpreters.channel_create()
1249
1250        def f():
1251            while True:
1252                try:
1253                    obj = interpreters.channel_recv(cid)
1254                    break
1255                except interpreters.ChannelEmptyError:
1256                    time.sleep(0.1)
1257            interpreters.channel_send(cid, obj)
1258        t = threading.Thread(target=f)
1259        t.start()
1260
1261        interpreters.channel_send(cid, b'spam')
1262        t.join()
1263        obj = interpreters.channel_recv(cid)
1264
1265        self.assertEqual(obj, b'spam')
1266
1267    def test_send_recv_different_interpreters_and_threads(self):
1268        cid = interpreters.channel_create()
1269        id1 = interpreters.create()
1270        out = None
1271
1272        def f():
1273            nonlocal out
1274            out = _run_output(id1, dedent(f"""
1275                import time
1276                import _xxsubinterpreters as _interpreters
1277                while True:
1278                    try:
1279                        obj = _interpreters.channel_recv({cid})
1280                        break
1281                    except _interpreters.ChannelEmptyError:
1282                        time.sleep(0.1)
1283                assert(obj == b'spam')
1284                _interpreters.channel_send({cid}, b'eggs')
1285                """))
1286        t = threading.Thread(target=f)
1287        t.start()
1288
1289        interpreters.channel_send(cid, b'spam')
1290        t.join()
1291        obj = interpreters.channel_recv(cid)
1292
1293        self.assertEqual(obj, b'eggs')
1294
1295    def test_send_not_found(self):
1296        with self.assertRaises(interpreters.ChannelNotFoundError):
1297            interpreters.channel_send(10, b'spam')
1298
1299    def test_recv_not_found(self):
1300        with self.assertRaises(interpreters.ChannelNotFoundError):
1301            interpreters.channel_recv(10)
1302
1303    def test_recv_empty(self):
1304        cid = interpreters.channel_create()
1305        with self.assertRaises(interpreters.ChannelEmptyError):
1306            interpreters.channel_recv(cid)
1307
1308    def test_run_string_arg_unresolved(self):
1309        cid = interpreters.channel_create()
1310        interp = interpreters.create()
1311
1312        out = _run_output(interp, dedent("""
1313            import _xxsubinterpreters as _interpreters
1314            print(cid.end)
1315            _interpreters.channel_send(cid, b'spam')
1316            """),
1317            dict(cid=cid.send))
1318        obj = interpreters.channel_recv(cid)
1319
1320        self.assertEqual(obj, b'spam')
1321        self.assertEqual(out.strip(), 'send')
1322
1323    # XXX For now there is no high-level channel into which the
1324    # sent channel ID can be converted...
1325    # Note: this test caused crashes on some buildbots (bpo-33615).
1326    @unittest.skip('disabled until high-level channels exist')
1327    def test_run_string_arg_resolved(self):
1328        cid = interpreters.channel_create()
1329        cid = interpreters._channel_id(cid, _resolve=True)
1330        interp = interpreters.create()
1331
1332        out = _run_output(interp, dedent("""
1333            import _xxsubinterpreters as _interpreters
1334            print(chan.id.end)
1335            _interpreters.channel_send(chan.id, b'spam')
1336            """),
1337            dict(chan=cid.send))
1338        obj = interpreters.channel_recv(cid)
1339
1340        self.assertEqual(obj, b'spam')
1341        self.assertEqual(out.strip(), 'send')
1342
1343    # close
1344
1345    def test_close_single_user(self):
1346        cid = interpreters.channel_create()
1347        interpreters.channel_send(cid, b'spam')
1348        interpreters.channel_recv(cid)
1349        interpreters.channel_close(cid)
1350
1351        with self.assertRaises(interpreters.ChannelClosedError):
1352            interpreters.channel_send(cid, b'eggs')
1353        with self.assertRaises(interpreters.ChannelClosedError):
1354            interpreters.channel_recv(cid)
1355
1356    def test_close_multiple_users(self):
1357        cid = interpreters.channel_create()
1358        id1 = interpreters.create()
1359        id2 = interpreters.create()
1360        interpreters.run_string(id1, dedent(f"""
1361            import _xxsubinterpreters as _interpreters
1362            _interpreters.channel_send({cid}, b'spam')
1363            """))
1364        interpreters.run_string(id2, dedent(f"""
1365            import _xxsubinterpreters as _interpreters
1366            _interpreters.channel_recv({cid})
1367            """))
1368        interpreters.channel_close(cid)
1369        with self.assertRaises(interpreters.RunFailedError) as cm:
1370            interpreters.run_string(id1, dedent(f"""
1371                _interpreters.channel_send({cid}, b'spam')
1372                """))
1373        self.assertIn('ChannelClosedError', str(cm.exception))
1374        with self.assertRaises(interpreters.RunFailedError) as cm:
1375            interpreters.run_string(id2, dedent(f"""
1376                _interpreters.channel_send({cid}, b'spam')
1377                """))
1378        self.assertIn('ChannelClosedError', str(cm.exception))
1379
1380    def test_close_multiple_times(self):
1381        cid = interpreters.channel_create()
1382        interpreters.channel_send(cid, b'spam')
1383        interpreters.channel_recv(cid)
1384        interpreters.channel_close(cid)
1385
1386        with self.assertRaises(interpreters.ChannelClosedError):
1387            interpreters.channel_close(cid)
1388
1389    def test_close_empty(self):
1390        tests = [
1391            (False, False),
1392            (True, False),
1393            (False, True),
1394            (True, True),
1395            ]
1396        for send, recv in tests:
1397            with self.subTest((send, recv)):
1398                cid = interpreters.channel_create()
1399                interpreters.channel_send(cid, b'spam')
1400                interpreters.channel_recv(cid)
1401                interpreters.channel_close(cid, send=send, recv=recv)
1402
1403                with self.assertRaises(interpreters.ChannelClosedError):
1404                    interpreters.channel_send(cid, b'eggs')
1405                with self.assertRaises(interpreters.ChannelClosedError):
1406                    interpreters.channel_recv(cid)
1407
1408    def test_close_defaults_with_unused_items(self):
1409        cid = interpreters.channel_create()
1410        interpreters.channel_send(cid, b'spam')
1411        interpreters.channel_send(cid, b'ham')
1412
1413        with self.assertRaises(interpreters.ChannelNotEmptyError):
1414            interpreters.channel_close(cid)
1415        interpreters.channel_recv(cid)
1416        interpreters.channel_send(cid, b'eggs')
1417
1418    def test_close_recv_with_unused_items_unforced(self):
1419        cid = interpreters.channel_create()
1420        interpreters.channel_send(cid, b'spam')
1421        interpreters.channel_send(cid, b'ham')
1422
1423        with self.assertRaises(interpreters.ChannelNotEmptyError):
1424            interpreters.channel_close(cid, recv=True)
1425        interpreters.channel_recv(cid)
1426        interpreters.channel_send(cid, b'eggs')
1427        interpreters.channel_recv(cid)
1428        interpreters.channel_recv(cid)
1429        interpreters.channel_close(cid, recv=True)
1430
1431    def test_close_send_with_unused_items_unforced(self):
1432        cid = interpreters.channel_create()
1433        interpreters.channel_send(cid, b'spam')
1434        interpreters.channel_send(cid, b'ham')
1435        interpreters.channel_close(cid, send=True)
1436
1437        with self.assertRaises(interpreters.ChannelClosedError):
1438            interpreters.channel_send(cid, b'eggs')
1439        interpreters.channel_recv(cid)
1440        interpreters.channel_recv(cid)
1441        with self.assertRaises(interpreters.ChannelClosedError):
1442            interpreters.channel_recv(cid)
1443
1444    def test_close_both_with_unused_items_unforced(self):
1445        cid = interpreters.channel_create()
1446        interpreters.channel_send(cid, b'spam')
1447        interpreters.channel_send(cid, b'ham')
1448
1449        with self.assertRaises(interpreters.ChannelNotEmptyError):
1450            interpreters.channel_close(cid, recv=True, send=True)
1451        interpreters.channel_recv(cid)
1452        interpreters.channel_send(cid, b'eggs')
1453        interpreters.channel_recv(cid)
1454        interpreters.channel_recv(cid)
1455        interpreters.channel_close(cid, recv=True)
1456
1457    def test_close_recv_with_unused_items_forced(self):
1458        cid = interpreters.channel_create()
1459        interpreters.channel_send(cid, b'spam')
1460        interpreters.channel_send(cid, b'ham')
1461        interpreters.channel_close(cid, recv=True, force=True)
1462
1463        with self.assertRaises(interpreters.ChannelClosedError):
1464            interpreters.channel_send(cid, b'eggs')
1465        with self.assertRaises(interpreters.ChannelClosedError):
1466            interpreters.channel_recv(cid)
1467
1468    def test_close_send_with_unused_items_forced(self):
1469        cid = interpreters.channel_create()
1470        interpreters.channel_send(cid, b'spam')
1471        interpreters.channel_send(cid, b'ham')
1472        interpreters.channel_close(cid, send=True, force=True)
1473
1474        with self.assertRaises(interpreters.ChannelClosedError):
1475            interpreters.channel_send(cid, b'eggs')
1476        with self.assertRaises(interpreters.ChannelClosedError):
1477            interpreters.channel_recv(cid)
1478
1479    def test_close_both_with_unused_items_forced(self):
1480        cid = interpreters.channel_create()
1481        interpreters.channel_send(cid, b'spam')
1482        interpreters.channel_send(cid, b'ham')
1483        interpreters.channel_close(cid, send=True, recv=True, force=True)
1484
1485        with self.assertRaises(interpreters.ChannelClosedError):
1486            interpreters.channel_send(cid, b'eggs')
1487        with self.assertRaises(interpreters.ChannelClosedError):
1488            interpreters.channel_recv(cid)
1489
1490    def test_close_never_used(self):
1491        cid = interpreters.channel_create()
1492        interpreters.channel_close(cid)
1493
1494        with self.assertRaises(interpreters.ChannelClosedError):
1495            interpreters.channel_send(cid, b'spam')
1496        with self.assertRaises(interpreters.ChannelClosedError):
1497            interpreters.channel_recv(cid)
1498
1499    def test_close_by_unassociated_interp(self):
1500        cid = interpreters.channel_create()
1501        interpreters.channel_send(cid, b'spam')
1502        interp = interpreters.create()
1503        interpreters.run_string(interp, dedent(f"""
1504            import _xxsubinterpreters as _interpreters
1505            _interpreters.channel_close({cid}, force=True)
1506            """))
1507        with self.assertRaises(interpreters.ChannelClosedError):
1508            interpreters.channel_recv(cid)
1509        with self.assertRaises(interpreters.ChannelClosedError):
1510            interpreters.channel_close(cid)
1511
1512    def test_close_used_multiple_times_by_single_user(self):
1513        cid = interpreters.channel_create()
1514        interpreters.channel_send(cid, b'spam')
1515        interpreters.channel_send(cid, b'spam')
1516        interpreters.channel_send(cid, b'spam')
1517        interpreters.channel_recv(cid)
1518        interpreters.channel_close(cid, force=True)
1519
1520        with self.assertRaises(interpreters.ChannelClosedError):
1521            interpreters.channel_send(cid, b'eggs')
1522        with self.assertRaises(interpreters.ChannelClosedError):
1523            interpreters.channel_recv(cid)
1524
1525
1526class ChannelReleaseTests(TestBase):
1527
1528    # XXX Add more test coverage a la the tests for close().
1529
1530    """
1531    - main / interp / other
1532    - run in: current thread / new thread / other thread / different threads
1533    - end / opposite
1534    - force / no force
1535    - used / not used  (associated / not associated)
1536    - empty / emptied / never emptied / partly emptied
1537    - closed / not closed
1538    - released / not released
1539    - creator (interp) / other
1540    - associated interpreter not running
1541    - associated interpreter destroyed
1542    """
1543
1544    """
1545    use
1546    pre-release
1547    release
1548    after
1549    check
1550    """
1551
1552    """
1553    release in:         main, interp1
1554    creator:            same, other (incl. interp2)
1555
1556    use:                None,send,recv,send/recv in None,same,other(incl. interp2),same+other(incl. interp2),all
1557    pre-release:        None,send,recv,both in None,same,other(incl. interp2),same+other(incl. interp2),all
1558    pre-release forced: None,send,recv,both in None,same,other(incl. interp2),same+other(incl. interp2),all
1559
1560    release:            same
1561    release forced:     same
1562
1563    use after:          None,send,recv,send/recv in None,same,other(incl. interp2),same+other(incl. interp2),all
1564    release after:      None,send,recv,send/recv in None,same,other(incl. interp2),same+other(incl. interp2),all
1565    check released:     send/recv for same/other(incl. interp2)
1566    check closed:       send/recv for same/other(incl. interp2)
1567    """
1568
1569    def test_single_user(self):
1570        cid = interpreters.channel_create()
1571        interpreters.channel_send(cid, b'spam')
1572        interpreters.channel_recv(cid)
1573        interpreters.channel_release(cid, send=True, recv=True)
1574
1575        with self.assertRaises(interpreters.ChannelClosedError):
1576            interpreters.channel_send(cid, b'eggs')
1577        with self.assertRaises(interpreters.ChannelClosedError):
1578            interpreters.channel_recv(cid)
1579
1580    def test_multiple_users(self):
1581        cid = interpreters.channel_create()
1582        id1 = interpreters.create()
1583        id2 = interpreters.create()
1584        interpreters.run_string(id1, dedent(f"""
1585            import _xxsubinterpreters as _interpreters
1586            _interpreters.channel_send({cid}, b'spam')
1587            """))
1588        out = _run_output(id2, dedent(f"""
1589            import _xxsubinterpreters as _interpreters
1590            obj = _interpreters.channel_recv({cid})
1591            _interpreters.channel_release({cid})
1592            print(repr(obj))
1593            """))
1594        interpreters.run_string(id1, dedent(f"""
1595            _interpreters.channel_release({cid})
1596            """))
1597
1598        self.assertEqual(out.strip(), "b'spam'")
1599
1600    def test_no_kwargs(self):
1601        cid = interpreters.channel_create()
1602        interpreters.channel_send(cid, b'spam')
1603        interpreters.channel_recv(cid)
1604        interpreters.channel_release(cid)
1605
1606        with self.assertRaises(interpreters.ChannelClosedError):
1607            interpreters.channel_send(cid, b'eggs')
1608        with self.assertRaises(interpreters.ChannelClosedError):
1609            interpreters.channel_recv(cid)
1610
1611    def test_multiple_times(self):
1612        cid = interpreters.channel_create()
1613        interpreters.channel_send(cid, b'spam')
1614        interpreters.channel_recv(cid)
1615        interpreters.channel_release(cid, send=True, recv=True)
1616
1617        with self.assertRaises(interpreters.ChannelClosedError):
1618            interpreters.channel_release(cid, send=True, recv=True)
1619
1620    def test_with_unused_items(self):
1621        cid = interpreters.channel_create()
1622        interpreters.channel_send(cid, b'spam')
1623        interpreters.channel_send(cid, b'ham')
1624        interpreters.channel_release(cid, send=True, recv=True)
1625
1626        with self.assertRaises(interpreters.ChannelClosedError):
1627            interpreters.channel_recv(cid)
1628
1629    def test_never_used(self):
1630        cid = interpreters.channel_create()
1631        interpreters.channel_release(cid)
1632
1633        with self.assertRaises(interpreters.ChannelClosedError):
1634            interpreters.channel_send(cid, b'spam')
1635        with self.assertRaises(interpreters.ChannelClosedError):
1636            interpreters.channel_recv(cid)
1637
1638    def test_by_unassociated_interp(self):
1639        cid = interpreters.channel_create()
1640        interpreters.channel_send(cid, b'spam')
1641        interp = interpreters.create()
1642        interpreters.run_string(interp, dedent(f"""
1643            import _xxsubinterpreters as _interpreters
1644            _interpreters.channel_release({cid})
1645            """))
1646        obj = interpreters.channel_recv(cid)
1647        interpreters.channel_release(cid)
1648
1649        with self.assertRaises(interpreters.ChannelClosedError):
1650            interpreters.channel_send(cid, b'eggs')
1651        self.assertEqual(obj, b'spam')
1652
1653    def test_close_if_unassociated(self):
1654        # XXX Something's not right with this test...
1655        cid = interpreters.channel_create()
1656        interp = interpreters.create()
1657        interpreters.run_string(interp, dedent(f"""
1658            import _xxsubinterpreters as _interpreters
1659            obj = _interpreters.channel_send({cid}, b'spam')
1660            _interpreters.channel_release({cid})
1661            """))
1662
1663        with self.assertRaises(interpreters.ChannelClosedError):
1664            interpreters.channel_recv(cid)
1665
1666    def test_partially(self):
1667        # XXX Is partial close too weird/confusing?
1668        cid = interpreters.channel_create()
1669        interpreters.channel_send(cid, None)
1670        interpreters.channel_recv(cid)
1671        interpreters.channel_send(cid, b'spam')
1672        interpreters.channel_release(cid, send=True)
1673        obj = interpreters.channel_recv(cid)
1674
1675        self.assertEqual(obj, b'spam')
1676
1677    def test_used_multiple_times_by_single_user(self):
1678        cid = interpreters.channel_create()
1679        interpreters.channel_send(cid, b'spam')
1680        interpreters.channel_send(cid, b'spam')
1681        interpreters.channel_send(cid, b'spam')
1682        interpreters.channel_recv(cid)
1683        interpreters.channel_release(cid, send=True, recv=True)
1684
1685        with self.assertRaises(interpreters.ChannelClosedError):
1686            interpreters.channel_send(cid, b'eggs')
1687        with self.assertRaises(interpreters.ChannelClosedError):
1688            interpreters.channel_recv(cid)
1689
1690
1691class ChannelCloseFixture(namedtuple('ChannelCloseFixture',
1692                                     'end interp other extra creator')):
1693
1694    # Set this to True to avoid creating interpreters, e.g. when
1695    # scanning through test permutations without running them.
1696    QUICK = False
1697
1698    def __new__(cls, end, interp, other, extra, creator):
1699        assert end in ('send', 'recv')
1700        if cls.QUICK:
1701            known = {}
1702        else:
1703            interp = Interpreter.from_raw(interp)
1704            other = Interpreter.from_raw(other)
1705            extra = Interpreter.from_raw(extra)
1706            known = {
1707                interp.name: interp,
1708                other.name: other,
1709                extra.name: extra,
1710                }
1711        if not creator:
1712            creator = 'same'
1713        self = super().__new__(cls, end, interp, other, extra, creator)
1714        self._prepped = set()
1715        self._state = ChannelState()
1716        self._known = known
1717        return self
1718
1719    @property
1720    def state(self):
1721        return self._state
1722
1723    @property
1724    def cid(self):
1725        try:
1726            return self._cid
1727        except AttributeError:
1728            creator = self._get_interpreter(self.creator)
1729            self._cid = self._new_channel(creator)
1730            return self._cid
1731
1732    def get_interpreter(self, interp):
1733        interp = self._get_interpreter(interp)
1734        self._prep_interpreter(interp)
1735        return interp
1736
1737    def expect_closed_error(self, end=None):
1738        if end is None:
1739            end = self.end
1740        if end == 'recv' and self.state.closed == 'send':
1741            return False
1742        return bool(self.state.closed)
1743
1744    def prep_interpreter(self, interp):
1745        self._prep_interpreter(interp)
1746
1747    def record_action(self, action, result):
1748        self._state = result
1749
1750    def clean_up(self):
1751        clean_up_interpreters()
1752        clean_up_channels()
1753
1754    # internal methods
1755
1756    def _new_channel(self, creator):
1757        if creator.name == 'main':
1758            return interpreters.channel_create()
1759        else:
1760            ch = interpreters.channel_create()
1761            run_interp(creator.id, f"""
1762                import _xxsubinterpreters
1763                cid = _xxsubinterpreters.channel_create()
1764                # We purposefully send back an int to avoid tying the
1765                # channel to the other interpreter.
1766                _xxsubinterpreters.channel_send({ch}, int(cid))
1767                del _xxsubinterpreters
1768                """)
1769            self._cid = interpreters.channel_recv(ch)
1770        return self._cid
1771
1772    def _get_interpreter(self, interp):
1773        if interp in ('same', 'interp'):
1774            return self.interp
1775        elif interp == 'other':
1776            return self.other
1777        elif interp == 'extra':
1778            return self.extra
1779        else:
1780            name = interp
1781            try:
1782                interp = self._known[name]
1783            except KeyError:
1784                interp = self._known[name] = Interpreter(name)
1785            return interp
1786
1787    def _prep_interpreter(self, interp):
1788        if interp.id in self._prepped:
1789            return
1790        self._prepped.add(interp.id)
1791        if interp.name == 'main':
1792            return
1793        run_interp(interp.id, f"""
1794            import _xxsubinterpreters as interpreters
1795            import test.test__xxsubinterpreters as helpers
1796            ChannelState = helpers.ChannelState
1797            try:
1798                cid
1799            except NameError:
1800                cid = interpreters._channel_id({self.cid})
1801            """)
1802
1803
1804@unittest.skip('these tests take several hours to run')
1805class ExhaustiveChannelTests(TestBase):
1806
1807    """
1808    - main / interp / other
1809    - run in: current thread / new thread / other thread / different threads
1810    - end / opposite
1811    - force / no force
1812    - used / not used  (associated / not associated)
1813    - empty / emptied / never emptied / partly emptied
1814    - closed / not closed
1815    - released / not released
1816    - creator (interp) / other
1817    - associated interpreter not running
1818    - associated interpreter destroyed
1819
1820    - close after unbound
1821    """
1822
1823    """
1824    use
1825    pre-close
1826    close
1827    after
1828    check
1829    """
1830
1831    """
1832    close in:         main, interp1
1833    creator:          same, other, extra
1834
1835    use:              None,send,recv,send/recv in None,same,other,same+other,all
1836    pre-close:        None,send,recv in None,same,other,same+other,all
1837    pre-close forced: None,send,recv in None,same,other,same+other,all
1838
1839    close:            same
1840    close forced:     same
1841
1842    use after:        None,send,recv,send/recv in None,same,other,extra,same+other,all
1843    close after:      None,send,recv,send/recv in None,same,other,extra,same+other,all
1844    check closed:     send/recv for same/other(incl. interp2)
1845    """
1846
1847    def iter_action_sets(self):
1848        # - used / not used  (associated / not associated)
1849        # - empty / emptied / never emptied / partly emptied
1850        # - closed / not closed
1851        # - released / not released
1852
1853        # never used
1854        yield []
1855
1856        # only pre-closed (and possible used after)
1857        for closeactions in self._iter_close_action_sets('same', 'other'):
1858            yield closeactions
1859            for postactions in self._iter_post_close_action_sets():
1860                yield closeactions + postactions
1861        for closeactions in self._iter_close_action_sets('other', 'extra'):
1862            yield closeactions
1863            for postactions in self._iter_post_close_action_sets():
1864                yield closeactions + postactions
1865
1866        # used
1867        for useactions in self._iter_use_action_sets('same', 'other'):
1868            yield useactions
1869            for closeactions in self._iter_close_action_sets('same', 'other'):
1870                actions = useactions + closeactions
1871                yield actions
1872                for postactions in self._iter_post_close_action_sets():
1873                    yield actions + postactions
1874            for closeactions in self._iter_close_action_sets('other', 'extra'):
1875                actions = useactions + closeactions
1876                yield actions
1877                for postactions in self._iter_post_close_action_sets():
1878                    yield actions + postactions
1879        for useactions in self._iter_use_action_sets('other', 'extra'):
1880            yield useactions
1881            for closeactions in self._iter_close_action_sets('same', 'other'):
1882                actions = useactions + closeactions
1883                yield actions
1884                for postactions in self._iter_post_close_action_sets():
1885                    yield actions + postactions
1886            for closeactions in self._iter_close_action_sets('other', 'extra'):
1887                actions = useactions + closeactions
1888                yield actions
1889                for postactions in self._iter_post_close_action_sets():
1890                    yield actions + postactions
1891
1892    def _iter_use_action_sets(self, interp1, interp2):
1893        interps = (interp1, interp2)
1894
1895        # only recv end used
1896        yield [
1897            ChannelAction('use', 'recv', interp1),
1898            ]
1899        yield [
1900            ChannelAction('use', 'recv', interp2),
1901            ]
1902        yield [
1903            ChannelAction('use', 'recv', interp1),
1904            ChannelAction('use', 'recv', interp2),
1905            ]
1906
1907        # never emptied
1908        yield [
1909            ChannelAction('use', 'send', interp1),
1910            ]
1911        yield [
1912            ChannelAction('use', 'send', interp2),
1913            ]
1914        yield [
1915            ChannelAction('use', 'send', interp1),
1916            ChannelAction('use', 'send', interp2),
1917            ]
1918
1919        # partially emptied
1920        for interp1 in interps:
1921            for interp2 in interps:
1922                for interp3 in interps:
1923                    yield [
1924                        ChannelAction('use', 'send', interp1),
1925                        ChannelAction('use', 'send', interp2),
1926                        ChannelAction('use', 'recv', interp3),
1927                        ]
1928
1929        # fully emptied
1930        for interp1 in interps:
1931            for interp2 in interps:
1932                for interp3 in interps:
1933                    for interp4 in interps:
1934                        yield [
1935                            ChannelAction('use', 'send', interp1),
1936                            ChannelAction('use', 'send', interp2),
1937                            ChannelAction('use', 'recv', interp3),
1938                            ChannelAction('use', 'recv', interp4),
1939                            ]
1940
1941    def _iter_close_action_sets(self, interp1, interp2):
1942        ends = ('recv', 'send')
1943        interps = (interp1, interp2)
1944        for force in (True, False):
1945            op = 'force-close' if force else 'close'
1946            for interp in interps:
1947                for end in ends:
1948                    yield [
1949                        ChannelAction(op, end, interp),
1950                        ]
1951        for recvop in ('close', 'force-close'):
1952            for sendop in ('close', 'force-close'):
1953                for recv in interps:
1954                    for send in interps:
1955                        yield [
1956                            ChannelAction(recvop, 'recv', recv),
1957                            ChannelAction(sendop, 'send', send),
1958                            ]
1959
1960    def _iter_post_close_action_sets(self):
1961        for interp in ('same', 'extra', 'other'):
1962            yield [
1963                ChannelAction('use', 'recv', interp),
1964                ]
1965            yield [
1966                ChannelAction('use', 'send', interp),
1967                ]
1968
1969    def run_actions(self, fix, actions):
1970        for action in actions:
1971            self.run_action(fix, action)
1972
1973    def run_action(self, fix, action, *, hideclosed=True):
1974        end = action.resolve_end(fix.end)
1975        interp = action.resolve_interp(fix.interp, fix.other, fix.extra)
1976        fix.prep_interpreter(interp)
1977        if interp.name == 'main':
1978            result = run_action(
1979                fix.cid,
1980                action.action,
1981                end,
1982                fix.state,
1983                hideclosed=hideclosed,
1984                )
1985            fix.record_action(action, result)
1986        else:
1987            _cid = interpreters.channel_create()
1988            run_interp(interp.id, f"""
1989                result = helpers.run_action(
1990                    {fix.cid},
1991                    {repr(action.action)},
1992                    {repr(end)},
1993                    {repr(fix.state)},
1994                    hideclosed={hideclosed},
1995                    )
1996                interpreters.channel_send({_cid}, result.pending.to_bytes(1, 'little'))
1997                interpreters.channel_send({_cid}, b'X' if result.closed else b'')
1998                """)
1999            result = ChannelState(
2000                pending=int.from_bytes(interpreters.channel_recv(_cid), 'little'),
2001                closed=bool(interpreters.channel_recv(_cid)),
2002                )
2003            fix.record_action(action, result)
2004
2005    def iter_fixtures(self):
2006        # XXX threads?
2007        interpreters = [
2008            ('main', 'interp', 'extra'),
2009            ('interp', 'main', 'extra'),
2010            ('interp1', 'interp2', 'extra'),
2011            ('interp1', 'interp2', 'main'),
2012        ]
2013        for interp, other, extra in interpreters:
2014            for creator in ('same', 'other', 'creator'):
2015                for end in ('send', 'recv'):
2016                    yield ChannelCloseFixture(end, interp, other, extra, creator)
2017
2018    def _close(self, fix, *, force):
2019        op = 'force-close' if force else 'close'
2020        close = ChannelAction(op, fix.end, 'same')
2021        if not fix.expect_closed_error():
2022            self.run_action(fix, close, hideclosed=False)
2023        else:
2024            with self.assertRaises(interpreters.ChannelClosedError):
2025                self.run_action(fix, close, hideclosed=False)
2026
2027    def _assert_closed_in_interp(self, fix, interp=None):
2028        if interp is None or interp.name == 'main':
2029            with self.assertRaises(interpreters.ChannelClosedError):
2030                interpreters.channel_recv(fix.cid)
2031            with self.assertRaises(interpreters.ChannelClosedError):
2032                interpreters.channel_send(fix.cid, b'spam')
2033            with self.assertRaises(interpreters.ChannelClosedError):
2034                interpreters.channel_close(fix.cid)
2035            with self.assertRaises(interpreters.ChannelClosedError):
2036                interpreters.channel_close(fix.cid, force=True)
2037        else:
2038            run_interp(interp.id, f"""
2039                with helpers.expect_channel_closed():
2040                    interpreters.channel_recv(cid)
2041                """)
2042            run_interp(interp.id, f"""
2043                with helpers.expect_channel_closed():
2044                    interpreters.channel_send(cid, b'spam')
2045                """)
2046            run_interp(interp.id, f"""
2047                with helpers.expect_channel_closed():
2048                    interpreters.channel_close(cid)
2049                """)
2050            run_interp(interp.id, f"""
2051                with helpers.expect_channel_closed():
2052                    interpreters.channel_close(cid, force=True)
2053                """)
2054
2055    def _assert_closed(self, fix):
2056        self.assertTrue(fix.state.closed)
2057
2058        for _ in range(fix.state.pending):
2059            interpreters.channel_recv(fix.cid)
2060        self._assert_closed_in_interp(fix)
2061
2062        for interp in ('same', 'other'):
2063            interp = fix.get_interpreter(interp)
2064            if interp.name == 'main':
2065                continue
2066            self._assert_closed_in_interp(fix, interp)
2067
2068        interp = fix.get_interpreter('fresh')
2069        self._assert_closed_in_interp(fix, interp)
2070
2071    def _iter_close_tests(self, verbose=False):
2072        i = 0
2073        for actions in self.iter_action_sets():
2074            print()
2075            for fix in self.iter_fixtures():
2076                i += 1
2077                if i > 1000:
2078                    return
2079                if verbose:
2080                    if (i - 1) % 6 == 0:
2081                        print()
2082                    print(i, fix, '({} actions)'.format(len(actions)))
2083                else:
2084                    if (i - 1) % 6 == 0:
2085                        print(' ', end='')
2086                    print('.', end=''); sys.stdout.flush()
2087                yield i, fix, actions
2088            if verbose:
2089                print('---')
2090        print()
2091
2092    # This is useful for scanning through the possible tests.
2093    def _skim_close_tests(self):
2094        ChannelCloseFixture.QUICK = True
2095        for i, fix, actions in self._iter_close_tests():
2096            pass
2097
2098    def test_close(self):
2099        for i, fix, actions in self._iter_close_tests():
2100            with self.subTest('{} {}  {}'.format(i, fix, actions)):
2101                fix.prep_interpreter(fix.interp)
2102                self.run_actions(fix, actions)
2103
2104                self._close(fix, force=False)
2105
2106                self._assert_closed(fix)
2107            # XXX Things slow down if we have too many interpreters.
2108            fix.clean_up()
2109
2110    def test_force_close(self):
2111        for i, fix, actions in self._iter_close_tests():
2112            with self.subTest('{} {}  {}'.format(i, fix, actions)):
2113                fix.prep_interpreter(fix.interp)
2114                self.run_actions(fix, actions)
2115
2116                self._close(fix, force=True)
2117
2118                self._assert_closed(fix)
2119            # XXX Things slow down if we have too many interpreters.
2120            fix.clean_up()
2121
2122
2123if __name__ == '__main__':
2124    unittest.main()
2125