1"""Tests for asyncio/sslproto.py.""" 2 3import logging 4import socket 5import sys 6import unittest 7import weakref 8from unittest import mock 9try: 10 import ssl 11except ImportError: 12 ssl = None 13 14import asyncio 15from asyncio import log 16from asyncio import protocols 17from asyncio import sslproto 18from test.test_asyncio import utils as test_utils 19from test.test_asyncio import functional as func_tests 20 21 22def tearDownModule(): 23 asyncio.set_event_loop_policy(None) 24 25 26@unittest.skipIf(ssl is None, 'No ssl module') 27class SslProtoHandshakeTests(test_utils.TestCase): 28 29 def setUp(self): 30 super().setUp() 31 self.loop = asyncio.new_event_loop() 32 self.set_event_loop(self.loop) 33 34 def ssl_protocol(self, *, waiter=None, proto=None): 35 sslcontext = test_utils.dummy_ssl_context() 36 if proto is None: # app protocol 37 proto = asyncio.Protocol() 38 ssl_proto = sslproto.SSLProtocol(self.loop, proto, sslcontext, waiter, 39 ssl_handshake_timeout=0.1) 40 self.assertIs(ssl_proto._app_transport.get_protocol(), proto) 41 self.addCleanup(ssl_proto._app_transport.close) 42 return ssl_proto 43 44 def connection_made(self, ssl_proto, *, do_handshake=None): 45 transport = mock.Mock() 46 sslpipe = mock.Mock() 47 sslpipe.shutdown.return_value = b'' 48 if do_handshake: 49 sslpipe.do_handshake.side_effect = do_handshake 50 else: 51 def mock_handshake(callback): 52 return [] 53 sslpipe.do_handshake.side_effect = mock_handshake 54 with mock.patch('asyncio.sslproto._SSLPipe', return_value=sslpipe): 55 ssl_proto.connection_made(transport) 56 return transport 57 58 def test_handshake_timeout_zero(self): 59 sslcontext = test_utils.dummy_ssl_context() 60 app_proto = mock.Mock() 61 waiter = mock.Mock() 62 with self.assertRaisesRegex(ValueError, 'a positive number'): 63 sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter, 64 ssl_handshake_timeout=0) 65 66 def test_handshake_timeout_negative(self): 67 sslcontext = test_utils.dummy_ssl_context() 68 app_proto = mock.Mock() 69 waiter = mock.Mock() 70 with self.assertRaisesRegex(ValueError, 'a positive number'): 71 sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter, 72 ssl_handshake_timeout=-10) 73 74 def test_eof_received_waiter(self): 75 waiter = self.loop.create_future() 76 ssl_proto = self.ssl_protocol(waiter=waiter) 77 self.connection_made(ssl_proto) 78 ssl_proto.eof_received() 79 test_utils.run_briefly(self.loop) 80 self.assertIsInstance(waiter.exception(), ConnectionResetError) 81 82 def test_fatal_error_no_name_error(self): 83 # From issue #363. 84 # _fatal_error() generates a NameError if sslproto.py 85 # does not import base_events. 86 waiter = self.loop.create_future() 87 ssl_proto = self.ssl_protocol(waiter=waiter) 88 # Temporarily turn off error logging so as not to spoil test output. 89 log_level = log.logger.getEffectiveLevel() 90 log.logger.setLevel(logging.FATAL) 91 try: 92 ssl_proto._fatal_error(None) 93 finally: 94 # Restore error logging. 95 log.logger.setLevel(log_level) 96 97 def test_connection_lost(self): 98 # From issue #472. 99 # yield from waiter hang if lost_connection was called. 100 waiter = self.loop.create_future() 101 ssl_proto = self.ssl_protocol(waiter=waiter) 102 self.connection_made(ssl_proto) 103 ssl_proto.connection_lost(ConnectionAbortedError) 104 test_utils.run_briefly(self.loop) 105 self.assertIsInstance(waiter.exception(), ConnectionAbortedError) 106 107 def test_close_during_handshake(self): 108 # bpo-29743 Closing transport during handshake process leaks socket 109 waiter = self.loop.create_future() 110 ssl_proto = self.ssl_protocol(waiter=waiter) 111 112 transport = self.connection_made(ssl_proto) 113 test_utils.run_briefly(self.loop) 114 115 ssl_proto._app_transport.close() 116 self.assertTrue(transport.abort.called) 117 118 def test_get_extra_info_on_closed_connection(self): 119 waiter = self.loop.create_future() 120 ssl_proto = self.ssl_protocol(waiter=waiter) 121 self.assertIsNone(ssl_proto._get_extra_info('socket')) 122 default = object() 123 self.assertIs(ssl_proto._get_extra_info('socket', default), default) 124 self.connection_made(ssl_proto) 125 self.assertIsNotNone(ssl_proto._get_extra_info('socket')) 126 ssl_proto.connection_lost(None) 127 self.assertIsNone(ssl_proto._get_extra_info('socket')) 128 129 def test_set_new_app_protocol(self): 130 waiter = self.loop.create_future() 131 ssl_proto = self.ssl_protocol(waiter=waiter) 132 new_app_proto = asyncio.Protocol() 133 ssl_proto._app_transport.set_protocol(new_app_proto) 134 self.assertIs(ssl_proto._app_transport.get_protocol(), new_app_proto) 135 self.assertIs(ssl_proto._app_protocol, new_app_proto) 136 137 def test_data_received_after_closing(self): 138 ssl_proto = self.ssl_protocol() 139 self.connection_made(ssl_proto) 140 transp = ssl_proto._app_transport 141 142 transp.close() 143 144 # should not raise 145 self.assertIsNone(ssl_proto.data_received(b'data')) 146 147 def test_write_after_closing(self): 148 ssl_proto = self.ssl_protocol() 149 self.connection_made(ssl_proto) 150 transp = ssl_proto._app_transport 151 transp.close() 152 153 # should not raise 154 self.assertIsNone(transp.write(b'data')) 155 156 157############################################################################## 158# Start TLS Tests 159############################################################################## 160 161 162class BaseStartTLS(func_tests.FunctionalTestCaseMixin): 163 164 PAYLOAD_SIZE = 1024 * 100 165 TIMEOUT = 60 166 167 def new_loop(self): 168 raise NotImplementedError 169 170 def test_buf_feed_data(self): 171 172 class Proto(asyncio.BufferedProtocol): 173 174 def __init__(self, bufsize, usemv): 175 self.buf = bytearray(bufsize) 176 self.mv = memoryview(self.buf) 177 self.data = b'' 178 self.usemv = usemv 179 180 def get_buffer(self, sizehint): 181 if self.usemv: 182 return self.mv 183 else: 184 return self.buf 185 186 def buffer_updated(self, nsize): 187 if self.usemv: 188 self.data += self.mv[:nsize] 189 else: 190 self.data += self.buf[:nsize] 191 192 for usemv in [False, True]: 193 proto = Proto(1, usemv) 194 protocols._feed_data_to_buffered_proto(proto, b'12345') 195 self.assertEqual(proto.data, b'12345') 196 197 proto = Proto(2, usemv) 198 protocols._feed_data_to_buffered_proto(proto, b'12345') 199 self.assertEqual(proto.data, b'12345') 200 201 proto = Proto(2, usemv) 202 protocols._feed_data_to_buffered_proto(proto, b'1234') 203 self.assertEqual(proto.data, b'1234') 204 205 proto = Proto(4, usemv) 206 protocols._feed_data_to_buffered_proto(proto, b'1234') 207 self.assertEqual(proto.data, b'1234') 208 209 proto = Proto(100, usemv) 210 protocols._feed_data_to_buffered_proto(proto, b'12345') 211 self.assertEqual(proto.data, b'12345') 212 213 proto = Proto(0, usemv) 214 with self.assertRaisesRegex(RuntimeError, 'empty buffer'): 215 protocols._feed_data_to_buffered_proto(proto, b'12345') 216 217 def test_start_tls_client_reg_proto_1(self): 218 HELLO_MSG = b'1' * self.PAYLOAD_SIZE 219 220 server_context = test_utils.simple_server_sslcontext() 221 client_context = test_utils.simple_client_sslcontext() 222 223 def serve(sock): 224 sock.settimeout(self.TIMEOUT) 225 226 data = sock.recv_all(len(HELLO_MSG)) 227 self.assertEqual(len(data), len(HELLO_MSG)) 228 229 sock.start_tls(server_context, server_side=True) 230 231 sock.sendall(b'O') 232 data = sock.recv_all(len(HELLO_MSG)) 233 self.assertEqual(len(data), len(HELLO_MSG)) 234 235 sock.shutdown(socket.SHUT_RDWR) 236 sock.close() 237 238 class ClientProto(asyncio.Protocol): 239 def __init__(self, on_data, on_eof): 240 self.on_data = on_data 241 self.on_eof = on_eof 242 self.con_made_cnt = 0 243 244 def connection_made(proto, tr): 245 proto.con_made_cnt += 1 246 # Ensure connection_made gets called only once. 247 self.assertEqual(proto.con_made_cnt, 1) 248 249 def data_received(self, data): 250 self.on_data.set_result(data) 251 252 def eof_received(self): 253 self.on_eof.set_result(True) 254 255 async def client(addr): 256 await asyncio.sleep(0.5) 257 258 on_data = self.loop.create_future() 259 on_eof = self.loop.create_future() 260 261 tr, proto = await self.loop.create_connection( 262 lambda: ClientProto(on_data, on_eof), *addr) 263 264 tr.write(HELLO_MSG) 265 new_tr = await self.loop.start_tls(tr, proto, client_context) 266 267 self.assertEqual(await on_data, b'O') 268 new_tr.write(HELLO_MSG) 269 await on_eof 270 271 new_tr.close() 272 273 with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: 274 self.loop.run_until_complete( 275 asyncio.wait_for(client(srv.addr), timeout=10)) 276 277 # No garbage is left if SSL is closed uncleanly 278 client_context = weakref.ref(client_context) 279 self.assertIsNone(client_context()) 280 281 def test_create_connection_memory_leak(self): 282 HELLO_MSG = b'1' * self.PAYLOAD_SIZE 283 284 server_context = test_utils.simple_server_sslcontext() 285 client_context = test_utils.simple_client_sslcontext() 286 287 def serve(sock): 288 sock.settimeout(self.TIMEOUT) 289 290 sock.start_tls(server_context, server_side=True) 291 292 sock.sendall(b'O') 293 data = sock.recv_all(len(HELLO_MSG)) 294 self.assertEqual(len(data), len(HELLO_MSG)) 295 296 sock.shutdown(socket.SHUT_RDWR) 297 sock.close() 298 299 class ClientProto(asyncio.Protocol): 300 def __init__(self, on_data, on_eof): 301 self.on_data = on_data 302 self.on_eof = on_eof 303 self.con_made_cnt = 0 304 305 def connection_made(proto, tr): 306 # XXX: We assume user stores the transport in protocol 307 proto.tr = tr 308 proto.con_made_cnt += 1 309 # Ensure connection_made gets called only once. 310 self.assertEqual(proto.con_made_cnt, 1) 311 312 def data_received(self, data): 313 self.on_data.set_result(data) 314 315 def eof_received(self): 316 self.on_eof.set_result(True) 317 318 async def client(addr): 319 await asyncio.sleep(0.5) 320 321 on_data = self.loop.create_future() 322 on_eof = self.loop.create_future() 323 324 tr, proto = await self.loop.create_connection( 325 lambda: ClientProto(on_data, on_eof), *addr, 326 ssl=client_context) 327 328 self.assertEqual(await on_data, b'O') 329 tr.write(HELLO_MSG) 330 await on_eof 331 332 tr.close() 333 334 with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: 335 self.loop.run_until_complete( 336 asyncio.wait_for(client(srv.addr), timeout=10)) 337 338 # No garbage is left for SSL client from loop.create_connection, even 339 # if user stores the SSLTransport in corresponding protocol instance 340 client_context = weakref.ref(client_context) 341 self.assertIsNone(client_context()) 342 343 def test_start_tls_client_buf_proto_1(self): 344 HELLO_MSG = b'1' * self.PAYLOAD_SIZE 345 346 server_context = test_utils.simple_server_sslcontext() 347 client_context = test_utils.simple_client_sslcontext() 348 client_con_made_calls = 0 349 350 def serve(sock): 351 sock.settimeout(self.TIMEOUT) 352 353 data = sock.recv_all(len(HELLO_MSG)) 354 self.assertEqual(len(data), len(HELLO_MSG)) 355 356 sock.start_tls(server_context, server_side=True) 357 358 sock.sendall(b'O') 359 data = sock.recv_all(len(HELLO_MSG)) 360 self.assertEqual(len(data), len(HELLO_MSG)) 361 362 sock.sendall(b'2') 363 data = sock.recv_all(len(HELLO_MSG)) 364 self.assertEqual(len(data), len(HELLO_MSG)) 365 366 sock.shutdown(socket.SHUT_RDWR) 367 sock.close() 368 369 class ClientProtoFirst(asyncio.BufferedProtocol): 370 def __init__(self, on_data): 371 self.on_data = on_data 372 self.buf = bytearray(1) 373 374 def connection_made(self, tr): 375 nonlocal client_con_made_calls 376 client_con_made_calls += 1 377 378 def get_buffer(self, sizehint): 379 return self.buf 380 381 def buffer_updated(self, nsize): 382 assert nsize == 1 383 self.on_data.set_result(bytes(self.buf[:nsize])) 384 385 class ClientProtoSecond(asyncio.Protocol): 386 def __init__(self, on_data, on_eof): 387 self.on_data = on_data 388 self.on_eof = on_eof 389 self.con_made_cnt = 0 390 391 def connection_made(self, tr): 392 nonlocal client_con_made_calls 393 client_con_made_calls += 1 394 395 def data_received(self, data): 396 self.on_data.set_result(data) 397 398 def eof_received(self): 399 self.on_eof.set_result(True) 400 401 async def client(addr): 402 await asyncio.sleep(0.5) 403 404 on_data1 = self.loop.create_future() 405 on_data2 = self.loop.create_future() 406 on_eof = self.loop.create_future() 407 408 tr, proto = await self.loop.create_connection( 409 lambda: ClientProtoFirst(on_data1), *addr) 410 411 tr.write(HELLO_MSG) 412 new_tr = await self.loop.start_tls(tr, proto, client_context) 413 414 self.assertEqual(await on_data1, b'O') 415 new_tr.write(HELLO_MSG) 416 417 new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof)) 418 self.assertEqual(await on_data2, b'2') 419 new_tr.write(HELLO_MSG) 420 await on_eof 421 422 new_tr.close() 423 424 # connection_made() should be called only once -- when 425 # we establish connection for the first time. Start TLS 426 # doesn't call connection_made() on application protocols. 427 self.assertEqual(client_con_made_calls, 1) 428 429 with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: 430 self.loop.run_until_complete( 431 asyncio.wait_for(client(srv.addr), 432 timeout=self.TIMEOUT)) 433 434 def test_start_tls_slow_client_cancel(self): 435 HELLO_MSG = b'1' * self.PAYLOAD_SIZE 436 437 client_context = test_utils.simple_client_sslcontext() 438 server_waits_on_handshake = self.loop.create_future() 439 440 def serve(sock): 441 sock.settimeout(self.TIMEOUT) 442 443 data = sock.recv_all(len(HELLO_MSG)) 444 self.assertEqual(len(data), len(HELLO_MSG)) 445 446 try: 447 self.loop.call_soon_threadsafe( 448 server_waits_on_handshake.set_result, None) 449 data = sock.recv_all(1024 * 1024) 450 except ConnectionAbortedError: 451 pass 452 finally: 453 sock.close() 454 455 class ClientProto(asyncio.Protocol): 456 def __init__(self, on_data, on_eof): 457 self.on_data = on_data 458 self.on_eof = on_eof 459 self.con_made_cnt = 0 460 461 def connection_made(proto, tr): 462 proto.con_made_cnt += 1 463 # Ensure connection_made gets called only once. 464 self.assertEqual(proto.con_made_cnt, 1) 465 466 def data_received(self, data): 467 self.on_data.set_result(data) 468 469 def eof_received(self): 470 self.on_eof.set_result(True) 471 472 async def client(addr): 473 await asyncio.sleep(0.5) 474 475 on_data = self.loop.create_future() 476 on_eof = self.loop.create_future() 477 478 tr, proto = await self.loop.create_connection( 479 lambda: ClientProto(on_data, on_eof), *addr) 480 481 tr.write(HELLO_MSG) 482 483 await server_waits_on_handshake 484 485 with self.assertRaises(asyncio.TimeoutError): 486 await asyncio.wait_for( 487 self.loop.start_tls(tr, proto, client_context), 488 0.5) 489 490 with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: 491 self.loop.run_until_complete( 492 asyncio.wait_for(client(srv.addr), timeout=10)) 493 494 def test_start_tls_server_1(self): 495 HELLO_MSG = b'1' * self.PAYLOAD_SIZE 496 ANSWER = b'answer' 497 498 server_context = test_utils.simple_server_sslcontext() 499 client_context = test_utils.simple_client_sslcontext() 500 answer = None 501 502 def client(sock, addr): 503 nonlocal answer 504 sock.settimeout(self.TIMEOUT) 505 506 sock.connect(addr) 507 data = sock.recv_all(len(HELLO_MSG)) 508 self.assertEqual(len(data), len(HELLO_MSG)) 509 510 sock.start_tls(client_context) 511 sock.sendall(HELLO_MSG) 512 answer = sock.recv_all(len(ANSWER)) 513 sock.close() 514 515 class ServerProto(asyncio.Protocol): 516 def __init__(self, on_con, on_con_lost, on_got_hello): 517 self.on_con = on_con 518 self.on_con_lost = on_con_lost 519 self.on_got_hello = on_got_hello 520 self.data = b'' 521 self.transport = None 522 523 def connection_made(self, tr): 524 self.transport = tr 525 self.on_con.set_result(tr) 526 527 def replace_transport(self, tr): 528 self.transport = tr 529 530 def data_received(self, data): 531 self.data += data 532 if len(self.data) >= len(HELLO_MSG): 533 self.on_got_hello.set_result(None) 534 535 def connection_lost(self, exc): 536 self.transport = None 537 if exc is None: 538 self.on_con_lost.set_result(None) 539 else: 540 self.on_con_lost.set_exception(exc) 541 542 async def main(proto, on_con, on_con_lost, on_got_hello): 543 tr = await on_con 544 tr.write(HELLO_MSG) 545 546 self.assertEqual(proto.data, b'') 547 548 new_tr = await self.loop.start_tls( 549 tr, proto, server_context, 550 server_side=True, 551 ssl_handshake_timeout=self.TIMEOUT) 552 proto.replace_transport(new_tr) 553 554 await on_got_hello 555 new_tr.write(ANSWER) 556 557 await on_con_lost 558 self.assertEqual(proto.data, HELLO_MSG) 559 new_tr.close() 560 561 async def run_main(): 562 on_con = self.loop.create_future() 563 on_con_lost = self.loop.create_future() 564 on_got_hello = self.loop.create_future() 565 proto = ServerProto(on_con, on_con_lost, on_got_hello) 566 567 server = await self.loop.create_server( 568 lambda: proto, '127.0.0.1', 0) 569 addr = server.sockets[0].getsockname() 570 571 with self.tcp_client(lambda sock: client(sock, addr), 572 timeout=self.TIMEOUT): 573 await asyncio.wait_for( 574 main(proto, on_con, on_con_lost, on_got_hello), 575 timeout=self.TIMEOUT) 576 577 server.close() 578 await server.wait_closed() 579 self.assertEqual(answer, ANSWER) 580 581 self.loop.run_until_complete(run_main()) 582 583 def test_start_tls_wrong_args(self): 584 async def main(): 585 with self.assertRaisesRegex(TypeError, 'SSLContext, got'): 586 await self.loop.start_tls(None, None, None) 587 588 sslctx = test_utils.simple_server_sslcontext() 589 with self.assertRaisesRegex(TypeError, 'is not supported'): 590 await self.loop.start_tls(None, None, sslctx) 591 592 self.loop.run_until_complete(main()) 593 594 def test_handshake_timeout(self): 595 # bpo-29970: Check that a connection is aborted if handshake is not 596 # completed in timeout period, instead of remaining open indefinitely 597 client_sslctx = test_utils.simple_client_sslcontext() 598 599 messages = [] 600 self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) 601 602 server_side_aborted = False 603 604 def server(sock): 605 nonlocal server_side_aborted 606 try: 607 sock.recv_all(1024 * 1024) 608 except ConnectionAbortedError: 609 server_side_aborted = True 610 finally: 611 sock.close() 612 613 async def client(addr): 614 await asyncio.wait_for( 615 self.loop.create_connection( 616 asyncio.Protocol, 617 *addr, 618 ssl=client_sslctx, 619 server_hostname='', 620 ssl_handshake_timeout=10.0), 621 0.5) 622 623 with self.tcp_server(server, 624 max_clients=1, 625 backlog=1) as srv: 626 627 with self.assertRaises(asyncio.TimeoutError): 628 self.loop.run_until_complete(client(srv.addr)) 629 630 self.assertTrue(server_side_aborted) 631 632 # Python issue #23197: cancelling a handshake must not raise an 633 # exception or log an error, even if the handshake failed 634 self.assertEqual(messages, []) 635 636 # The 10s handshake timeout should be cancelled to free related 637 # objects without really waiting for 10s 638 client_sslctx = weakref.ref(client_sslctx) 639 self.assertIsNone(client_sslctx()) 640 641 def test_create_connection_ssl_slow_handshake(self): 642 client_sslctx = test_utils.simple_client_sslcontext() 643 644 messages = [] 645 self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) 646 647 def server(sock): 648 try: 649 sock.recv_all(1024 * 1024) 650 except ConnectionAbortedError: 651 pass 652 finally: 653 sock.close() 654 655 async def client(addr): 656 with self.assertWarns(DeprecationWarning): 657 reader, writer = await asyncio.open_connection( 658 *addr, 659 ssl=client_sslctx, 660 server_hostname='', 661 loop=self.loop, 662 ssl_handshake_timeout=1.0) 663 664 with self.tcp_server(server, 665 max_clients=1, 666 backlog=1) as srv: 667 668 with self.assertRaisesRegex( 669 ConnectionAbortedError, 670 r'SSL handshake.*is taking longer'): 671 672 self.loop.run_until_complete(client(srv.addr)) 673 674 self.assertEqual(messages, []) 675 676 def test_create_connection_ssl_failed_certificate(self): 677 self.loop.set_exception_handler(lambda loop, ctx: None) 678 679 sslctx = test_utils.simple_server_sslcontext() 680 client_sslctx = test_utils.simple_client_sslcontext( 681 disable_verify=False) 682 683 def server(sock): 684 try: 685 sock.start_tls( 686 sslctx, 687 server_side=True) 688 except ssl.SSLError: 689 pass 690 except OSError: 691 pass 692 finally: 693 sock.close() 694 695 async def client(addr): 696 with self.assertWarns(DeprecationWarning): 697 reader, writer = await asyncio.open_connection( 698 *addr, 699 ssl=client_sslctx, 700 server_hostname='', 701 loop=self.loop, 702 ssl_handshake_timeout=1.0) 703 704 with self.tcp_server(server, 705 max_clients=1, 706 backlog=1) as srv: 707 708 with self.assertRaises(ssl.SSLCertVerificationError): 709 self.loop.run_until_complete(client(srv.addr)) 710 711 def test_start_tls_client_corrupted_ssl(self): 712 self.loop.set_exception_handler(lambda loop, ctx: None) 713 714 sslctx = test_utils.simple_server_sslcontext() 715 client_sslctx = test_utils.simple_client_sslcontext() 716 717 def server(sock): 718 orig_sock = sock.dup() 719 try: 720 sock.start_tls( 721 sslctx, 722 server_side=True) 723 sock.sendall(b'A\n') 724 sock.recv_all(1) 725 orig_sock.send(b'please corrupt the SSL connection') 726 except ssl.SSLError: 727 pass 728 finally: 729 orig_sock.close() 730 sock.close() 731 732 async def client(addr): 733 with self.assertWarns(DeprecationWarning): 734 reader, writer = await asyncio.open_connection( 735 *addr, 736 ssl=client_sslctx, 737 server_hostname='', 738 loop=self.loop) 739 740 self.assertEqual(await reader.readline(), b'A\n') 741 writer.write(b'B') 742 with self.assertRaises(ssl.SSLError): 743 await reader.readline() 744 745 writer.close() 746 return 'OK' 747 748 with self.tcp_server(server, 749 max_clients=1, 750 backlog=1) as srv: 751 752 res = self.loop.run_until_complete(client(srv.addr)) 753 754 self.assertEqual(res, 'OK') 755 756 757@unittest.skipIf(ssl is None, 'No ssl module') 758class SelectorStartTLSTests(BaseStartTLS, unittest.TestCase): 759 760 def new_loop(self): 761 return asyncio.SelectorEventLoop() 762 763 764@unittest.skipIf(ssl is None, 'No ssl module') 765@unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only') 766class ProactorStartTLSTests(BaseStartTLS, unittest.TestCase): 767 768 def new_loop(self): 769 return asyncio.ProactorEventLoop() 770 771 772if __name__ == '__main__': 773 unittest.main() 774