1import contextlib 2import os 3import threading 4from textwrap import dedent 5import unittest 6import time 7 8import _xxsubinterpreters as _interpreters 9from test.support import interpreters 10 11 12def _captured_script(script): 13 r, w = os.pipe() 14 indented = script.replace('\n', '\n ') 15 wrapped = dedent(f""" 16 import contextlib 17 with open({w}, 'w', encoding='utf-8') as spipe: 18 with contextlib.redirect_stdout(spipe): 19 {indented} 20 """) 21 return wrapped, open(r, encoding='utf-8') 22 23 24def clean_up_interpreters(): 25 for interp in interpreters.list_all(): 26 if interp.id == 0: # main 27 continue 28 try: 29 interp.close() 30 except RuntimeError: 31 pass # already destroyed 32 33 34def _run_output(interp, request, channels=None): 35 script, rpipe = _captured_script(request) 36 with rpipe: 37 interp.run(script, channels=channels) 38 return rpipe.read() 39 40 41@contextlib.contextmanager 42def _running(interp): 43 r, w = os.pipe() 44 def run(): 45 interp.run(dedent(f""" 46 # wait for "signal" 47 with open({r}) as rpipe: 48 rpipe.read() 49 """)) 50 51 t = threading.Thread(target=run) 52 t.start() 53 54 yield 55 56 with open(w, 'w') as spipe: 57 spipe.write('done') 58 t.join() 59 60 61class TestBase(unittest.TestCase): 62 63 def tearDown(self): 64 clean_up_interpreters() 65 66 67class CreateTests(TestBase): 68 69 def test_in_main(self): 70 interp = interpreters.create() 71 self.assertIsInstance(interp, interpreters.Interpreter) 72 self.assertIn(interp, interpreters.list_all()) 73 74 def test_in_thread(self): 75 lock = threading.Lock() 76 interp = None 77 def f(): 78 nonlocal interp 79 interp = interpreters.create() 80 lock.acquire() 81 lock.release() 82 t = threading.Thread(target=f) 83 with lock: 84 t.start() 85 t.join() 86 self.assertIn(interp, interpreters.list_all()) 87 88 def test_in_subinterpreter(self): 89 main, = interpreters.list_all() 90 interp = interpreters.create() 91 out = _run_output(interp, dedent(""" 92 from test.support import interpreters 93 interp = interpreters.create() 94 print(interp.id) 95 """)) 96 interp2 = interpreters.Interpreter(int(out)) 97 self.assertEqual(interpreters.list_all(), [main, interp, interp2]) 98 99 def test_after_destroy_all(self): 100 before = set(interpreters.list_all()) 101 # Create 3 subinterpreters. 102 interp_lst = [] 103 for _ in range(3): 104 interps = interpreters.create() 105 interp_lst.append(interps) 106 # Now destroy them. 107 for interp in interp_lst: 108 interp.close() 109 # Finally, create another. 110 interp = interpreters.create() 111 self.assertEqual(set(interpreters.list_all()), before | {interp}) 112 113 def test_after_destroy_some(self): 114 before = set(interpreters.list_all()) 115 # Create 3 subinterpreters. 116 interp1 = interpreters.create() 117 interp2 = interpreters.create() 118 interp3 = interpreters.create() 119 # Now destroy 2 of them. 120 interp1.close() 121 interp2.close() 122 # Finally, create another. 123 interp = interpreters.create() 124 self.assertEqual(set(interpreters.list_all()), before | {interp3, interp}) 125 126 127class GetCurrentTests(TestBase): 128 129 def test_main(self): 130 main = interpreters.get_main() 131 current = interpreters.get_current() 132 self.assertEqual(current, main) 133 134 def test_subinterpreter(self): 135 main = _interpreters.get_main() 136 interp = interpreters.create() 137 out = _run_output(interp, dedent(""" 138 from test.support import interpreters 139 cur = interpreters.get_current() 140 print(cur.id) 141 """)) 142 current = interpreters.Interpreter(int(out)) 143 self.assertNotEqual(current, main) 144 145 146class ListAllTests(TestBase): 147 148 def test_initial(self): 149 interps = interpreters.list_all() 150 self.assertEqual(1, len(interps)) 151 152 def test_after_creating(self): 153 main = interpreters.get_current() 154 first = interpreters.create() 155 second = interpreters.create() 156 157 ids = [] 158 for interp in interpreters.list_all(): 159 ids.append(interp.id) 160 161 self.assertEqual(ids, [main.id, first.id, second.id]) 162 163 def test_after_destroying(self): 164 main = interpreters.get_current() 165 first = interpreters.create() 166 second = interpreters.create() 167 first.close() 168 169 ids = [] 170 for interp in interpreters.list_all(): 171 ids.append(interp.id) 172 173 self.assertEqual(ids, [main.id, second.id]) 174 175 176class TestInterpreterAttrs(TestBase): 177 178 def test_id_type(self): 179 main = interpreters.get_main() 180 current = interpreters.get_current() 181 interp = interpreters.create() 182 self.assertIsInstance(main.id, _interpreters.InterpreterID) 183 self.assertIsInstance(current.id, _interpreters.InterpreterID) 184 self.assertIsInstance(interp.id, _interpreters.InterpreterID) 185 186 def test_main_id(self): 187 main = interpreters.get_main() 188 self.assertEqual(main.id, 0) 189 190 def test_custom_id(self): 191 interp = interpreters.Interpreter(1) 192 self.assertEqual(interp.id, 1) 193 194 with self.assertRaises(TypeError): 195 interpreters.Interpreter('1') 196 197 def test_id_readonly(self): 198 interp = interpreters.Interpreter(1) 199 with self.assertRaises(AttributeError): 200 interp.id = 2 201 202 @unittest.skip('not ready yet (see bpo-32604)') 203 def test_main_isolated(self): 204 main = interpreters.get_main() 205 self.assertFalse(main.isolated) 206 207 @unittest.skip('not ready yet (see bpo-32604)') 208 def test_subinterpreter_isolated_default(self): 209 interp = interpreters.create() 210 self.assertFalse(interp.isolated) 211 212 def test_subinterpreter_isolated_explicit(self): 213 interp1 = interpreters.create(isolated=True) 214 interp2 = interpreters.create(isolated=False) 215 self.assertTrue(interp1.isolated) 216 self.assertFalse(interp2.isolated) 217 218 @unittest.skip('not ready yet (see bpo-32604)') 219 def test_custom_isolated_default(self): 220 interp = interpreters.Interpreter(1) 221 self.assertFalse(interp.isolated) 222 223 def test_custom_isolated_explicit(self): 224 interp1 = interpreters.Interpreter(1, isolated=True) 225 interp2 = interpreters.Interpreter(1, isolated=False) 226 self.assertTrue(interp1.isolated) 227 self.assertFalse(interp2.isolated) 228 229 def test_isolated_readonly(self): 230 interp = interpreters.Interpreter(1) 231 with self.assertRaises(AttributeError): 232 interp.isolated = True 233 234 def test_equality(self): 235 interp1 = interpreters.create() 236 interp2 = interpreters.create() 237 self.assertEqual(interp1, interp1) 238 self.assertNotEqual(interp1, interp2) 239 240 241class TestInterpreterIsRunning(TestBase): 242 243 def test_main(self): 244 main = interpreters.get_main() 245 self.assertTrue(main.is_running()) 246 247 @unittest.skip('Fails on FreeBSD') 248 def test_subinterpreter(self): 249 interp = interpreters.create() 250 self.assertFalse(interp.is_running()) 251 252 with _running(interp): 253 self.assertTrue(interp.is_running()) 254 self.assertFalse(interp.is_running()) 255 256 def test_from_subinterpreter(self): 257 interp = interpreters.create() 258 out = _run_output(interp, dedent(f""" 259 import _xxsubinterpreters as _interpreters 260 if _interpreters.is_running({interp.id}): 261 print(True) 262 else: 263 print(False) 264 """)) 265 self.assertEqual(out.strip(), 'True') 266 267 def test_already_destroyed(self): 268 interp = interpreters.create() 269 interp.close() 270 with self.assertRaises(RuntimeError): 271 interp.is_running() 272 273 def test_does_not_exist(self): 274 interp = interpreters.Interpreter(1_000_000) 275 with self.assertRaises(RuntimeError): 276 interp.is_running() 277 278 def test_bad_id(self): 279 interp = interpreters.Interpreter(-1) 280 with self.assertRaises(ValueError): 281 interp.is_running() 282 283 284class TestInterpreterClose(TestBase): 285 286 def test_basic(self): 287 main = interpreters.get_main() 288 interp1 = interpreters.create() 289 interp2 = interpreters.create() 290 interp3 = interpreters.create() 291 self.assertEqual(set(interpreters.list_all()), 292 {main, interp1, interp2, interp3}) 293 interp2.close() 294 self.assertEqual(set(interpreters.list_all()), 295 {main, interp1, interp3}) 296 297 def test_all(self): 298 before = set(interpreters.list_all()) 299 interps = set() 300 for _ in range(3): 301 interp = interpreters.create() 302 interps.add(interp) 303 self.assertEqual(set(interpreters.list_all()), before | interps) 304 for interp in interps: 305 interp.close() 306 self.assertEqual(set(interpreters.list_all()), before) 307 308 def test_main(self): 309 main, = interpreters.list_all() 310 with self.assertRaises(RuntimeError): 311 main.close() 312 313 def f(): 314 with self.assertRaises(RuntimeError): 315 main.close() 316 317 t = threading.Thread(target=f) 318 t.start() 319 t.join() 320 321 def test_already_destroyed(self): 322 interp = interpreters.create() 323 interp.close() 324 with self.assertRaises(RuntimeError): 325 interp.close() 326 327 def test_does_not_exist(self): 328 interp = interpreters.Interpreter(1_000_000) 329 with self.assertRaises(RuntimeError): 330 interp.close() 331 332 def test_bad_id(self): 333 interp = interpreters.Interpreter(-1) 334 with self.assertRaises(ValueError): 335 interp.close() 336 337 def test_from_current(self): 338 main, = interpreters.list_all() 339 interp = interpreters.create() 340 out = _run_output(interp, dedent(f""" 341 from test.support import interpreters 342 interp = interpreters.Interpreter({int(interp.id)}) 343 try: 344 interp.close() 345 except RuntimeError: 346 print('failed') 347 """)) 348 self.assertEqual(out.strip(), 'failed') 349 self.assertEqual(set(interpreters.list_all()), {main, interp}) 350 351 def test_from_sibling(self): 352 main, = interpreters.list_all() 353 interp1 = interpreters.create() 354 interp2 = interpreters.create() 355 self.assertEqual(set(interpreters.list_all()), 356 {main, interp1, interp2}) 357 interp1.run(dedent(f""" 358 from test.support import interpreters 359 interp2 = interpreters.Interpreter(int({interp2.id})) 360 interp2.close() 361 interp3 = interpreters.create() 362 interp3.close() 363 """)) 364 self.assertEqual(set(interpreters.list_all()), {main, interp1}) 365 366 def test_from_other_thread(self): 367 interp = interpreters.create() 368 def f(): 369 interp.close() 370 371 t = threading.Thread(target=f) 372 t.start() 373 t.join() 374 375 @unittest.skip('Fails on FreeBSD') 376 def test_still_running(self): 377 main, = interpreters.list_all() 378 interp = interpreters.create() 379 with _running(interp): 380 with self.assertRaises(RuntimeError): 381 interp.close() 382 self.assertTrue(interp.is_running()) 383 384 385class TestInterpreterRun(TestBase): 386 387 def test_success(self): 388 interp = interpreters.create() 389 script, file = _captured_script('print("it worked!", end="")') 390 with file: 391 interp.run(script) 392 out = file.read() 393 394 self.assertEqual(out, 'it worked!') 395 396 def test_in_thread(self): 397 interp = interpreters.create() 398 script, file = _captured_script('print("it worked!", end="")') 399 with file: 400 def f(): 401 interp.run(script) 402 403 t = threading.Thread(target=f) 404 t.start() 405 t.join() 406 out = file.read() 407 408 self.assertEqual(out, 'it worked!') 409 410 @unittest.skipUnless(hasattr(os, 'fork'), "test needs os.fork()") 411 def test_fork(self): 412 interp = interpreters.create() 413 import tempfile 414 with tempfile.NamedTemporaryFile('w+', encoding='utf-8') as file: 415 file.write('') 416 file.flush() 417 418 expected = 'spam spam spam spam spam' 419 script = dedent(f""" 420 import os 421 try: 422 os.fork() 423 except RuntimeError: 424 with open('{file.name}', 'w', encoding='utf-8') as out: 425 out.write('{expected}') 426 """) 427 interp.run(script) 428 429 file.seek(0) 430 content = file.read() 431 self.assertEqual(content, expected) 432 433 @unittest.skip('Fails on FreeBSD') 434 def test_already_running(self): 435 interp = interpreters.create() 436 with _running(interp): 437 with self.assertRaises(RuntimeError): 438 interp.run('print("spam")') 439 440 def test_does_not_exist(self): 441 interp = interpreters.Interpreter(1_000_000) 442 with self.assertRaises(RuntimeError): 443 interp.run('print("spam")') 444 445 def test_bad_id(self): 446 interp = interpreters.Interpreter(-1) 447 with self.assertRaises(ValueError): 448 interp.run('print("spam")') 449 450 def test_bad_script(self): 451 interp = interpreters.create() 452 with self.assertRaises(TypeError): 453 interp.run(10) 454 455 def test_bytes_for_script(self): 456 interp = interpreters.create() 457 with self.assertRaises(TypeError): 458 interp.run(b'print("spam")') 459 460 # test_xxsubinterpreters covers the remaining Interpreter.run() behavior. 461 462 463class TestIsShareable(TestBase): 464 465 def test_default_shareables(self): 466 shareables = [ 467 # singletons 468 None, 469 # builtin objects 470 b'spam', 471 'spam', 472 10, 473 -10, 474 ] 475 for obj in shareables: 476 with self.subTest(obj): 477 shareable = interpreters.is_shareable(obj) 478 self.assertTrue(shareable) 479 480 def test_not_shareable(self): 481 class Cheese: 482 def __init__(self, name): 483 self.name = name 484 def __str__(self): 485 return self.name 486 487 class SubBytes(bytes): 488 """A subclass of a shareable type.""" 489 490 not_shareables = [ 491 # singletons 492 True, 493 False, 494 NotImplemented, 495 ..., 496 # builtin types and objects 497 type, 498 object, 499 object(), 500 Exception(), 501 100.0, 502 # user-defined types and objects 503 Cheese, 504 Cheese('Wensleydale'), 505 SubBytes(b'spam'), 506 ] 507 for obj in not_shareables: 508 with self.subTest(repr(obj)): 509 self.assertFalse( 510 interpreters.is_shareable(obj)) 511 512 513class TestChannels(TestBase): 514 515 def test_create(self): 516 r, s = interpreters.create_channel() 517 self.assertIsInstance(r, interpreters.RecvChannel) 518 self.assertIsInstance(s, interpreters.SendChannel) 519 520 def test_list_all(self): 521 self.assertEqual(interpreters.list_all_channels(), []) 522 created = set() 523 for _ in range(3): 524 ch = interpreters.create_channel() 525 created.add(ch) 526 after = set(interpreters.list_all_channels()) 527 self.assertEqual(after, created) 528 529 530class TestRecvChannelAttrs(TestBase): 531 532 def test_id_type(self): 533 rch, _ = interpreters.create_channel() 534 self.assertIsInstance(rch.id, _interpreters.ChannelID) 535 536 def test_custom_id(self): 537 rch = interpreters.RecvChannel(1) 538 self.assertEqual(rch.id, 1) 539 540 with self.assertRaises(TypeError): 541 interpreters.RecvChannel('1') 542 543 def test_id_readonly(self): 544 rch = interpreters.RecvChannel(1) 545 with self.assertRaises(AttributeError): 546 rch.id = 2 547 548 def test_equality(self): 549 ch1, _ = interpreters.create_channel() 550 ch2, _ = interpreters.create_channel() 551 self.assertEqual(ch1, ch1) 552 self.assertNotEqual(ch1, ch2) 553 554 555class TestSendChannelAttrs(TestBase): 556 557 def test_id_type(self): 558 _, sch = interpreters.create_channel() 559 self.assertIsInstance(sch.id, _interpreters.ChannelID) 560 561 def test_custom_id(self): 562 sch = interpreters.SendChannel(1) 563 self.assertEqual(sch.id, 1) 564 565 with self.assertRaises(TypeError): 566 interpreters.SendChannel('1') 567 568 def test_id_readonly(self): 569 sch = interpreters.SendChannel(1) 570 with self.assertRaises(AttributeError): 571 sch.id = 2 572 573 def test_equality(self): 574 _, ch1 = interpreters.create_channel() 575 _, ch2 = interpreters.create_channel() 576 self.assertEqual(ch1, ch1) 577 self.assertNotEqual(ch1, ch2) 578 579 580class TestSendRecv(TestBase): 581 582 def test_send_recv_main(self): 583 r, s = interpreters.create_channel() 584 orig = b'spam' 585 s.send_nowait(orig) 586 obj = r.recv() 587 588 self.assertEqual(obj, orig) 589 self.assertIsNot(obj, orig) 590 591 def test_send_recv_same_interpreter(self): 592 interp = interpreters.create() 593 interp.run(dedent(""" 594 from test.support import interpreters 595 r, s = interpreters.create_channel() 596 orig = b'spam' 597 s.send_nowait(orig) 598 obj = r.recv() 599 assert obj == orig, 'expected: obj == orig' 600 assert obj is not orig, 'expected: obj is not orig' 601 """)) 602 603 @unittest.skip('broken (see BPO-...)') 604 def test_send_recv_different_interpreters(self): 605 r1, s1 = interpreters.create_channel() 606 r2, s2 = interpreters.create_channel() 607 orig1 = b'spam' 608 s1.send_nowait(orig1) 609 out = _run_output( 610 interpreters.create(), 611 dedent(f""" 612 obj1 = r.recv() 613 assert obj1 == b'spam', 'expected: obj1 == orig1' 614 # When going to another interpreter we get a copy. 615 assert id(obj1) != {id(orig1)}, 'expected: obj1 is not orig1' 616 orig2 = b'eggs' 617 print(id(orig2)) 618 s.send_nowait(orig2) 619 """), 620 channels=dict(r=r1, s=s2), 621 ) 622 obj2 = r2.recv() 623 624 self.assertEqual(obj2, b'eggs') 625 self.assertNotEqual(id(obj2), int(out)) 626 627 def test_send_recv_different_threads(self): 628 r, s = interpreters.create_channel() 629 630 def f(): 631 while True: 632 try: 633 obj = r.recv() 634 break 635 except interpreters.ChannelEmptyError: 636 time.sleep(0.1) 637 s.send(obj) 638 t = threading.Thread(target=f) 639 t.start() 640 641 orig = b'spam' 642 s.send(orig) 643 t.join() 644 obj = r.recv() 645 646 self.assertEqual(obj, orig) 647 self.assertIsNot(obj, orig) 648 649 def test_send_recv_nowait_main(self): 650 r, s = interpreters.create_channel() 651 orig = b'spam' 652 s.send_nowait(orig) 653 obj = r.recv_nowait() 654 655 self.assertEqual(obj, orig) 656 self.assertIsNot(obj, orig) 657 658 def test_send_recv_nowait_main_with_default(self): 659 r, _ = interpreters.create_channel() 660 obj = r.recv_nowait(None) 661 662 self.assertIsNone(obj) 663 664 def test_send_recv_nowait_same_interpreter(self): 665 interp = interpreters.create() 666 interp.run(dedent(""" 667 from test.support import interpreters 668 r, s = interpreters.create_channel() 669 orig = b'spam' 670 s.send_nowait(orig) 671 obj = r.recv_nowait() 672 assert obj == orig, 'expected: obj == orig' 673 # When going back to the same interpreter we get the same object. 674 assert obj is not orig, 'expected: obj is not orig' 675 """)) 676 677 @unittest.skip('broken (see BPO-...)') 678 def test_send_recv_nowait_different_interpreters(self): 679 r1, s1 = interpreters.create_channel() 680 r2, s2 = interpreters.create_channel() 681 orig1 = b'spam' 682 s1.send_nowait(orig1) 683 out = _run_output( 684 interpreters.create(), 685 dedent(f""" 686 obj1 = r.recv_nowait() 687 assert obj1 == b'spam', 'expected: obj1 == orig1' 688 # When going to another interpreter we get a copy. 689 assert id(obj1) != {id(orig1)}, 'expected: obj1 is not orig1' 690 orig2 = b'eggs' 691 print(id(orig2)) 692 s.send_nowait(orig2) 693 """), 694 channels=dict(r=r1, s=s2), 695 ) 696 obj2 = r2.recv_nowait() 697 698 self.assertEqual(obj2, b'eggs') 699 self.assertNotEqual(id(obj2), int(out)) 700 701 def test_recv_channel_does_not_exist(self): 702 ch = interpreters.RecvChannel(1_000_000) 703 with self.assertRaises(interpreters.ChannelNotFoundError): 704 ch.recv() 705 706 def test_send_channel_does_not_exist(self): 707 ch = interpreters.SendChannel(1_000_000) 708 with self.assertRaises(interpreters.ChannelNotFoundError): 709 ch.send(b'spam') 710 711 def test_recv_nowait_channel_does_not_exist(self): 712 ch = interpreters.RecvChannel(1_000_000) 713 with self.assertRaises(interpreters.ChannelNotFoundError): 714 ch.recv_nowait() 715 716 def test_send_nowait_channel_does_not_exist(self): 717 ch = interpreters.SendChannel(1_000_000) 718 with self.assertRaises(interpreters.ChannelNotFoundError): 719 ch.send_nowait(b'spam') 720 721 def test_recv_nowait_empty(self): 722 ch, _ = interpreters.create_channel() 723 with self.assertRaises(interpreters.ChannelEmptyError): 724 ch.recv_nowait() 725 726 def test_recv_nowait_default(self): 727 default = object() 728 rch, sch = interpreters.create_channel() 729 obj1 = rch.recv_nowait(default) 730 sch.send_nowait(None) 731 sch.send_nowait(1) 732 sch.send_nowait(b'spam') 733 sch.send_nowait(b'eggs') 734 obj2 = rch.recv_nowait(default) 735 obj3 = rch.recv_nowait(default) 736 obj4 = rch.recv_nowait() 737 obj5 = rch.recv_nowait(default) 738 obj6 = rch.recv_nowait(default) 739 740 self.assertIs(obj1, default) 741 self.assertIs(obj2, None) 742 self.assertEqual(obj3, 1) 743 self.assertEqual(obj4, b'spam') 744 self.assertEqual(obj5, b'eggs') 745 self.assertIs(obj6, default) 746