1import asyncio 2import asyncio.sslproto 3import gc 4import os 5import select 6import socket 7import unittest.mock 8import ssl 9import sys 10import threading 11import time 12import weakref 13 14from OpenSSL import SSL as openssl_ssl 15from uvloop import _testbase as tb 16 17 18SSL_HANDSHAKE_TIMEOUT = 15.0 19 20 21class MyBaseProto(asyncio.Protocol): 22 connected = None 23 done = None 24 25 def __init__(self, loop=None): 26 self.transport = None 27 self.state = 'INITIAL' 28 self.nbytes = 0 29 if loop is not None: 30 self.connected = asyncio.Future(loop=loop) 31 self.done = asyncio.Future(loop=loop) 32 33 def connection_made(self, transport): 34 self.transport = transport 35 assert self.state == 'INITIAL', self.state 36 self.state = 'CONNECTED' 37 if self.connected: 38 self.connected.set_result(None) 39 40 def data_received(self, data): 41 assert self.state == 'CONNECTED', self.state 42 self.nbytes += len(data) 43 44 def eof_received(self): 45 assert self.state == 'CONNECTED', self.state 46 self.state = 'EOF' 47 48 def connection_lost(self, exc): 49 assert self.state in ('CONNECTED', 'EOF'), self.state 50 self.state = 'CLOSED' 51 if self.done: 52 self.done.set_result(None) 53 54 55class _TestTCP: 56 def test_create_server_1(self): 57 CNT = 0 # number of clients that were successful 58 TOTAL_CNT = 25 # total number of clients that test will create 59 TIMEOUT = 5.0 # timeout for this test 60 61 A_DATA = b'A' * 1024 * 1024 62 B_DATA = b'B' * 1024 * 1024 63 64 async def handle_client(reader, writer): 65 nonlocal CNT 66 67 data = await reader.readexactly(len(A_DATA)) 68 self.assertEqual(data, A_DATA) 69 writer.write(b'OK') 70 71 data = await reader.readexactly(len(B_DATA)) 72 self.assertEqual(data, B_DATA) 73 writer.writelines([b'S', b'P']) 74 writer.write(bytearray(b'A')) 75 writer.write(memoryview(b'M')) 76 77 if self.implementation == 'uvloop': 78 tr = writer.transport 79 sock = tr.get_extra_info('socket') 80 self.assertTrue( 81 sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY)) 82 83 await writer.drain() 84 writer.close() 85 86 CNT += 1 87 88 async def test_client(addr): 89 sock = socket.socket() 90 with sock: 91 sock.setblocking(False) 92 await self.loop.sock_connect(sock, addr) 93 94 await self.loop.sock_sendall(sock, A_DATA) 95 96 buf = b'' 97 while len(buf) != 2: 98 buf += await self.loop.sock_recv(sock, 1) 99 self.assertEqual(buf, b'OK') 100 101 await self.loop.sock_sendall(sock, B_DATA) 102 103 buf = b'' 104 while len(buf) != 4: 105 buf += await self.loop.sock_recv(sock, 1) 106 self.assertEqual(buf, b'SPAM') 107 108 self.assertEqual(sock.fileno(), -1) 109 self.assertEqual(sock._io_refs, 0) 110 self.assertTrue(sock._closed) 111 112 async def start_server(): 113 nonlocal CNT 114 CNT = 0 115 116 srv = await asyncio.start_server( 117 handle_client, 118 ('127.0.0.1', 'localhost'), 0, 119 family=socket.AF_INET) 120 121 srv_socks = srv.sockets 122 self.assertTrue(srv_socks) 123 self.assertTrue(srv.is_serving()) 124 125 addr = srv_socks[0].getsockname() 126 127 tasks = [] 128 for _ in range(TOTAL_CNT): 129 tasks.append(test_client(addr)) 130 131 await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT) 132 133 self.loop.call_soon(srv.close) 134 await srv.wait_closed() 135 136 # Check that the server cleaned-up proxy-sockets 137 for srv_sock in srv_socks: 138 self.assertEqual(srv_sock.fileno(), -1) 139 140 self.assertFalse(srv.is_serving()) 141 142 async def start_server_sock(): 143 nonlocal CNT 144 CNT = 0 145 146 sock = socket.socket() 147 sock.bind(('127.0.0.1', 0)) 148 addr = sock.getsockname() 149 150 srv = await asyncio.start_server( 151 handle_client, 152 None, None, 153 family=socket.AF_INET, 154 sock=sock) 155 156 self.assertIs(srv.get_loop(), self.loop) 157 158 srv_socks = srv.sockets 159 self.assertTrue(srv_socks) 160 self.assertTrue(srv.is_serving()) 161 162 tasks = [] 163 for _ in range(TOTAL_CNT): 164 tasks.append(test_client(addr)) 165 166 await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT) 167 168 srv.close() 169 await srv.wait_closed() 170 171 # Check that the server cleaned-up proxy-sockets 172 for srv_sock in srv_socks: 173 self.assertEqual(srv_sock.fileno(), -1) 174 175 self.assertFalse(srv.is_serving()) 176 177 self.loop.run_until_complete(start_server()) 178 self.assertEqual(CNT, TOTAL_CNT) 179 180 self.loop.run_until_complete(start_server_sock()) 181 self.assertEqual(CNT, TOTAL_CNT) 182 183 def test_create_server_2(self): 184 with self.assertRaisesRegex(ValueError, 'nor sock were specified'): 185 self.loop.run_until_complete(self.loop.create_server(object)) 186 187 def test_create_server_3(self): 188 ''' check ephemeral port can be used ''' 189 190 async def start_server_ephemeral_ports(): 191 192 for port_sentinel in [0, None]: 193 srv = await self.loop.create_server( 194 asyncio.Protocol, 195 '127.0.0.1', port_sentinel, 196 family=socket.AF_INET) 197 198 srv_socks = srv.sockets 199 self.assertTrue(srv_socks) 200 self.assertTrue(srv.is_serving()) 201 202 host, port = srv_socks[0].getsockname() 203 self.assertNotEqual(0, port) 204 205 self.loop.call_soon(srv.close) 206 await srv.wait_closed() 207 208 # Check that the server cleaned-up proxy-sockets 209 for srv_sock in srv_socks: 210 self.assertEqual(srv_sock.fileno(), -1) 211 212 self.assertFalse(srv.is_serving()) 213 214 self.loop.run_until_complete(start_server_ephemeral_ports()) 215 216 def test_create_server_4(self): 217 sock = socket.socket() 218 sock.bind(('127.0.0.1', 0)) 219 220 with sock: 221 addr = sock.getsockname() 222 223 with self.assertRaisesRegex(OSError, 224 r"error while attempting.*\('127.*: " 225 r"address already in use"): 226 227 self.loop.run_until_complete( 228 self.loop.create_server(object, *addr)) 229 230 def test_create_server_5(self): 231 # Test that create_server sets the TCP_IPV6ONLY flag, 232 # so it can bind to ipv4 and ipv6 addresses 233 # simultaneously. 234 235 port = tb.find_free_port() 236 237 async def runner(): 238 srv = await self.loop.create_server( 239 asyncio.Protocol, 240 None, port) 241 242 srv.close() 243 await srv.wait_closed() 244 245 self.loop.run_until_complete(runner()) 246 247 def test_create_server_6(self): 248 if not hasattr(socket, 'SO_REUSEPORT'): 249 raise unittest.SkipTest( 250 'The system does not support SO_REUSEPORT') 251 252 port = tb.find_free_port() 253 254 async def runner(): 255 srv1 = await self.loop.create_server( 256 asyncio.Protocol, 257 None, port, 258 reuse_port=True) 259 260 srv2 = await self.loop.create_server( 261 asyncio.Protocol, 262 None, port, 263 reuse_port=True) 264 265 srv1.close() 266 srv2.close() 267 268 await srv1.wait_closed() 269 await srv2.wait_closed() 270 271 self.loop.run_until_complete(runner()) 272 273 def test_create_server_7(self): 274 # Test that create_server() stores a hard ref to the server object 275 # somewhere in the loop. In asyncio it so happens that 276 # loop.sock_accept() has a reference to the server object so it 277 # never gets GCed. 278 279 class Proto(asyncio.Protocol): 280 def connection_made(self, tr): 281 self.tr = tr 282 self.tr.write(b'hello') 283 284 async def test(): 285 port = tb.find_free_port() 286 srv = await self.loop.create_server(Proto, '127.0.0.1', port) 287 wsrv = weakref.ref(srv) 288 del srv 289 290 gc.collect() 291 gc.collect() 292 gc.collect() 293 294 s = socket.socket(socket.AF_INET) 295 with s: 296 s.setblocking(False) 297 await self.loop.sock_connect(s, ('127.0.0.1', port)) 298 d = await self.loop.sock_recv(s, 100) 299 self.assertEqual(d, b'hello') 300 301 srv = wsrv() 302 srv.close() 303 await srv.wait_closed() 304 del srv 305 306 # Let all transports shutdown. 307 await asyncio.sleep(0.1) 308 309 gc.collect() 310 gc.collect() 311 gc.collect() 312 313 self.assertIsNone(wsrv()) 314 315 self.loop.run_until_complete(test()) 316 317 def test_create_server_8(self): 318 with self.assertRaisesRegex( 319 ValueError, 'ssl_handshake_timeout is only meaningful'): 320 self.loop.run_until_complete( 321 self.loop.create_server( 322 lambda: None, host='::', port=0, ssl_handshake_timeout=10)) 323 324 def test_create_server_9(self): 325 async def handle_client(reader, writer): 326 pass 327 328 async def start_server(): 329 srv = await asyncio.start_server( 330 handle_client, 331 '127.0.0.1', 0, 332 family=socket.AF_INET, 333 start_serving=False) 334 335 await srv.start_serving() 336 self.assertTrue(srv.is_serving()) 337 338 # call start_serving again 339 await srv.start_serving() 340 self.assertTrue(srv.is_serving()) 341 342 srv.close() 343 await srv.wait_closed() 344 self.assertFalse(srv.is_serving()) 345 346 self.loop.run_until_complete(start_server()) 347 348 def test_create_server_10(self): 349 async def handle_client(reader, writer): 350 pass 351 352 async def start_server(): 353 srv = await asyncio.start_server( 354 handle_client, 355 '127.0.0.1', 0, 356 family=socket.AF_INET, 357 start_serving=False) 358 359 async with srv: 360 fut = asyncio.ensure_future(srv.serve_forever()) 361 await asyncio.sleep(0) 362 self.assertTrue(srv.is_serving()) 363 364 fut.cancel() 365 with self.assertRaises(asyncio.CancelledError): 366 await fut 367 self.assertFalse(srv.is_serving()) 368 369 self.loop.run_until_complete(start_server()) 370 371 def test_create_connection_open_con_addr(self): 372 async def client(addr): 373 reader, writer = await asyncio.open_connection(*addr) 374 375 writer.write(b'AAAA') 376 self.assertEqual(await reader.readexactly(2), b'OK') 377 378 re = r'(a bytes-like object)|(must be byte-ish)' 379 with self.assertRaisesRegex(TypeError, re): 380 writer.write('AAAA') 381 382 writer.write(b'BBBB') 383 self.assertEqual(await reader.readexactly(4), b'SPAM') 384 385 if self.implementation == 'uvloop': 386 tr = writer.transport 387 sock = tr.get_extra_info('socket') 388 self.assertTrue( 389 sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY)) 390 391 writer.close() 392 await self.wait_closed(writer) 393 394 self._test_create_connection_1(client) 395 396 def test_create_connection_open_con_sock(self): 397 async def client(addr): 398 sock = socket.socket() 399 sock.connect(addr) 400 reader, writer = await asyncio.open_connection(sock=sock) 401 402 writer.write(b'AAAA') 403 self.assertEqual(await reader.readexactly(2), b'OK') 404 405 writer.write(b'BBBB') 406 self.assertEqual(await reader.readexactly(4), b'SPAM') 407 408 if self.implementation == 'uvloop': 409 tr = writer.transport 410 sock = tr.get_extra_info('socket') 411 self.assertTrue( 412 sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY)) 413 414 writer.close() 415 await self.wait_closed(writer) 416 417 self._test_create_connection_1(client) 418 419 def _test_create_connection_1(self, client): 420 CNT = 0 421 TOTAL_CNT = 100 422 423 def server(sock): 424 data = sock.recv_all(4) 425 self.assertEqual(data, b'AAAA') 426 sock.send(b'OK') 427 428 data = sock.recv_all(4) 429 self.assertEqual(data, b'BBBB') 430 sock.send(b'SPAM') 431 432 async def client_wrapper(addr): 433 await client(addr) 434 nonlocal CNT 435 CNT += 1 436 437 def run(coro): 438 nonlocal CNT 439 CNT = 0 440 441 with self.tcp_server(server, 442 max_clients=TOTAL_CNT, 443 backlog=TOTAL_CNT) as srv: 444 tasks = [] 445 for _ in range(TOTAL_CNT): 446 tasks.append(coro(srv.addr)) 447 448 self.loop.run_until_complete(asyncio.gather(*tasks)) 449 450 self.assertEqual(CNT, TOTAL_CNT) 451 452 run(client_wrapper) 453 454 def test_create_connection_2(self): 455 sock = socket.socket() 456 with sock: 457 sock.bind(('127.0.0.1', 0)) 458 addr = sock.getsockname() 459 460 async def client(): 461 reader, writer = await asyncio.open_connection(*addr) 462 writer.close() 463 await self.wait_closed(writer) 464 465 async def runner(): 466 with self.assertRaises(ConnectionRefusedError): 467 await client() 468 469 self.loop.run_until_complete(runner()) 470 471 def test_create_connection_3(self): 472 CNT = 0 473 TOTAL_CNT = 100 474 475 def server(sock): 476 data = sock.recv_all(4) 477 self.assertEqual(data, b'AAAA') 478 sock.close() 479 480 async def client(addr): 481 reader, writer = await asyncio.open_connection(*addr) 482 483 writer.write(b'AAAA') 484 485 with self.assertRaises(asyncio.IncompleteReadError): 486 await reader.readexactly(10) 487 488 writer.close() 489 await self.wait_closed(writer) 490 491 nonlocal CNT 492 CNT += 1 493 494 def run(coro): 495 nonlocal CNT 496 CNT = 0 497 498 with self.tcp_server(server, 499 max_clients=TOTAL_CNT, 500 backlog=TOTAL_CNT) as srv: 501 tasks = [] 502 for _ in range(TOTAL_CNT): 503 tasks.append(coro(srv.addr)) 504 505 self.loop.run_until_complete(asyncio.gather(*tasks)) 506 507 self.assertEqual(CNT, TOTAL_CNT) 508 509 run(client) 510 511 def test_create_connection_4(self): 512 sock = socket.socket() 513 sock.close() 514 515 async def client(): 516 reader, writer = await asyncio.open_connection(sock=sock) 517 writer.close() 518 await self.wait_closed(writer) 519 520 async def runner(): 521 with self.assertRaisesRegex(OSError, 'Bad file'): 522 await client() 523 524 self.loop.run_until_complete(runner()) 525 526 def test_create_connection_5(self): 527 def server(sock): 528 try: 529 data = sock.recv_all(4) 530 except ConnectionError: 531 return 532 self.assertEqual(data, b'AAAA') 533 sock.send(b'OK') 534 535 async def client(addr): 536 fut = asyncio.ensure_future( 537 self.loop.create_connection(asyncio.Protocol, *addr)) 538 await asyncio.sleep(0) 539 fut.cancel() 540 with self.assertRaises(asyncio.CancelledError): 541 await fut 542 543 with self.tcp_server(server, 544 max_clients=1, 545 backlog=1) as srv: 546 self.loop.run_until_complete(client(srv.addr)) 547 548 def test_create_connection_6(self): 549 with self.assertRaisesRegex( 550 ValueError, 'ssl_handshake_timeout is only meaningful'): 551 self.loop.run_until_complete( 552 self.loop.create_connection( 553 lambda: None, host='::', port=0, ssl_handshake_timeout=10)) 554 555 def test_transport_shutdown(self): 556 CNT = 0 # number of clients that were successful 557 TOTAL_CNT = 100 # total number of clients that test will create 558 TIMEOUT = 5.0 # timeout for this test 559 560 async def handle_client(reader, writer): 561 nonlocal CNT 562 563 data = await reader.readexactly(4) 564 self.assertEqual(data, b'AAAA') 565 566 writer.write(b'OK') 567 writer.write_eof() 568 writer.write_eof() 569 570 await writer.drain() 571 writer.close() 572 573 CNT += 1 574 575 async def test_client(addr): 576 reader, writer = await asyncio.open_connection(*addr) 577 578 writer.write(b'AAAA') 579 data = await reader.readexactly(2) 580 self.assertEqual(data, b'OK') 581 582 writer.close() 583 await self.wait_closed(writer) 584 585 async def start_server(): 586 nonlocal CNT 587 CNT = 0 588 589 srv = await asyncio.start_server( 590 handle_client, 591 '127.0.0.1', 0, 592 family=socket.AF_INET) 593 594 srv_socks = srv.sockets 595 self.assertTrue(srv_socks) 596 597 addr = srv_socks[0].getsockname() 598 599 tasks = [] 600 for _ in range(TOTAL_CNT): 601 tasks.append(test_client(addr)) 602 603 await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT) 604 605 srv.close() 606 await srv.wait_closed() 607 608 self.loop.run_until_complete(start_server()) 609 self.assertEqual(CNT, TOTAL_CNT) 610 611 def test_tcp_handle_exception_in_connection_made(self): 612 # Test that if connection_made raises an exception, 613 # 'create_connection' still returns. 614 615 # Silence error logging 616 self.loop.set_exception_handler(lambda *args: None) 617 618 fut = asyncio.Future() 619 connection_lost_called = asyncio.Future() 620 621 async def server(reader, writer): 622 try: 623 await reader.read() 624 finally: 625 writer.close() 626 627 class Proto(asyncio.Protocol): 628 def connection_made(self, tr): 629 1 / 0 630 631 def connection_lost(self, exc): 632 connection_lost_called.set_result(exc) 633 634 srv = self.loop.run_until_complete(asyncio.start_server( 635 server, 636 '127.0.0.1', 0, 637 family=socket.AF_INET)) 638 639 async def runner(): 640 tr, pr = await asyncio.wait_for( 641 self.loop.create_connection( 642 Proto, *srv.sockets[0].getsockname()), 643 timeout=1.0) 644 fut.set_result(None) 645 tr.close() 646 647 self.loop.run_until_complete(runner()) 648 srv.close() 649 self.loop.run_until_complete(srv.wait_closed()) 650 self.loop.run_until_complete(fut) 651 652 self.assertIsNone( 653 self.loop.run_until_complete(connection_lost_called)) 654 655 def test_context_run_segfault(self): 656 is_new = False 657 done = self.loop.create_future() 658 659 def server(sock): 660 sock.sendall(b'hello') 661 662 class Protocol(asyncio.Protocol): 663 def __init__(self): 664 self.transport = None 665 666 def connection_made(self, transport): 667 self.transport = transport 668 669 def data_received(self, data): 670 try: 671 self = weakref.ref(self) 672 nonlocal is_new 673 if is_new: 674 done.set_result(data) 675 else: 676 is_new = True 677 new_proto = Protocol() 678 self().transport.set_protocol(new_proto) 679 new_proto.connection_made(self().transport) 680 new_proto.data_received(data) 681 except Exception as e: 682 done.set_exception(e) 683 684 async def test(addr): 685 await self.loop.create_connection(Protocol, *addr) 686 data = await done 687 self.assertEqual(data, b'hello') 688 689 with self.tcp_server(server) as srv: 690 self.loop.run_until_complete(test(srv.addr)) 691 692 693class Test_UV_TCP(_TestTCP, tb.UVTestCase): 694 695 def test_create_server_buffered_1(self): 696 SIZE = 123123 697 eof = False 698 fut = asyncio.Future() 699 700 class Proto(asyncio.BaseProtocol): 701 def connection_made(self, tr): 702 self.tr = tr 703 self.recvd = b'' 704 self.data = bytearray(50) 705 self.buf = memoryview(self.data) 706 707 def get_buffer(self, sizehint): 708 return self.buf 709 710 def buffer_updated(self, nbytes): 711 self.recvd += self.buf[:nbytes] 712 if self.recvd == b'a' * SIZE: 713 self.tr.write(b'hello') 714 715 def eof_received(self): 716 nonlocal eof 717 eof = True 718 719 def connection_lost(self, exc): 720 fut.set_result(exc) 721 722 async def test(): 723 port = tb.find_free_port() 724 srv = await self.loop.create_server(Proto, '127.0.0.1', port) 725 726 s = socket.socket(socket.AF_INET) 727 with s: 728 s.setblocking(False) 729 await self.loop.sock_connect(s, ('127.0.0.1', port)) 730 await self.loop.sock_sendall(s, b'a' * SIZE) 731 d = await self.loop.sock_recv(s, 100) 732 self.assertEqual(d, b'hello') 733 734 srv.close() 735 await srv.wait_closed() 736 737 self.loop.run_until_complete(test()) 738 self.loop.run_until_complete(fut) 739 self.assertTrue(eof) 740 self.assertIsNone(fut.result()) 741 742 def test_create_server_buffered_2(self): 743 class ProtoExc(asyncio.BaseProtocol): 744 def __init__(self): 745 self._lost_exc = None 746 747 def get_buffer(self, sizehint): 748 1 / 0 749 750 def buffer_updated(self, nbytes): 751 pass 752 753 def connection_lost(self, exc): 754 self._lost_exc = exc 755 756 def eof_received(self): 757 pass 758 759 class ProtoZeroBuf1(asyncio.BaseProtocol): 760 def __init__(self): 761 self._lost_exc = None 762 763 def get_buffer(self, sizehint): 764 return bytearray(0) 765 766 def buffer_updated(self, nbytes): 767 pass 768 769 def connection_lost(self, exc): 770 self._lost_exc = exc 771 772 def eof_received(self): 773 pass 774 775 class ProtoZeroBuf2(asyncio.BaseProtocol): 776 def __init__(self): 777 self._lost_exc = None 778 779 def get_buffer(self, sizehint): 780 return memoryview(bytearray(0)) 781 782 def buffer_updated(self, nbytes): 783 pass 784 785 def connection_lost(self, exc): 786 self._lost_exc = exc 787 788 def eof_received(self): 789 pass 790 791 class ProtoUpdatedError(asyncio.BaseProtocol): 792 def __init__(self): 793 self._lost_exc = None 794 795 def get_buffer(self, sizehint): 796 return memoryview(bytearray(100)) 797 798 def buffer_updated(self, nbytes): 799 raise RuntimeError('oups') 800 801 def connection_lost(self, exc): 802 self._lost_exc = exc 803 804 def eof_received(self): 805 pass 806 807 async def test(proto_factory, exc_type, exc_re): 808 port = tb.find_free_port() 809 proto = proto_factory() 810 srv = await self.loop.create_server( 811 lambda: proto, '127.0.0.1', port) 812 813 try: 814 s = socket.socket(socket.AF_INET) 815 with s: 816 s.setblocking(False) 817 await self.loop.sock_connect(s, ('127.0.0.1', port)) 818 await self.loop.sock_sendall(s, b'a') 819 d = await self.loop.sock_recv(s, 100) 820 if not d: 821 raise ConnectionResetError 822 except ConnectionResetError: 823 pass 824 else: 825 self.fail("server didn't abort the connection") 826 return 827 finally: 828 srv.close() 829 await srv.wait_closed() 830 831 if proto._lost_exc is None: 832 self.fail("connection_lost() was not called") 833 return 834 835 with self.assertRaisesRegex(exc_type, exc_re): 836 raise proto._lost_exc 837 838 self.loop.set_exception_handler(lambda loop, ctx: None) 839 840 self.loop.run_until_complete( 841 test(ProtoExc, RuntimeError, 'unhandled error .* get_buffer')) 842 843 self.loop.run_until_complete( 844 test(ProtoZeroBuf1, RuntimeError, 'unhandled error .* get_buffer')) 845 846 self.loop.run_until_complete( 847 test(ProtoZeroBuf2, RuntimeError, 'unhandled error .* get_buffer')) 848 849 self.loop.run_until_complete( 850 test(ProtoUpdatedError, RuntimeError, r'^oups$')) 851 852 def test_transport_get_extra_info(self): 853 # This tests is only for uvloop. asyncio should pass it 854 # too in Python 3.6. 855 856 fut = asyncio.Future() 857 858 async def handle_client(reader, writer): 859 with self.assertRaises(asyncio.IncompleteReadError): 860 await reader.readexactly(4) 861 writer.close() 862 863 # Previously, when we used socket.fromfd to create a socket 864 # for UVTransports (to make get_extra_info() work), a duplicate 865 # of the socket was created, preventing UVTransport from being 866 # properly closed. 867 # This test ensures that server handle will receive an EOF 868 # and finish the request. 869 fut.set_result(None) 870 871 async def test_client(addr): 872 t, p = await self.loop.create_connection( 873 lambda: asyncio.Protocol(), *addr) 874 875 if hasattr(t, 'get_protocol'): 876 p2 = asyncio.Protocol() 877 self.assertIs(t.get_protocol(), p) 878 t.set_protocol(p2) 879 self.assertIs(t.get_protocol(), p2) 880 t.set_protocol(p) 881 882 self.assertFalse(t._paused) 883 self.assertTrue(t.is_reading()) 884 t.pause_reading() 885 t.pause_reading() # Check that it's OK to call it 2nd time. 886 self.assertTrue(t._paused) 887 self.assertFalse(t.is_reading()) 888 t.resume_reading() 889 t.resume_reading() # Check that it's OK to call it 2nd time. 890 self.assertFalse(t._paused) 891 self.assertTrue(t.is_reading()) 892 893 sock = t.get_extra_info('socket') 894 self.assertIs(sock, t.get_extra_info('socket')) 895 sockname = sock.getsockname() 896 peername = sock.getpeername() 897 898 with self.assertRaisesRegex(RuntimeError, 'is used by transport'): 899 self.loop.add_writer(sock.fileno(), lambda: None) 900 with self.assertRaisesRegex(RuntimeError, 'is used by transport'): 901 self.loop.remove_writer(sock.fileno()) 902 with self.assertRaisesRegex(RuntimeError, 'is used by transport'): 903 self.loop.add_reader(sock.fileno(), lambda: None) 904 with self.assertRaisesRegex(RuntimeError, 'is used by transport'): 905 self.loop.remove_reader(sock.fileno()) 906 907 self.assertEqual(t.get_extra_info('sockname'), 908 sockname) 909 self.assertEqual(t.get_extra_info('peername'), 910 peername) 911 912 t.write(b'OK') # We want server to fail. 913 914 self.assertFalse(t._closing) 915 t.abort() 916 self.assertTrue(t._closing) 917 918 self.assertFalse(t.is_reading()) 919 # Check that pause_reading and resume_reading don't raise 920 # errors if called after the transport is closed. 921 t.pause_reading() 922 t.resume_reading() 923 924 await fut 925 926 # Test that peername and sockname are available after 927 # the transport is closed. 928 self.assertEqual(t.get_extra_info('peername'), 929 peername) 930 self.assertEqual(t.get_extra_info('sockname'), 931 sockname) 932 933 async def start_server(): 934 srv = await asyncio.start_server( 935 handle_client, 936 '127.0.0.1', 0, 937 family=socket.AF_INET) 938 939 addr = srv.sockets[0].getsockname() 940 await test_client(addr) 941 942 srv.close() 943 await srv.wait_closed() 944 945 self.loop.run_until_complete(start_server()) 946 947 def test_create_server_float_backlog(self): 948 # asyncio spits out a warning we cannot suppress 949 950 async def runner(bl): 951 await self.loop.create_server( 952 asyncio.Protocol, 953 None, 0, backlog=bl) 954 955 for bl in (1.1, '1'): 956 with self.subTest(backlog=bl): 957 with self.assertRaisesRegex(TypeError, 'integer'): 958 self.loop.run_until_complete(runner(bl)) 959 960 def test_many_small_writes(self): 961 N = 10000 962 TOTAL = 0 963 964 fut = self.loop.create_future() 965 966 async def server(reader, writer): 967 nonlocal TOTAL 968 while True: 969 d = await reader.read(10000) 970 if not d: 971 break 972 TOTAL += len(d) 973 fut.set_result(True) 974 writer.close() 975 976 async def run(): 977 srv = await asyncio.start_server( 978 server, 979 '127.0.0.1', 0, 980 family=socket.AF_INET) 981 982 addr = srv.sockets[0].getsockname() 983 r, w = await asyncio.open_connection(*addr) 984 985 DATA = b'x' * 102400 986 987 # Test _StreamWriteContext with short sequences of writes 988 w.write(DATA) 989 await w.drain() 990 for _ in range(3): 991 w.write(DATA) 992 await w.drain() 993 for _ in range(10): 994 w.write(DATA) 995 await w.drain() 996 997 for _ in range(N): 998 w.write(DATA) 999 1000 try: 1001 w.write('a') 1002 except TypeError: 1003 pass 1004 1005 await w.drain() 1006 for _ in range(N): 1007 w.write(DATA) 1008 await w.drain() 1009 1010 w.close() 1011 await fut 1012 await self.wait_closed(w) 1013 1014 srv.close() 1015 await srv.wait_closed() 1016 1017 self.assertEqual(TOTAL, N * 2 * len(DATA) + 14 * len(DATA)) 1018 1019 self.loop.run_until_complete(run()) 1020 1021 @unittest.skipIf(sys.version_info[:3] >= (3, 8, 0), 1022 "3.8 has a different method of GCing unclosed streams") 1023 def test_tcp_handle_unclosed_gc(self): 1024 fut = self.loop.create_future() 1025 1026 async def server(reader, writer): 1027 writer.transport.abort() 1028 fut.set_result(True) 1029 1030 async def run(): 1031 addr = srv.sockets[0].getsockname() 1032 await asyncio.open_connection(*addr) 1033 await fut 1034 srv.close() 1035 await srv.wait_closed() 1036 1037 srv = self.loop.run_until_complete(asyncio.start_server( 1038 server, 1039 '127.0.0.1', 0, 1040 family=socket.AF_INET)) 1041 1042 if self.loop.get_debug(): 1043 rx = r'unclosed resource <TCP.*; ' \ 1044 r'object created at(.|\n)*test_tcp_handle_unclosed_gc' 1045 else: 1046 rx = r'unclosed resource <TCP.*' 1047 1048 with self.assertWarnsRegex(ResourceWarning, rx): 1049 self.loop.create_task(run()) 1050 self.loop.run_until_complete(srv.wait_closed()) 1051 self.loop.run_until_complete(asyncio.sleep(0.1)) 1052 1053 srv = None 1054 gc.collect() 1055 gc.collect() 1056 gc.collect() 1057 1058 self.loop.run_until_complete(asyncio.sleep(0.1)) 1059 1060 # Since one TCPTransport handle wasn't closed correctly, 1061 # we need to disable this check: 1062 self.skip_unclosed_handles_check() 1063 1064 def test_tcp_handle_abort_in_connection_made(self): 1065 async def server(reader, writer): 1066 try: 1067 await reader.read() 1068 finally: 1069 writer.close() 1070 1071 class Proto(asyncio.Protocol): 1072 def connection_made(self, tr): 1073 tr.abort() 1074 1075 srv = self.loop.run_until_complete(asyncio.start_server( 1076 server, 1077 '127.0.0.1', 0, 1078 family=socket.AF_INET)) 1079 1080 async def runner(): 1081 tr, pr = await asyncio.wait_for( 1082 self.loop.create_connection( 1083 Proto, *srv.sockets[0].getsockname()), 1084 timeout=1.0) 1085 1086 # Asyncio would return a closed socket, which we 1087 # can't do: the transport was aborted, hence there 1088 # is no FD to attach a socket to (to make 1089 # get_extra_info() work). 1090 self.assertIsNone(tr.get_extra_info('socket')) 1091 tr.close() 1092 1093 self.loop.run_until_complete(runner()) 1094 srv.close() 1095 self.loop.run_until_complete(srv.wait_closed()) 1096 1097 def test_connect_accepted_socket_ssl_args(self): 1098 with self.assertRaisesRegex( 1099 ValueError, 'ssl_handshake_timeout is only meaningful'): 1100 with socket.socket() as s: 1101 self.loop.run_until_complete( 1102 self.loop.connect_accepted_socket( 1103 (lambda: None), 1104 s, 1105 ssl_handshake_timeout=SSL_HANDSHAKE_TIMEOUT 1106 ) 1107 ) 1108 1109 def test_connect_accepted_socket(self, server_ssl=None, client_ssl=None): 1110 loop = self.loop 1111 1112 class MyProto(MyBaseProto): 1113 1114 def connection_lost(self, exc): 1115 super().connection_lost(exc) 1116 loop.call_soon(loop.stop) 1117 1118 def data_received(self, data): 1119 super().data_received(data) 1120 self.transport.write(expected_response) 1121 1122 lsock = socket.socket(socket.AF_INET) 1123 lsock.bind(('127.0.0.1', 0)) 1124 lsock.listen(1) 1125 addr = lsock.getsockname() 1126 1127 message = b'test data' 1128 response = None 1129 expected_response = b'roger' 1130 1131 def client(): 1132 nonlocal response 1133 try: 1134 csock = socket.socket(socket.AF_INET) 1135 if client_ssl is not None: 1136 csock = client_ssl.wrap_socket(csock) 1137 csock.connect(addr) 1138 csock.sendall(message) 1139 response = csock.recv(99) 1140 csock.close() 1141 except Exception as exc: 1142 print( 1143 "Failure in client thread in test_connect_accepted_socket", 1144 exc) 1145 1146 thread = threading.Thread(target=client, daemon=True) 1147 thread.start() 1148 1149 conn, _ = lsock.accept() 1150 proto = MyProto(loop=loop) 1151 proto.loop = loop 1152 1153 extras = {} 1154 if server_ssl: 1155 extras = dict(ssl_handshake_timeout=SSL_HANDSHAKE_TIMEOUT) 1156 1157 f = loop.create_task( 1158 loop.connect_accepted_socket( 1159 (lambda: proto), conn, ssl=server_ssl, 1160 **extras)) 1161 loop.run_forever() 1162 conn.close() 1163 lsock.close() 1164 1165 thread.join(1) 1166 self.assertFalse(thread.is_alive()) 1167 self.assertEqual(proto.state, 'CLOSED') 1168 self.assertEqual(proto.nbytes, len(message)) 1169 self.assertEqual(response, expected_response) 1170 tr, _ = f.result() 1171 1172 if server_ssl: 1173 self.assertIn('SSL', tr.__class__.__name__) 1174 1175 tr.close() 1176 # let it close 1177 self.loop.run_until_complete(asyncio.sleep(0.1)) 1178 1179 @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'no Unix sockets') 1180 def test_create_connection_wrong_sock(self): 1181 sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 1182 with sock: 1183 coro = self.loop.create_connection(MyBaseProto, sock=sock) 1184 with self.assertRaisesRegex(ValueError, 1185 'A Stream Socket was expected'): 1186 self.loop.run_until_complete(coro) 1187 1188 @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'no Unix sockets') 1189 def test_create_server_wrong_sock(self): 1190 sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 1191 with sock: 1192 coro = self.loop.create_server(MyBaseProto, sock=sock) 1193 with self.assertRaisesRegex(ValueError, 1194 'A Stream Socket was expected'): 1195 self.loop.run_until_complete(coro) 1196 1197 @unittest.skipUnless(hasattr(socket, 'SOCK_NONBLOCK'), 1198 'no socket.SOCK_NONBLOCK (linux only)') 1199 def test_create_server_stream_bittype(self): 1200 sock = socket.socket( 1201 socket.AF_INET, socket.SOCK_STREAM | socket.SOCK_NONBLOCK) 1202 with sock: 1203 coro = self.loop.create_server(lambda: None, sock=sock) 1204 srv = self.loop.run_until_complete(coro) 1205 srv.close() 1206 self.loop.run_until_complete(srv.wait_closed()) 1207 1208 def test_flowcontrol_mixin_set_write_limits(self): 1209 async def client(addr): 1210 paused = False 1211 1212 class Protocol(asyncio.Protocol): 1213 def pause_writing(self): 1214 nonlocal paused 1215 paused = True 1216 1217 def resume_writing(self): 1218 nonlocal paused 1219 paused = False 1220 1221 t, p = await self.loop.create_connection(Protocol, *addr) 1222 1223 t.write(b'q' * 512) 1224 self.assertEqual(t.get_write_buffer_size(), 512) 1225 1226 t.set_write_buffer_limits(low=16385) 1227 self.assertFalse(paused) 1228 self.assertEqual(t.get_write_buffer_limits(), (16385, 65540)) 1229 1230 with self.assertRaisesRegex(ValueError, 'high.*must be >= low'): 1231 t.set_write_buffer_limits(high=0, low=1) 1232 1233 t.set_write_buffer_limits(high=1024, low=128) 1234 self.assertFalse(paused) 1235 self.assertEqual(t.get_write_buffer_limits(), (128, 1024)) 1236 1237 t.set_write_buffer_limits(high=256, low=128) 1238 self.assertTrue(paused) 1239 self.assertEqual(t.get_write_buffer_limits(), (128, 256)) 1240 1241 t.close() 1242 1243 with self.tcp_server(lambda sock: sock.recv_all(1), 1244 max_clients=1, 1245 backlog=1) as srv: 1246 self.loop.run_until_complete(client(srv.addr)) 1247 1248 1249class Test_AIO_TCP(_TestTCP, tb.AIOTestCase): 1250 pass 1251 1252 1253class _TestSSL(tb.SSLTestCase): 1254 1255 ONLYCERT = tb._cert_fullname(__file__, 'ssl_cert.pem') 1256 ONLYKEY = tb._cert_fullname(__file__, 'ssl_key.pem') 1257 1258 PAYLOAD_SIZE = 1024 * 100 1259 TIMEOUT = 60 1260 1261 def test_create_server_ssl_1(self): 1262 CNT = 0 # number of clients that were successful 1263 TOTAL_CNT = 25 # total number of clients that test will create 1264 TIMEOUT = 10.0 # timeout for this test 1265 1266 A_DATA = b'A' * 1024 * 1024 1267 B_DATA = b'B' * 1024 * 1024 1268 1269 sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY) 1270 client_sslctx = self._create_client_ssl_context() 1271 1272 clients = [] 1273 1274 async def handle_client(reader, writer): 1275 nonlocal CNT 1276 1277 data = await reader.readexactly(len(A_DATA)) 1278 self.assertEqual(data, A_DATA) 1279 writer.write(b'OK') 1280 1281 data = await reader.readexactly(len(B_DATA)) 1282 self.assertEqual(data, B_DATA) 1283 writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')]) 1284 1285 await writer.drain() 1286 writer.close() 1287 1288 CNT += 1 1289 1290 async def test_client(addr): 1291 fut = asyncio.Future() 1292 1293 def prog(sock): 1294 try: 1295 sock.starttls(client_sslctx) 1296 sock.connect(addr) 1297 sock.send(A_DATA) 1298 1299 data = sock.recv_all(2) 1300 self.assertEqual(data, b'OK') 1301 1302 sock.send(B_DATA) 1303 data = sock.recv_all(4) 1304 self.assertEqual(data, b'SPAM') 1305 1306 sock.close() 1307 1308 except Exception as ex: 1309 self.loop.call_soon_threadsafe(fut.set_exception, ex) 1310 else: 1311 self.loop.call_soon_threadsafe(fut.set_result, None) 1312 1313 client = self.tcp_client(prog) 1314 client.start() 1315 clients.append(client) 1316 1317 await fut 1318 1319 async def start_server(): 1320 extras = dict(ssl_handshake_timeout=SSL_HANDSHAKE_TIMEOUT) 1321 1322 srv = await asyncio.start_server( 1323 handle_client, 1324 '127.0.0.1', 0, 1325 family=socket.AF_INET, 1326 ssl=sslctx, 1327 **extras) 1328 1329 try: 1330 srv_socks = srv.sockets 1331 self.assertTrue(srv_socks) 1332 1333 addr = srv_socks[0].getsockname() 1334 1335 tasks = [] 1336 for _ in range(TOTAL_CNT): 1337 tasks.append(test_client(addr)) 1338 1339 await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT) 1340 1341 finally: 1342 self.loop.call_soon(srv.close) 1343 await srv.wait_closed() 1344 1345 with self._silence_eof_received_warning(): 1346 self.loop.run_until_complete(start_server()) 1347 1348 self.assertEqual(CNT, TOTAL_CNT) 1349 1350 for client in clients: 1351 client.stop() 1352 1353 def test_create_connection_ssl_1(self): 1354 if self.implementation == 'asyncio': 1355 # Don't crash on asyncio errors 1356 self.loop.set_exception_handler(None) 1357 1358 CNT = 0 1359 TOTAL_CNT = 25 1360 1361 A_DATA = b'A' * 1024 * 1024 1362 B_DATA = b'B' * 1024 * 1024 1363 1364 sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY) 1365 client_sslctx = self._create_client_ssl_context() 1366 1367 def server(sock): 1368 sock.starttls( 1369 sslctx, 1370 server_side=True) 1371 1372 data = sock.recv_all(len(A_DATA)) 1373 self.assertEqual(data, A_DATA) 1374 sock.send(b'OK') 1375 1376 data = sock.recv_all(len(B_DATA)) 1377 self.assertEqual(data, B_DATA) 1378 sock.send(b'SPAM') 1379 1380 sock.close() 1381 1382 async def client(addr): 1383 extras = dict(ssl_handshake_timeout=SSL_HANDSHAKE_TIMEOUT) 1384 1385 reader, writer = await asyncio.open_connection( 1386 *addr, 1387 ssl=client_sslctx, 1388 server_hostname='', 1389 **extras) 1390 1391 writer.write(A_DATA) 1392 self.assertEqual(await reader.readexactly(2), b'OK') 1393 1394 writer.write(B_DATA) 1395 self.assertEqual(await reader.readexactly(4), b'SPAM') 1396 1397 nonlocal CNT 1398 CNT += 1 1399 1400 writer.close() 1401 await self.wait_closed(writer) 1402 1403 async def client_sock(addr): 1404 sock = socket.socket() 1405 sock.connect(addr) 1406 reader, writer = await asyncio.open_connection( 1407 sock=sock, 1408 ssl=client_sslctx, 1409 server_hostname='') 1410 1411 writer.write(A_DATA) 1412 self.assertEqual(await reader.readexactly(2), b'OK') 1413 1414 writer.write(B_DATA) 1415 self.assertEqual(await reader.readexactly(4), b'SPAM') 1416 1417 nonlocal CNT 1418 CNT += 1 1419 1420 writer.close() 1421 await self.wait_closed(writer) 1422 sock.close() 1423 1424 def run(coro): 1425 nonlocal CNT 1426 CNT = 0 1427 1428 with self.tcp_server(server, 1429 max_clients=TOTAL_CNT, 1430 backlog=TOTAL_CNT) as srv: 1431 tasks = [] 1432 for _ in range(TOTAL_CNT): 1433 tasks.append(coro(srv.addr)) 1434 1435 self.loop.run_until_complete(asyncio.gather(*tasks)) 1436 1437 self.assertEqual(CNT, TOTAL_CNT) 1438 1439 with self._silence_eof_received_warning(): 1440 run(client) 1441 1442 with self._silence_eof_received_warning(): 1443 run(client_sock) 1444 1445 def test_create_connection_ssl_slow_handshake(self): 1446 if self.implementation == 'asyncio': 1447 raise unittest.SkipTest() 1448 1449 client_sslctx = self._create_client_ssl_context() 1450 1451 # silence error logger 1452 self.loop.set_exception_handler(lambda *args: None) 1453 1454 def server(sock): 1455 try: 1456 sock.recv_all(1024 * 1024) 1457 except ConnectionAbortedError: 1458 pass 1459 finally: 1460 sock.close() 1461 1462 async def client(addr): 1463 reader, writer = await asyncio.open_connection( 1464 *addr, 1465 ssl=client_sslctx, 1466 server_hostname='', 1467 ssl_handshake_timeout=1.0) 1468 writer.close() 1469 await self.wait_closed(writer) 1470 1471 with self.tcp_server(server, 1472 max_clients=1, 1473 backlog=1) as srv: 1474 1475 with self.assertRaisesRegex( 1476 ConnectionAbortedError, 1477 r'SSL handshake.*is taking longer'): 1478 1479 self.loop.run_until_complete(client(srv.addr)) 1480 1481 def test_create_connection_ssl_failed_certificate(self): 1482 if self.implementation == 'asyncio': 1483 raise unittest.SkipTest() 1484 1485 # silence error logger 1486 self.loop.set_exception_handler(lambda *args: None) 1487 1488 sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY) 1489 client_sslctx = self._create_client_ssl_context(disable_verify=False) 1490 1491 def server(sock): 1492 try: 1493 sock.starttls( 1494 sslctx, 1495 server_side=True) 1496 sock.connect() 1497 except (ssl.SSLError, OSError): 1498 pass 1499 finally: 1500 sock.close() 1501 1502 async def client(addr): 1503 reader, writer = await asyncio.open_connection( 1504 *addr, 1505 ssl=client_sslctx, 1506 server_hostname='', 1507 ssl_handshake_timeout=1.0) 1508 writer.close() 1509 await self.wait_closed(writer) 1510 1511 with self.tcp_server(server, 1512 max_clients=1, 1513 backlog=1) as srv: 1514 1515 with self.assertRaises(ssl.SSLCertVerificationError): 1516 self.loop.run_until_complete(client(srv.addr)) 1517 1518 def test_start_tls_wrong_args(self): 1519 if self.implementation == 'asyncio': 1520 raise unittest.SkipTest() 1521 1522 async def main(): 1523 with self.assertRaisesRegex(TypeError, 'SSLContext, got'): 1524 await self.loop.start_tls(None, None, None) 1525 1526 sslctx = self._create_server_ssl_context( 1527 self.ONLYCERT, self.ONLYKEY) 1528 with self.assertRaisesRegex(TypeError, 'is not supported'): 1529 await self.loop.start_tls(None, None, sslctx) 1530 1531 self.loop.run_until_complete(main()) 1532 1533 def test_ssl_handshake_timeout(self): 1534 if self.implementation == 'asyncio': 1535 raise unittest.SkipTest() 1536 1537 # bpo-29970: Check that a connection is aborted if handshake is not 1538 # completed in timeout period, instead of remaining open indefinitely 1539 client_sslctx = self._create_client_ssl_context() 1540 1541 # silence error logger 1542 messages = [] 1543 self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) 1544 1545 server_side_aborted = False 1546 1547 def server(sock): 1548 nonlocal server_side_aborted 1549 try: 1550 sock.recv_all(1024 * 1024) 1551 except ConnectionAbortedError: 1552 server_side_aborted = True 1553 finally: 1554 sock.close() 1555 1556 async def client(addr): 1557 await asyncio.wait_for( 1558 self.loop.create_connection( 1559 asyncio.Protocol, 1560 *addr, 1561 ssl=client_sslctx, 1562 server_hostname='', 1563 ssl_handshake_timeout=SSL_HANDSHAKE_TIMEOUT 1564 ), 1565 0.5 1566 ) 1567 1568 with self.tcp_server(server, 1569 max_clients=1, 1570 backlog=1) as srv: 1571 1572 with self.assertRaises(asyncio.TimeoutError): 1573 self.loop.run_until_complete(client(srv.addr)) 1574 1575 self.assertTrue(server_side_aborted) 1576 1577 # Python issue #23197: cancelling a handshake must not raise an 1578 # exception or log an error, even if the handshake failed 1579 self.assertEqual(messages, []) 1580 1581 def test_ssl_handshake_connection_lost(self): 1582 # #246: make sure that no connection_lost() is called before 1583 # connection_made() is called first 1584 1585 client_sslctx = self._create_client_ssl_context() 1586 1587 # silence error logger 1588 self.loop.set_exception_handler(lambda loop, ctx: None) 1589 1590 connection_made_called = False 1591 connection_lost_called = False 1592 1593 def server(sock): 1594 sock.recv(1024) 1595 # break the connection during handshake 1596 sock.close() 1597 1598 class ClientProto(asyncio.Protocol): 1599 def connection_made(self, transport): 1600 nonlocal connection_made_called 1601 connection_made_called = True 1602 1603 def connection_lost(self, exc): 1604 nonlocal connection_lost_called 1605 connection_lost_called = True 1606 1607 async def client(addr): 1608 await self.loop.create_connection( 1609 ClientProto, 1610 *addr, 1611 ssl=client_sslctx, 1612 server_hostname=''), 1613 1614 with self.tcp_server(server, 1615 max_clients=1, 1616 backlog=1) as srv: 1617 1618 with self.assertRaises(ConnectionResetError): 1619 self.loop.run_until_complete(client(srv.addr)) 1620 1621 if connection_lost_called: 1622 if connection_made_called: 1623 self.fail("unexpected call to connection_lost()") 1624 else: 1625 self.fail("unexpected call to connection_lost() without" 1626 "calling connection_made()") 1627 elif connection_made_called: 1628 self.fail("unexpected call to connection_made()") 1629 1630 def test_ssl_connect_accepted_socket(self): 1631 if hasattr(ssl, 'PROTOCOL_TLS'): 1632 proto = ssl.PROTOCOL_TLS 1633 else: 1634 proto = ssl.PROTOCOL_SSLv23 1635 server_context = ssl.SSLContext(proto) 1636 server_context.load_cert_chain(self.ONLYCERT, self.ONLYKEY) 1637 if hasattr(server_context, 'check_hostname'): 1638 server_context.check_hostname = False 1639 server_context.verify_mode = ssl.CERT_NONE 1640 1641 client_context = ssl.SSLContext(proto) 1642 if hasattr(server_context, 'check_hostname'): 1643 client_context.check_hostname = False 1644 client_context.verify_mode = ssl.CERT_NONE 1645 1646 Test_UV_TCP.test_connect_accepted_socket( 1647 self, server_context, client_context) 1648 1649 def test_start_tls_client_corrupted_ssl(self): 1650 if self.implementation == 'asyncio': 1651 raise unittest.SkipTest() 1652 1653 self.loop.set_exception_handler(lambda loop, ctx: None) 1654 1655 sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY) 1656 client_sslctx = self._create_client_ssl_context() 1657 1658 def server(sock): 1659 orig_sock = sock.dup() 1660 try: 1661 sock.starttls( 1662 sslctx, 1663 server_side=True) 1664 sock.sendall(b'A\n') 1665 sock.recv_all(1) 1666 orig_sock.send(b'please corrupt the SSL connection') 1667 except ssl.SSLError: 1668 pass 1669 finally: 1670 sock.close() 1671 orig_sock.close() 1672 1673 async def client(addr): 1674 reader, writer = await asyncio.open_connection( 1675 *addr, 1676 ssl=client_sslctx, 1677 server_hostname='') 1678 1679 self.assertEqual(await reader.readline(), b'A\n') 1680 writer.write(b'B') 1681 with self.assertRaises(ssl.SSLError): 1682 await reader.readline() 1683 writer.close() 1684 try: 1685 await self.wait_closed(writer) 1686 except ssl.SSLError: 1687 pass 1688 return 'OK' 1689 1690 with self.tcp_server(server, 1691 max_clients=1, 1692 backlog=1) as srv: 1693 1694 res = self.loop.run_until_complete(client(srv.addr)) 1695 1696 self.assertEqual(res, 'OK') 1697 1698 def test_start_tls_client_reg_proto_1(self): 1699 if self.implementation == 'asyncio': 1700 raise unittest.SkipTest() 1701 1702 HELLO_MSG = b'1' * self.PAYLOAD_SIZE 1703 1704 server_context = self._create_server_ssl_context( 1705 self.ONLYCERT, self.ONLYKEY) 1706 client_context = self._create_client_ssl_context() 1707 1708 def serve(sock): 1709 sock.settimeout(self.TIMEOUT) 1710 1711 data = sock.recv_all(len(HELLO_MSG)) 1712 self.assertEqual(len(data), len(HELLO_MSG)) 1713 1714 sock.starttls(server_context, server_side=True) 1715 1716 sock.sendall(b'O') 1717 data = sock.recv_all(len(HELLO_MSG)) 1718 self.assertEqual(len(data), len(HELLO_MSG)) 1719 1720 sock.unwrap() 1721 sock.close() 1722 1723 class ClientProto(asyncio.Protocol): 1724 def __init__(self, on_data, on_eof): 1725 self.on_data = on_data 1726 self.on_eof = on_eof 1727 self.con_made_cnt = 0 1728 1729 def connection_made(proto, tr): 1730 proto.con_made_cnt += 1 1731 # Ensure connection_made gets called only once. 1732 self.assertEqual(proto.con_made_cnt, 1) 1733 1734 def data_received(self, data): 1735 self.on_data.set_result(data) 1736 1737 def eof_received(self): 1738 self.on_eof.set_result(True) 1739 1740 async def client(addr): 1741 await asyncio.sleep(0.5) 1742 1743 on_data = self.loop.create_future() 1744 on_eof = self.loop.create_future() 1745 1746 tr, proto = await self.loop.create_connection( 1747 lambda: ClientProto(on_data, on_eof), *addr) 1748 1749 tr.write(HELLO_MSG) 1750 new_tr = await self.loop.start_tls(tr, proto, client_context) 1751 1752 self.assertEqual(await on_data, b'O') 1753 new_tr.write(HELLO_MSG) 1754 await on_eof 1755 1756 new_tr.close() 1757 1758 with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: 1759 self.loop.run_until_complete( 1760 asyncio.wait_for(client(srv.addr), timeout=10)) 1761 1762 def test_create_connection_memory_leak(self): 1763 if self.implementation == 'asyncio': 1764 raise unittest.SkipTest() 1765 1766 HELLO_MSG = b'1' * self.PAYLOAD_SIZE 1767 1768 server_context = self._create_server_ssl_context( 1769 self.ONLYCERT, self.ONLYKEY) 1770 client_context = self._create_client_ssl_context() 1771 1772 def serve(sock): 1773 sock.settimeout(self.TIMEOUT) 1774 1775 sock.starttls(server_context, server_side=True) 1776 1777 sock.sendall(b'O') 1778 data = sock.recv_all(len(HELLO_MSG)) 1779 self.assertEqual(len(data), len(HELLO_MSG)) 1780 1781 sock.unwrap() 1782 sock.close() 1783 1784 class ClientProto(asyncio.Protocol): 1785 def __init__(self, on_data, on_eof): 1786 self.on_data = on_data 1787 self.on_eof = on_eof 1788 self.con_made_cnt = 0 1789 1790 def connection_made(proto, tr): 1791 # XXX: We assume user stores the transport in protocol 1792 proto.tr = tr 1793 proto.con_made_cnt += 1 1794 # Ensure connection_made gets called only once. 1795 self.assertEqual(proto.con_made_cnt, 1) 1796 1797 def data_received(self, data): 1798 self.on_data.set_result(data) 1799 1800 def eof_received(self): 1801 self.on_eof.set_result(True) 1802 1803 async def client(addr): 1804 await asyncio.sleep(0.5) 1805 1806 on_data = self.loop.create_future() 1807 on_eof = self.loop.create_future() 1808 1809 tr, proto = await self.loop.create_connection( 1810 lambda: ClientProto(on_data, on_eof), *addr, 1811 ssl=client_context) 1812 1813 self.assertEqual(await on_data, b'O') 1814 tr.write(HELLO_MSG) 1815 await on_eof 1816 1817 tr.close() 1818 1819 with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: 1820 self.loop.run_until_complete( 1821 asyncio.wait_for(client(srv.addr), timeout=10)) 1822 1823 # No garbage is left for SSL client from loop.create_connection, even 1824 # if user stores the SSLTransport in corresponding protocol instance 1825 client_context = weakref.ref(client_context) 1826 self.assertIsNone(client_context()) 1827 1828 def test_start_tls_client_buf_proto_1(self): 1829 if self.implementation == 'asyncio': 1830 raise unittest.SkipTest() 1831 1832 HELLO_MSG = b'1' * self.PAYLOAD_SIZE 1833 1834 server_context = self._create_server_ssl_context( 1835 self.ONLYCERT, self.ONLYKEY) 1836 client_context = self._create_client_ssl_context() 1837 1838 client_con_made_calls = 0 1839 1840 def serve(sock): 1841 sock.settimeout(self.TIMEOUT) 1842 1843 data = sock.recv_all(len(HELLO_MSG)) 1844 self.assertEqual(len(data), len(HELLO_MSG)) 1845 1846 sock.starttls(server_context, server_side=True) 1847 1848 sock.sendall(b'O') 1849 data = sock.recv_all(len(HELLO_MSG)) 1850 self.assertEqual(len(data), len(HELLO_MSG)) 1851 1852 sock.sendall(b'2') 1853 data = sock.recv_all(len(HELLO_MSG)) 1854 self.assertEqual(len(data), len(HELLO_MSG)) 1855 1856 sock.unwrap() 1857 sock.close() 1858 1859 class ClientProtoFirst(asyncio.BaseProtocol): 1860 def __init__(self, on_data): 1861 self.on_data = on_data 1862 self.buf = bytearray(1) 1863 1864 def connection_made(self, tr): 1865 nonlocal client_con_made_calls 1866 client_con_made_calls += 1 1867 1868 def get_buffer(self, sizehint): 1869 return self.buf 1870 1871 def buffer_updated(self, nsize): 1872 assert nsize == 1 1873 self.on_data.set_result(bytes(self.buf[:nsize])) 1874 1875 def eof_received(self): 1876 pass 1877 1878 class ClientProtoSecond(asyncio.Protocol): 1879 def __init__(self, on_data, on_eof): 1880 self.on_data = on_data 1881 self.on_eof = on_eof 1882 self.con_made_cnt = 0 1883 1884 def connection_made(self, tr): 1885 nonlocal client_con_made_calls 1886 client_con_made_calls += 1 1887 1888 def data_received(self, data): 1889 self.on_data.set_result(data) 1890 1891 def eof_received(self): 1892 self.on_eof.set_result(True) 1893 1894 async def client(addr): 1895 await asyncio.sleep(0.5) 1896 1897 on_data1 = self.loop.create_future() 1898 on_data2 = self.loop.create_future() 1899 on_eof = self.loop.create_future() 1900 1901 tr, proto = await self.loop.create_connection( 1902 lambda: ClientProtoFirst(on_data1), *addr) 1903 1904 tr.write(HELLO_MSG) 1905 new_tr = await self.loop.start_tls(tr, proto, client_context) 1906 1907 self.assertEqual(await on_data1, b'O') 1908 new_tr.write(HELLO_MSG) 1909 1910 new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof)) 1911 self.assertEqual(await on_data2, b'2') 1912 new_tr.write(HELLO_MSG) 1913 await on_eof 1914 1915 new_tr.close() 1916 1917 # connection_made() should be called only once -- when 1918 # we establish connection for the first time. Start TLS 1919 # doesn't call connection_made() on application protocols. 1920 self.assertEqual(client_con_made_calls, 1) 1921 1922 with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: 1923 self.loop.run_until_complete( 1924 asyncio.wait_for(client(srv.addr), 1925 timeout=self.TIMEOUT)) 1926 1927 def test_start_tls_slow_client_cancel(self): 1928 if self.implementation == 'asyncio': 1929 raise unittest.SkipTest() 1930 1931 HELLO_MSG = b'1' * self.PAYLOAD_SIZE 1932 1933 client_context = self._create_client_ssl_context() 1934 server_waits_on_handshake = self.loop.create_future() 1935 1936 def serve(sock): 1937 sock.settimeout(self.TIMEOUT) 1938 1939 data = sock.recv_all(len(HELLO_MSG)) 1940 self.assertEqual(len(data), len(HELLO_MSG)) 1941 1942 try: 1943 self.loop.call_soon_threadsafe( 1944 server_waits_on_handshake.set_result, None) 1945 data = sock.recv_all(1024 * 1024) 1946 except ConnectionAbortedError: 1947 pass 1948 finally: 1949 sock.close() 1950 1951 class ClientProto(asyncio.Protocol): 1952 def __init__(self, on_data, on_eof): 1953 self.on_data = on_data 1954 self.on_eof = on_eof 1955 self.con_made_cnt = 0 1956 1957 def connection_made(proto, tr): 1958 proto.con_made_cnt += 1 1959 # Ensure connection_made gets called only once. 1960 self.assertEqual(proto.con_made_cnt, 1) 1961 1962 def data_received(self, data): 1963 self.on_data.set_result(data) 1964 1965 def eof_received(self): 1966 self.on_eof.set_result(True) 1967 1968 async def client(addr): 1969 await asyncio.sleep(0.5) 1970 1971 on_data = self.loop.create_future() 1972 on_eof = self.loop.create_future() 1973 1974 tr, proto = await self.loop.create_connection( 1975 lambda: ClientProto(on_data, on_eof), *addr) 1976 1977 tr.write(HELLO_MSG) 1978 1979 await server_waits_on_handshake 1980 1981 with self.assertRaises(asyncio.TimeoutError): 1982 await asyncio.wait_for( 1983 self.loop.start_tls(tr, proto, client_context), 1984 0.5) 1985 1986 with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: 1987 self.loop.run_until_complete( 1988 asyncio.wait_for(client(srv.addr), timeout=10)) 1989 1990 def test_start_tls_server_1(self): 1991 if self.implementation == 'asyncio': 1992 raise unittest.SkipTest() 1993 1994 HELLO_MSG = b'1' * self.PAYLOAD_SIZE 1995 1996 server_context = self._create_server_ssl_context( 1997 self.ONLYCERT, self.ONLYKEY) 1998 client_context = self._create_client_ssl_context() 1999 2000 def client(sock, addr): 2001 sock.settimeout(self.TIMEOUT) 2002 2003 sock.connect(addr) 2004 data = sock.recv_all(len(HELLO_MSG)) 2005 self.assertEqual(len(data), len(HELLO_MSG)) 2006 2007 sock.starttls(client_context) 2008 sock.sendall(HELLO_MSG) 2009 2010 sock.unwrap() 2011 sock.close() 2012 2013 class ServerProto(asyncio.Protocol): 2014 def __init__(self, on_con, on_eof, on_con_lost): 2015 self.on_con = on_con 2016 self.on_eof = on_eof 2017 self.on_con_lost = on_con_lost 2018 self.data = b'' 2019 2020 def connection_made(self, tr): 2021 self.on_con.set_result(tr) 2022 2023 def data_received(self, data): 2024 self.data += data 2025 2026 def eof_received(self): 2027 self.on_eof.set_result(1) 2028 2029 def connection_lost(self, exc): 2030 if exc is None: 2031 self.on_con_lost.set_result(None) 2032 else: 2033 self.on_con_lost.set_exception(exc) 2034 2035 async def main(proto, on_con, on_eof, on_con_lost): 2036 tr = await on_con 2037 tr.write(HELLO_MSG) 2038 2039 self.assertEqual(proto.data, b'') 2040 2041 new_tr = await self.loop.start_tls( 2042 tr, proto, server_context, 2043 server_side=True, 2044 ssl_handshake_timeout=self.TIMEOUT) 2045 2046 await on_eof 2047 await on_con_lost 2048 self.assertEqual(proto.data, HELLO_MSG) 2049 new_tr.close() 2050 2051 async def run_main(): 2052 on_con = self.loop.create_future() 2053 on_eof = self.loop.create_future() 2054 on_con_lost = self.loop.create_future() 2055 proto = ServerProto(on_con, on_eof, on_con_lost) 2056 2057 server = await self.loop.create_server( 2058 lambda: proto, '127.0.0.1', 0) 2059 addr = server.sockets[0].getsockname() 2060 2061 with self.tcp_client(lambda sock: client(sock, addr), 2062 timeout=self.TIMEOUT): 2063 await asyncio.wait_for( 2064 main(proto, on_con, on_eof, on_con_lost), 2065 timeout=self.TIMEOUT) 2066 2067 server.close() 2068 await server.wait_closed() 2069 2070 self.loop.run_until_complete(run_main()) 2071 2072 def test_create_server_ssl_over_ssl(self): 2073 if self.implementation == 'asyncio': 2074 raise unittest.SkipTest('asyncio does not support SSL over SSL') 2075 2076 CNT = 0 # number of clients that were successful 2077 TOTAL_CNT = 25 # total number of clients that test will create 2078 TIMEOUT = 20.0 # timeout for this test 2079 2080 A_DATA = b'A' * 1024 * 1024 2081 B_DATA = b'B' * 1024 * 1024 2082 2083 sslctx_1 = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY) 2084 client_sslctx_1 = self._create_client_ssl_context() 2085 sslctx_2 = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY) 2086 client_sslctx_2 = self._create_client_ssl_context() 2087 2088 clients = [] 2089 2090 async def handle_client(reader, writer): 2091 nonlocal CNT 2092 2093 data = await reader.readexactly(len(A_DATA)) 2094 self.assertEqual(data, A_DATA) 2095 writer.write(b'OK') 2096 2097 data = await reader.readexactly(len(B_DATA)) 2098 self.assertEqual(data, B_DATA) 2099 writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')]) 2100 2101 await writer.drain() 2102 writer.close() 2103 2104 CNT += 1 2105 2106 class ServerProtocol(asyncio.StreamReaderProtocol): 2107 def connection_made(self, transport): 2108 super_ = super() 2109 transport.pause_reading() 2110 fut = self._loop.create_task(self._loop.start_tls( 2111 transport, self, sslctx_2, server_side=True)) 2112 2113 def cb(_): 2114 try: 2115 tr = fut.result() 2116 except Exception as ex: 2117 super_.connection_lost(ex) 2118 else: 2119 super_.connection_made(tr) 2120 fut.add_done_callback(cb) 2121 2122 def server_protocol_factory(): 2123 reader = asyncio.StreamReader() 2124 protocol = ServerProtocol(reader, handle_client) 2125 return protocol 2126 2127 async def test_client(addr): 2128 fut = asyncio.Future() 2129 2130 def prog(sock): 2131 try: 2132 sock.connect(addr) 2133 sock.starttls(client_sslctx_1) 2134 2135 # because wrap_socket() doesn't work correctly on 2136 # SSLSocket, we have to do the 2nd level SSL manually 2137 incoming = ssl.MemoryBIO() 2138 outgoing = ssl.MemoryBIO() 2139 sslobj = client_sslctx_2.wrap_bio(incoming, outgoing) 2140 2141 def do(func, *args): 2142 while True: 2143 try: 2144 rv = func(*args) 2145 break 2146 except ssl.SSLWantReadError: 2147 if outgoing.pending: 2148 sock.send(outgoing.read()) 2149 incoming.write(sock.recv(65536)) 2150 if outgoing.pending: 2151 sock.send(outgoing.read()) 2152 return rv 2153 2154 do(sslobj.do_handshake) 2155 2156 do(sslobj.write, A_DATA) 2157 data = do(sslobj.read, 2) 2158 self.assertEqual(data, b'OK') 2159 2160 do(sslobj.write, B_DATA) 2161 data = b'' 2162 while True: 2163 chunk = do(sslobj.read, 4) 2164 if not chunk: 2165 break 2166 data += chunk 2167 self.assertEqual(data, b'SPAM') 2168 2169 do(sslobj.unwrap) 2170 sock.close() 2171 2172 except Exception as ex: 2173 self.loop.call_soon_threadsafe(fut.set_exception, ex) 2174 sock.close() 2175 else: 2176 self.loop.call_soon_threadsafe(fut.set_result, None) 2177 2178 client = self.tcp_client(prog) 2179 client.start() 2180 clients.append(client) 2181 2182 await fut 2183 2184 async def start_server(): 2185 extras = dict(ssl_handshake_timeout=SSL_HANDSHAKE_TIMEOUT) 2186 2187 srv = await self.loop.create_server( 2188 server_protocol_factory, 2189 '127.0.0.1', 0, 2190 family=socket.AF_INET, 2191 ssl=sslctx_1, 2192 **extras) 2193 2194 try: 2195 srv_socks = srv.sockets 2196 self.assertTrue(srv_socks) 2197 2198 addr = srv_socks[0].getsockname() 2199 2200 tasks = [] 2201 for _ in range(TOTAL_CNT): 2202 tasks.append(test_client(addr)) 2203 2204 await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT) 2205 2206 finally: 2207 self.loop.call_soon(srv.close) 2208 await srv.wait_closed() 2209 2210 with self._silence_eof_received_warning(): 2211 self.loop.run_until_complete(start_server()) 2212 2213 self.assertEqual(CNT, TOTAL_CNT) 2214 2215 for client in clients: 2216 client.stop() 2217 2218 def test_renegotiation(self): 2219 if self.implementation == 'asyncio': 2220 raise unittest.SkipTest('asyncio does not support renegotiation') 2221 2222 CNT = 0 2223 TOTAL_CNT = 25 2224 2225 A_DATA = b'A' * 1024 * 1024 2226 B_DATA = b'B' * 1024 * 1024 2227 2228 sslctx = openssl_ssl.Context(openssl_ssl.TLSv1_2_METHOD) 2229 if hasattr(openssl_ssl, 'OP_NO_SSLV2'): 2230 sslctx.set_options(openssl_ssl.OP_NO_SSLV2) 2231 sslctx.use_privatekey_file(self.ONLYKEY) 2232 sslctx.use_certificate_chain_file(self.ONLYCERT) 2233 client_sslctx = self._create_client_ssl_context() 2234 if hasattr(ssl, 'OP_NO_TLSv1_3'): 2235 client_sslctx.options |= ssl.OP_NO_TLSv1_3 2236 2237 def server(sock): 2238 conn = openssl_ssl.Connection(sslctx, sock) 2239 conn.set_accept_state() 2240 2241 data = b'' 2242 while len(data) < len(A_DATA): 2243 try: 2244 chunk = conn.recv(len(A_DATA) - len(data)) 2245 if not chunk: 2246 break 2247 data += chunk 2248 except openssl_ssl.WantReadError: 2249 pass 2250 self.assertEqual(data, A_DATA) 2251 conn.renegotiate() 2252 if conn.renegotiate_pending(): 2253 conn.send(b'OK') 2254 else: 2255 conn.send(b'ER') 2256 2257 data = b'' 2258 while len(data) < len(B_DATA): 2259 try: 2260 chunk = conn.recv(len(B_DATA) - len(data)) 2261 if not chunk: 2262 break 2263 data += chunk 2264 except openssl_ssl.WantReadError: 2265 pass 2266 self.assertEqual(data, B_DATA) 2267 if conn.renegotiate_pending(): 2268 conn.send(b'ERRO') 2269 else: 2270 conn.send(b'SPAM') 2271 2272 conn.shutdown() 2273 2274 async def client(addr): 2275 extras = dict(ssl_handshake_timeout=SSL_HANDSHAKE_TIMEOUT) 2276 2277 reader, writer = await asyncio.open_connection( 2278 *addr, 2279 ssl=client_sslctx, 2280 server_hostname='', 2281 **extras) 2282 2283 writer.write(A_DATA) 2284 self.assertEqual(await reader.readexactly(2), b'OK') 2285 2286 writer.write(B_DATA) 2287 self.assertEqual(await reader.readexactly(4), b'SPAM') 2288 2289 nonlocal CNT 2290 CNT += 1 2291 2292 writer.close() 2293 await self.wait_closed(writer) 2294 2295 async def client_sock(addr): 2296 sock = socket.socket() 2297 sock.connect(addr) 2298 reader, writer = await asyncio.open_connection( 2299 sock=sock, 2300 ssl=client_sslctx, 2301 server_hostname='') 2302 2303 writer.write(A_DATA) 2304 self.assertEqual(await reader.readexactly(2), b'OK') 2305 2306 writer.write(B_DATA) 2307 self.assertEqual(await reader.readexactly(4), b'SPAM') 2308 2309 nonlocal CNT 2310 CNT += 1 2311 2312 writer.close() 2313 await self.wait_closed(writer) 2314 sock.close() 2315 2316 def run(coro): 2317 nonlocal CNT 2318 CNT = 0 2319 2320 with self.tcp_server(server, 2321 max_clients=TOTAL_CNT, 2322 backlog=TOTAL_CNT) as srv: 2323 tasks = [] 2324 for _ in range(TOTAL_CNT): 2325 tasks.append(coro(srv.addr)) 2326 2327 self.loop.run_until_complete( 2328 asyncio.gather(*tasks)) 2329 2330 self.assertEqual(CNT, TOTAL_CNT) 2331 2332 with self._silence_eof_received_warning(): 2333 run(client) 2334 2335 with self._silence_eof_received_warning(): 2336 run(client_sock) 2337 2338 def test_shutdown_timeout(self): 2339 if self.implementation == 'asyncio': 2340 raise unittest.SkipTest() 2341 2342 CNT = 0 # number of clients that were successful 2343 TOTAL_CNT = 25 # total number of clients that test will create 2344 TIMEOUT = 10.0 # timeout for this test 2345 2346 A_DATA = b'A' * 1024 * 1024 2347 2348 sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY) 2349 client_sslctx = self._create_client_ssl_context() 2350 2351 clients = [] 2352 2353 async def handle_client(reader, writer): 2354 nonlocal CNT 2355 2356 data = await reader.readexactly(len(A_DATA)) 2357 self.assertEqual(data, A_DATA) 2358 writer.write(b'OK') 2359 await writer.drain() 2360 writer.close() 2361 with self.assertRaisesRegex(asyncio.TimeoutError, 2362 'SSL shutdown timed out'): 2363 await reader.read() 2364 CNT += 1 2365 2366 async def test_client(addr): 2367 fut = asyncio.Future() 2368 2369 def prog(sock): 2370 try: 2371 sock.starttls(client_sslctx) 2372 sock.connect(addr) 2373 sock.send(A_DATA) 2374 2375 data = sock.recv_all(2) 2376 self.assertEqual(data, b'OK') 2377 2378 data = sock.recv(1024) 2379 self.assertEqual(data, b'') 2380 2381 fd = sock.detach() 2382 try: 2383 select.select([fd], [], [], 3) 2384 finally: 2385 os.close(fd) 2386 2387 except Exception as ex: 2388 self.loop.call_soon_threadsafe(fut.set_exception, ex) 2389 else: 2390 self.loop.call_soon_threadsafe(fut.set_result, None) 2391 2392 client = self.tcp_client(prog) 2393 client.start() 2394 clients.append(client) 2395 2396 await fut 2397 2398 async def start_server(): 2399 extras = {'ssl_handshake_timeout': SSL_HANDSHAKE_TIMEOUT} 2400 if self.implementation != 'asyncio': # or self.PY38 2401 extras['ssl_shutdown_timeout'] = 0.5 2402 2403 srv = await asyncio.start_server( 2404 handle_client, 2405 '127.0.0.1', 0, 2406 family=socket.AF_INET, 2407 ssl=sslctx, 2408 **extras) 2409 2410 try: 2411 srv_socks = srv.sockets 2412 self.assertTrue(srv_socks) 2413 2414 addr = srv_socks[0].getsockname() 2415 2416 tasks = [] 2417 for _ in range(TOTAL_CNT): 2418 tasks.append(test_client(addr)) 2419 2420 await asyncio.wait_for( 2421 asyncio.gather(*tasks), 2422 TIMEOUT) 2423 2424 finally: 2425 self.loop.call_soon(srv.close) 2426 await srv.wait_closed() 2427 2428 with self._silence_eof_received_warning(): 2429 self.loop.run_until_complete(start_server()) 2430 2431 self.assertEqual(CNT, TOTAL_CNT) 2432 2433 for client in clients: 2434 client.stop() 2435 2436 def test_shutdown_cleanly(self): 2437 if self.implementation == 'asyncio': 2438 raise unittest.SkipTest() 2439 2440 CNT = 0 2441 TOTAL_CNT = 25 2442 2443 A_DATA = b'A' * 1024 * 1024 2444 2445 sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY) 2446 client_sslctx = self._create_client_ssl_context() 2447 2448 def server(sock): 2449 sock.starttls( 2450 sslctx, 2451 server_side=True) 2452 2453 data = sock.recv_all(len(A_DATA)) 2454 self.assertEqual(data, A_DATA) 2455 sock.send(b'OK') 2456 2457 sock.unwrap() 2458 2459 sock.close() 2460 2461 async def client(addr): 2462 extras = dict(ssl_handshake_timeout=SSL_HANDSHAKE_TIMEOUT) 2463 2464 reader, writer = await asyncio.open_connection( 2465 *addr, 2466 ssl=client_sslctx, 2467 server_hostname='', 2468 **extras) 2469 2470 writer.write(A_DATA) 2471 self.assertEqual(await reader.readexactly(2), b'OK') 2472 2473 self.assertEqual(await reader.read(), b'') 2474 2475 nonlocal CNT 2476 CNT += 1 2477 2478 writer.close() 2479 await self.wait_closed(writer) 2480 2481 def run(coro): 2482 nonlocal CNT 2483 CNT = 0 2484 2485 with self.tcp_server(server, 2486 max_clients=TOTAL_CNT, 2487 backlog=TOTAL_CNT) as srv: 2488 tasks = [] 2489 for _ in range(TOTAL_CNT): 2490 tasks.append(coro(srv.addr)) 2491 2492 self.loop.run_until_complete( 2493 asyncio.gather(*tasks)) 2494 2495 self.assertEqual(CNT, TOTAL_CNT) 2496 2497 with self._silence_eof_received_warning(): 2498 run(client) 2499 2500 def test_write_to_closed_transport(self): 2501 if self.implementation == 'asyncio': 2502 raise unittest.SkipTest() 2503 2504 sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY) 2505 client_sslctx = self._create_client_ssl_context() 2506 future = None 2507 2508 def server(sock): 2509 sock.starttls(sslctx, server_side=True) 2510 sock.shutdown(socket.SHUT_RDWR) 2511 sock.close() 2512 2513 def unwrap_server(sock): 2514 sock.starttls(sslctx, server_side=True) 2515 while True: 2516 try: 2517 sock.unwrap() 2518 break 2519 except ssl.SSLError as ex: 2520 # Since OpenSSL 1.1.1, it raises "application data after 2521 # close notify" 2522 # Python < 3.8: 2523 if ex.reason == 'KRB5_S_INIT': 2524 break 2525 # Python >= 3.8: 2526 if ex.reason == 'APPLICATION_DATA_AFTER_CLOSE_NOTIFY': 2527 break 2528 raise ex 2529 except OSError as ex: 2530 # OpenSSL < 1.1.1 2531 if ex.errno != 0: 2532 raise 2533 sock.close() 2534 2535 async def client(addr): 2536 nonlocal future 2537 future = self.loop.create_future() 2538 2539 reader, writer = await asyncio.open_connection( 2540 *addr, 2541 ssl=client_sslctx, 2542 server_hostname='') 2543 writer.write(b'I AM WRITING NOWHERE1' * 100) 2544 2545 try: 2546 data = await reader.read() 2547 self.assertEqual(data, b'') 2548 except (ConnectionResetError, BrokenPipeError): 2549 pass 2550 2551 for i in range(25): 2552 writer.write(b'I AM WRITING NOWHERE2' * 100) 2553 2554 self.assertEqual( 2555 writer.transport.get_write_buffer_size(), 0) 2556 2557 await future 2558 2559 writer.close() 2560 await self.wait_closed(writer) 2561 2562 def run(meth): 2563 def wrapper(sock): 2564 try: 2565 meth(sock) 2566 except Exception as ex: 2567 self.loop.call_soon_threadsafe(future.set_exception, ex) 2568 else: 2569 self.loop.call_soon_threadsafe(future.set_result, None) 2570 return wrapper 2571 2572 with self._silence_eof_received_warning(): 2573 with self.tcp_server(run(server)) as srv: 2574 self.loop.run_until_complete(client(srv.addr)) 2575 2576 with self.tcp_server(run(unwrap_server)) as srv: 2577 self.loop.run_until_complete(client(srv.addr)) 2578 2579 def test_flush_before_shutdown(self): 2580 if self.implementation == 'asyncio': 2581 raise unittest.SkipTest() 2582 2583 CHUNK = 1024 * 128 2584 SIZE = 32 2585 2586 sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY) 2587 sslctx_openssl = openssl_ssl.Context(openssl_ssl.TLSv1_2_METHOD) 2588 if hasattr(openssl_ssl, 'OP_NO_SSLV2'): 2589 sslctx_openssl.set_options(openssl_ssl.OP_NO_SSLV2) 2590 sslctx_openssl.use_privatekey_file(self.ONLYKEY) 2591 sslctx_openssl.use_certificate_chain_file(self.ONLYCERT) 2592 client_sslctx = self._create_client_ssl_context() 2593 if hasattr(ssl, 'OP_NO_TLSv1_3'): 2594 client_sslctx.options |= ssl.OP_NO_TLSv1_3 2595 2596 future = None 2597 2598 def server(sock): 2599 sock.starttls(sslctx, server_side=True) 2600 self.assertEqual(sock.recv_all(4), b'ping') 2601 sock.send(b'pong') 2602 time.sleep(0.5) # hopefully stuck the TCP buffer 2603 data = sock.recv_all(CHUNK * SIZE) 2604 self.assertEqual(len(data), CHUNK * SIZE) 2605 sock.close() 2606 2607 def run(meth): 2608 def wrapper(sock): 2609 try: 2610 meth(sock) 2611 except Exception as ex: 2612 self.loop.call_soon_threadsafe(future.set_exception, ex) 2613 else: 2614 self.loop.call_soon_threadsafe(future.set_result, None) 2615 return wrapper 2616 2617 async def client(addr): 2618 nonlocal future 2619 future = self.loop.create_future() 2620 reader, writer = await asyncio.open_connection( 2621 *addr, 2622 ssl=client_sslctx, 2623 server_hostname='') 2624 sslprotocol = writer.get_extra_info('uvloop.sslproto') 2625 writer.write(b'ping') 2626 data = await reader.readexactly(4) 2627 self.assertEqual(data, b'pong') 2628 2629 sslprotocol.pause_writing() 2630 for _ in range(SIZE): 2631 writer.write(b'x' * CHUNK) 2632 2633 writer.close() 2634 sslprotocol.resume_writing() 2635 2636 await self.wait_closed(writer) 2637 try: 2638 data = await reader.read() 2639 self.assertEqual(data, b'') 2640 except ConnectionResetError: 2641 pass 2642 await future 2643 2644 with self.tcp_server(run(server)) as srv: 2645 self.loop.run_until_complete(client(srv.addr)) 2646 2647 def test_remote_shutdown_receives_trailing_data(self): 2648 CHUNK = 1024 * 16 2649 SIZE = 8 2650 count = 0 2651 2652 sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY) 2653 client_sslctx = self._create_client_ssl_context() 2654 future = None 2655 filled = threading.Lock() 2656 eof_received = threading.Lock() 2657 2658 def server(sock): 2659 incoming = ssl.MemoryBIO() 2660 outgoing = ssl.MemoryBIO() 2661 sslobj = sslctx.wrap_bio(incoming, outgoing, server_side=True) 2662 2663 while True: 2664 try: 2665 sslobj.do_handshake() 2666 except ssl.SSLWantReadError: 2667 if outgoing.pending: 2668 sock.send(outgoing.read()) 2669 incoming.write(sock.recv(16384)) 2670 else: 2671 if outgoing.pending: 2672 sock.send(outgoing.read()) 2673 break 2674 2675 while True: 2676 try: 2677 data = sslobj.read(4) 2678 except ssl.SSLWantReadError: 2679 incoming.write(sock.recv(16384)) 2680 else: 2681 break 2682 2683 self.assertEqual(data, b'ping') 2684 sslobj.write(b'pong') 2685 sock.send(outgoing.read()) 2686 2687 data_len = 0 2688 2689 with filled: 2690 # trigger peer's resume_writing() 2691 incoming.write(sock.recv(65536 * 4)) 2692 while True: 2693 try: 2694 chunk = len(sslobj.read(16384)) 2695 data_len += chunk 2696 except ssl.SSLWantReadError: 2697 break 2698 2699 # send close_notify but don't wait for response 2700 with self.assertRaises(ssl.SSLWantReadError): 2701 sslobj.unwrap() 2702 sock.send(outgoing.read()) 2703 2704 with eof_received: 2705 # should receive all data 2706 while True: 2707 try: 2708 chunk = len(sslobj.read(16384)) 2709 data_len += chunk 2710 except ssl.SSLWantReadError: 2711 incoming.write(sock.recv(16384)) 2712 if not incoming.pending: 2713 # EOF received 2714 break 2715 except ssl.SSLZeroReturnError: 2716 break 2717 2718 self.assertEqual(data_len, CHUNK * count) 2719 2720 if self.implementation == 'uvloop': 2721 # Verify that close_notify is received. asyncio is currently 2722 # not guaranteed to send close_notify before dropping off 2723 sslobj.unwrap() 2724 2725 sock.close() 2726 2727 async def client(addr): 2728 nonlocal future, count 2729 future = self.loop.create_future() 2730 2731 with eof_received: 2732 with filled: 2733 reader, writer = await asyncio.open_connection( 2734 *addr, 2735 ssl=client_sslctx, 2736 server_hostname='') 2737 writer.write(b'ping') 2738 data = await reader.readexactly(4) 2739 self.assertEqual(data, b'pong') 2740 2741 count = 0 2742 try: 2743 while True: 2744 writer.write(b'x' * CHUNK) 2745 count += 1 2746 await asyncio.wait_for( 2747 asyncio.ensure_future(writer.drain()), 0.5) 2748 except asyncio.TimeoutError: 2749 # fill write backlog in a hacky way for uvloop 2750 if self.implementation == 'uvloop': 2751 for _ in range(SIZE): 2752 writer.transport._test__append_write_backlog( 2753 b'x' * CHUNK) 2754 count += 1 2755 2756 data = await reader.read() 2757 self.assertEqual(data, b'') 2758 2759 await future 2760 2761 writer.close() 2762 await self.wait_closed(writer) 2763 2764 def run(meth): 2765 def wrapper(sock): 2766 try: 2767 meth(sock) 2768 except Exception as ex: 2769 self.loop.call_soon_threadsafe(future.set_exception, ex) 2770 else: 2771 self.loop.call_soon_threadsafe(future.set_result, None) 2772 return wrapper 2773 2774 with self.tcp_server(run(server)) as srv: 2775 self.loop.run_until_complete(client(srv.addr)) 2776 2777 def test_connect_timeout_warning(self): 2778 s = socket.socket(socket.AF_INET) 2779 s.bind(('127.0.0.1', 0)) 2780 addr = s.getsockname() 2781 2782 async def test(): 2783 try: 2784 await asyncio.wait_for( 2785 self.loop.create_connection(asyncio.Protocol, 2786 *addr, ssl=True), 2787 0.1) 2788 except (ConnectionRefusedError, asyncio.TimeoutError): 2789 pass 2790 else: 2791 self.fail('TimeoutError is not raised') 2792 2793 with s: 2794 try: 2795 with self.assertWarns(ResourceWarning) as cm: 2796 self.loop.run_until_complete(test()) 2797 gc.collect() 2798 gc.collect() 2799 gc.collect() 2800 except AssertionError as e: 2801 self.assertEqual(str(e), 'ResourceWarning not triggered') 2802 else: 2803 self.fail('Unexpected ResourceWarning: {}'.format(cm.warning)) 2804 2805 def test_handshake_timeout_handler_leak(self): 2806 if self.implementation == 'asyncio': 2807 # Okay this turns out to be an issue for asyncio.sslproto too 2808 raise unittest.SkipTest() 2809 2810 s = socket.socket(socket.AF_INET) 2811 s.bind(('127.0.0.1', 0)) 2812 s.listen(1) 2813 addr = s.getsockname() 2814 2815 async def test(ctx): 2816 try: 2817 await asyncio.wait_for( 2818 self.loop.create_connection(asyncio.Protocol, *addr, 2819 ssl=ctx), 2820 0.1) 2821 except (ConnectionRefusedError, asyncio.TimeoutError): 2822 pass 2823 else: 2824 self.fail('TimeoutError is not raised') 2825 2826 with s: 2827 ctx = ssl.create_default_context() 2828 self.loop.run_until_complete(test(ctx)) 2829 ctx = weakref.ref(ctx) 2830 2831 # SSLProtocol should be DECREF to 0 2832 self.assertIsNone(ctx()) 2833 2834 def test_shutdown_timeout_handler_leak(self): 2835 loop = self.loop 2836 2837 def server(sock): 2838 sslctx = self._create_server_ssl_context(self.ONLYCERT, 2839 self.ONLYKEY) 2840 sock = sslctx.wrap_socket(sock, server_side=True) 2841 sock.recv(32) 2842 sock.close() 2843 2844 class Protocol(asyncio.Protocol): 2845 def __init__(self): 2846 self.fut = asyncio.Future(loop=loop) 2847 2848 def connection_lost(self, exc): 2849 self.fut.set_result(None) 2850 2851 async def client(addr, ctx): 2852 tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx) 2853 tr.close() 2854 await pr.fut 2855 2856 with self.tcp_server(server) as srv: 2857 ctx = self._create_client_ssl_context() 2858 loop.run_until_complete(client(srv.addr, ctx)) 2859 ctx = weakref.ref(ctx) 2860 2861 if self.implementation == 'asyncio': 2862 # asyncio has no shutdown timeout, but it ends up with a circular 2863 # reference loop - not ideal (introduces gc glitches), but at least 2864 # not leaking 2865 gc.collect() 2866 gc.collect() 2867 gc.collect() 2868 2869 # SSLProtocol should be DECREF to 0 2870 self.assertIsNone(ctx()) 2871 2872 def test_shutdown_timeout_handler_not_set(self): 2873 if self.implementation == 'asyncio': 2874 # asyncio doesn't call SSL eof_received() so we can't run this test 2875 raise unittest.SkipTest() 2876 2877 loop = self.loop 2878 extra = None 2879 2880 def server(sock): 2881 sslctx = self._create_server_ssl_context(self.ONLYCERT, 2882 self.ONLYKEY) 2883 sock = sslctx.wrap_socket(sock, server_side=True) 2884 sock.send(b'hello') 2885 assert sock.recv(1024) == b'world' 2886 sock.send(b'extra bytes') 2887 # sending EOF here 2888 sock.shutdown(socket.SHUT_WR) 2889 # make sure we have enough time to reproduce the issue 2890 self.assertEqual(sock.recv(1024), b'') 2891 sock.close() 2892 2893 class Protocol(asyncio.Protocol): 2894 def __init__(self): 2895 self.fut = asyncio.Future(loop=loop) 2896 self.transport = None 2897 2898 def connection_made(self, transport): 2899 self.transport = transport 2900 2901 def data_received(self, data): 2902 if data == b'hello': 2903 self.transport.write(b'world') 2904 # pause reading would make incoming data stay in the sslobj 2905 self.transport.pause_reading() 2906 else: 2907 nonlocal extra 2908 extra = data 2909 2910 def connection_lost(self, exc): 2911 if exc is None: 2912 self.fut.set_result(None) 2913 else: 2914 self.fut.set_exception(exc) 2915 2916 def eof_received(self): 2917 self.transport.resume_reading() 2918 2919 async def client(addr): 2920 ctx = self._create_client_ssl_context() 2921 tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx) 2922 await pr.fut 2923 tr.close() 2924 # extra data received after transport.close() should be ignored 2925 self.assertIsNone(extra) 2926 2927 with self.tcp_server(server) as srv: 2928 loop.run_until_complete(client(srv.addr)) 2929 2930 def test_shutdown_while_pause_reading(self): 2931 if self.implementation == 'asyncio': 2932 raise unittest.SkipTest() 2933 2934 loop = self.loop 2935 conn_made = loop.create_future() 2936 eof_recvd = loop.create_future() 2937 conn_lost = loop.create_future() 2938 data_recv = False 2939 2940 def server(sock): 2941 sslctx = self._create_server_ssl_context(self.ONLYCERT, 2942 self.ONLYKEY) 2943 incoming = ssl.MemoryBIO() 2944 outgoing = ssl.MemoryBIO() 2945 sslobj = sslctx.wrap_bio(incoming, outgoing, server_side=True) 2946 2947 while True: 2948 try: 2949 sslobj.do_handshake() 2950 sslobj.write(b'trailing data') 2951 break 2952 except ssl.SSLWantReadError: 2953 if outgoing.pending: 2954 sock.send(outgoing.read()) 2955 incoming.write(sock.recv(16384)) 2956 if outgoing.pending: 2957 sock.send(outgoing.read()) 2958 2959 while True: 2960 try: 2961 self.assertEqual(sslobj.read(), b'') # close_notify 2962 break 2963 except ssl.SSLWantReadError: 2964 incoming.write(sock.recv(16384)) 2965 2966 while True: 2967 try: 2968 sslobj.unwrap() 2969 except ssl.SSLWantReadError: 2970 if outgoing.pending: 2971 sock.send(outgoing.read()) 2972 incoming.write(sock.recv(16384)) 2973 else: 2974 if outgoing.pending: 2975 sock.send(outgoing.read()) 2976 break 2977 2978 self.assertEqual(sock.recv(16384), b'') # socket closed 2979 2980 class Protocol(asyncio.Protocol): 2981 def connection_made(self, transport): 2982 conn_made.set_result(None) 2983 2984 def data_received(self, data): 2985 nonlocal data_recv 2986 data_recv = True 2987 2988 def eof_received(self): 2989 eof_recvd.set_result(None) 2990 2991 def connection_lost(self, exc): 2992 if exc is None: 2993 conn_lost.set_result(None) 2994 else: 2995 conn_lost.set_exception(exc) 2996 2997 async def client(addr): 2998 ctx = self._create_client_ssl_context() 2999 tr, _ = await loop.create_connection(Protocol, *addr, ssl=ctx) 3000 await conn_made 3001 self.assertFalse(data_recv) 3002 3003 tr.pause_reading() 3004 tr.close() 3005 3006 await asyncio.wait_for(eof_recvd, 10) 3007 await asyncio.wait_for(conn_lost, 10) 3008 3009 with self.tcp_server(server) as srv: 3010 loop.run_until_complete(client(srv.addr)) 3011 3012 def test_bpo_39951_discard_trailing_data(self): 3013 sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY) 3014 client_sslctx = self._create_client_ssl_context() 3015 future = None 3016 close_notify = threading.Lock() 3017 3018 def server(sock): 3019 incoming = ssl.MemoryBIO() 3020 outgoing = ssl.MemoryBIO() 3021 sslobj = sslctx.wrap_bio(incoming, outgoing, server_side=True) 3022 3023 while True: 3024 try: 3025 sslobj.do_handshake() 3026 except ssl.SSLWantReadError: 3027 if outgoing.pending: 3028 sock.send(outgoing.read()) 3029 incoming.write(sock.recv(16384)) 3030 else: 3031 if outgoing.pending: 3032 sock.send(outgoing.read()) 3033 break 3034 3035 while True: 3036 try: 3037 data = sslobj.read(4) 3038 except ssl.SSLWantReadError: 3039 incoming.write(sock.recv(16384)) 3040 else: 3041 break 3042 3043 self.assertEqual(data, b'ping') 3044 sslobj.write(b'pong') 3045 sock.send(outgoing.read()) 3046 3047 with close_notify: 3048 sslobj.write(b'trailing') 3049 sock.send(outgoing.read()) 3050 time.sleep(0.5) # allow time for the client to receive 3051 3052 incoming.write(sock.recv(16384)) 3053 sslobj.unwrap() 3054 sock.send(outgoing.read()) 3055 sock.close() 3056 3057 async def client(addr): 3058 nonlocal future 3059 future = self.loop.create_future() 3060 3061 with close_notify: 3062 reader, writer = await asyncio.open_connection( 3063 *addr, 3064 ssl=client_sslctx, 3065 server_hostname='') 3066 writer.write(b'ping') 3067 data = await reader.readexactly(4) 3068 self.assertEqual(data, b'pong') 3069 3070 writer.close() 3071 3072 try: 3073 await self.wait_closed(writer) 3074 except ssl.SSLError as e: 3075 if self.implementation == 'asyncio' and \ 3076 'application data after close notify' in str(e): 3077 raise unittest.SkipTest('bpo-39951') 3078 raise 3079 await future 3080 3081 def run(meth): 3082 def wrapper(sock): 3083 try: 3084 meth(sock) 3085 except Exception as ex: 3086 self.loop.call_soon_threadsafe(future.set_exception, ex) 3087 else: 3088 self.loop.call_soon_threadsafe(future.set_result, None) 3089 return wrapper 3090 3091 with self.tcp_server(run(server)) as srv: 3092 self.loop.run_until_complete(client(srv.addr)) 3093 3094 def test_first_data_after_wakeup(self): 3095 if self.implementation == 'asyncio': 3096 raise unittest.SkipTest() 3097 3098 server_context = self._create_server_ssl_context( 3099 self.ONLYCERT, self.ONLYKEY) 3100 client_context = self._create_client_ssl_context() 3101 loop = self.loop 3102 this = self 3103 fut = self.loop.create_future() 3104 3105 def client(sock, addr): 3106 try: 3107 sock.connect(addr) 3108 3109 incoming = ssl.MemoryBIO() 3110 outgoing = ssl.MemoryBIO() 3111 sslobj = client_context.wrap_bio(incoming, outgoing) 3112 3113 # Do handshake manually so that we could collect the last piece 3114 while True: 3115 try: 3116 sslobj.do_handshake() 3117 break 3118 except ssl.SSLWantReadError: 3119 if outgoing.pending: 3120 sock.send(outgoing.read()) 3121 incoming.write(sock.recv(65536)) 3122 3123 # Send the first data together with the last handshake payload 3124 sslobj.write(b'hello') 3125 sock.send(outgoing.read()) 3126 3127 while True: 3128 try: 3129 incoming.write(sock.recv(65536)) 3130 self.assertEqual(sslobj.read(1024), b'hello') 3131 break 3132 except ssl.SSLWantReadError: 3133 pass 3134 3135 sock.close() 3136 3137 except Exception as ex: 3138 loop.call_soon_threadsafe(fut.set_exception, ex) 3139 sock.close() 3140 else: 3141 loop.call_soon_threadsafe(fut.set_result, None) 3142 3143 class EchoProto(asyncio.Protocol): 3144 def connection_made(self, tr): 3145 self.tr = tr 3146 # manually run the coroutine, in order to avoid accidental data 3147 coro = loop.start_tls( 3148 tr, self, server_context, 3149 server_side=True, 3150 ssl_handshake_timeout=this.TIMEOUT, 3151 ) 3152 waiter = coro.send(None) 3153 3154 def tls_started(_): 3155 try: 3156 coro.send(None) 3157 except StopIteration as e: 3158 # update self.tr to SSL transport as soon as we know it 3159 self.tr = e.value 3160 3161 waiter.add_done_callback(tls_started) 3162 3163 def data_received(self, data): 3164 # This is a dumb protocol that writes back whatever it receives 3165 # regardless of whether self.tr is SSL or not 3166 self.tr.write(data) 3167 3168 async def run_main(): 3169 proto = EchoProto() 3170 3171 server = await self.loop.create_server( 3172 lambda: proto, '127.0.0.1', 0) 3173 addr = server.sockets[0].getsockname() 3174 3175 with self.tcp_client(lambda sock: client(sock, addr), 3176 timeout=self.TIMEOUT): 3177 await asyncio.wait_for(fut, timeout=self.TIMEOUT) 3178 proto.tr.close() 3179 3180 server.close() 3181 await server.wait_closed() 3182 3183 self.loop.run_until_complete(run_main()) 3184 3185 3186class Test_UV_TCPSSL(_TestSSL, tb.UVTestCase): 3187 pass 3188 3189 3190class Test_AIO_TCPSSL(_TestSSL, tb.AIOTestCase): 3191 pass 3192