1"""Tests for asyncio/sslproto.py."""
2
3import logging
4import socket
5import sys
6import unittest
7import weakref
8from unittest import mock
9try:
10    import ssl
11except ImportError:
12    ssl = None
13
14import asyncio
15from asyncio import log
16from asyncio import protocols
17from asyncio import sslproto
18from test.test_asyncio import utils as test_utils
19from test.test_asyncio import functional as func_tests
20
21
22def tearDownModule():
23    asyncio.set_event_loop_policy(None)
24
25
26@unittest.skipIf(ssl is None, 'No ssl module')
27class SslProtoHandshakeTests(test_utils.TestCase):
28
29    def setUp(self):
30        super().setUp()
31        self.loop = asyncio.new_event_loop()
32        self.set_event_loop(self.loop)
33
34    def ssl_protocol(self, *, waiter=None, proto=None):
35        sslcontext = test_utils.dummy_ssl_context()
36        if proto is None:  # app protocol
37            proto = asyncio.Protocol()
38        ssl_proto = sslproto.SSLProtocol(self.loop, proto, sslcontext, waiter,
39                                         ssl_handshake_timeout=0.1)
40        self.assertIs(ssl_proto._app_transport.get_protocol(), proto)
41        self.addCleanup(ssl_proto._app_transport.close)
42        return ssl_proto
43
44    def connection_made(self, ssl_proto, *, do_handshake=None):
45        transport = mock.Mock()
46        sslpipe = mock.Mock()
47        sslpipe.shutdown.return_value = b''
48        if do_handshake:
49            sslpipe.do_handshake.side_effect = do_handshake
50        else:
51            def mock_handshake(callback):
52                return []
53            sslpipe.do_handshake.side_effect = mock_handshake
54        with mock.patch('asyncio.sslproto._SSLPipe', return_value=sslpipe):
55            ssl_proto.connection_made(transport)
56        return transport
57
58    def test_handshake_timeout_zero(self):
59        sslcontext = test_utils.dummy_ssl_context()
60        app_proto = mock.Mock()
61        waiter = mock.Mock()
62        with self.assertRaisesRegex(ValueError, 'a positive number'):
63            sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter,
64                                 ssl_handshake_timeout=0)
65
66    def test_handshake_timeout_negative(self):
67        sslcontext = test_utils.dummy_ssl_context()
68        app_proto = mock.Mock()
69        waiter = mock.Mock()
70        with self.assertRaisesRegex(ValueError, 'a positive number'):
71            sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter,
72                                 ssl_handshake_timeout=-10)
73
74    def test_eof_received_waiter(self):
75        waiter = self.loop.create_future()
76        ssl_proto = self.ssl_protocol(waiter=waiter)
77        self.connection_made(ssl_proto)
78        ssl_proto.eof_received()
79        test_utils.run_briefly(self.loop)
80        self.assertIsInstance(waiter.exception(), ConnectionResetError)
81
82    def test_fatal_error_no_name_error(self):
83        # From issue #363.
84        # _fatal_error() generates a NameError if sslproto.py
85        # does not import base_events.
86        waiter = self.loop.create_future()
87        ssl_proto = self.ssl_protocol(waiter=waiter)
88        # Temporarily turn off error logging so as not to spoil test output.
89        log_level = log.logger.getEffectiveLevel()
90        log.logger.setLevel(logging.FATAL)
91        try:
92            ssl_proto._fatal_error(None)
93        finally:
94            # Restore error logging.
95            log.logger.setLevel(log_level)
96
97    def test_connection_lost(self):
98        # From issue #472.
99        # yield from waiter hang if lost_connection was called.
100        waiter = self.loop.create_future()
101        ssl_proto = self.ssl_protocol(waiter=waiter)
102        self.connection_made(ssl_proto)
103        ssl_proto.connection_lost(ConnectionAbortedError)
104        test_utils.run_briefly(self.loop)
105        self.assertIsInstance(waiter.exception(), ConnectionAbortedError)
106
107    def test_close_during_handshake(self):
108        # bpo-29743 Closing transport during handshake process leaks socket
109        waiter = self.loop.create_future()
110        ssl_proto = self.ssl_protocol(waiter=waiter)
111
112        transport = self.connection_made(ssl_proto)
113        test_utils.run_briefly(self.loop)
114
115        ssl_proto._app_transport.close()
116        self.assertTrue(transport.abort.called)
117
118    def test_get_extra_info_on_closed_connection(self):
119        waiter = self.loop.create_future()
120        ssl_proto = self.ssl_protocol(waiter=waiter)
121        self.assertIsNone(ssl_proto._get_extra_info('socket'))
122        default = object()
123        self.assertIs(ssl_proto._get_extra_info('socket', default), default)
124        self.connection_made(ssl_proto)
125        self.assertIsNotNone(ssl_proto._get_extra_info('socket'))
126        ssl_proto.connection_lost(None)
127        self.assertIsNone(ssl_proto._get_extra_info('socket'))
128
129    def test_set_new_app_protocol(self):
130        waiter = self.loop.create_future()
131        ssl_proto = self.ssl_protocol(waiter=waiter)
132        new_app_proto = asyncio.Protocol()
133        ssl_proto._app_transport.set_protocol(new_app_proto)
134        self.assertIs(ssl_proto._app_transport.get_protocol(), new_app_proto)
135        self.assertIs(ssl_proto._app_protocol, new_app_proto)
136
137    def test_data_received_after_closing(self):
138        ssl_proto = self.ssl_protocol()
139        self.connection_made(ssl_proto)
140        transp = ssl_proto._app_transport
141
142        transp.close()
143
144        # should not raise
145        self.assertIsNone(ssl_proto.data_received(b'data'))
146
147    def test_write_after_closing(self):
148        ssl_proto = self.ssl_protocol()
149        self.connection_made(ssl_proto)
150        transp = ssl_proto._app_transport
151        transp.close()
152
153        # should not raise
154        self.assertIsNone(transp.write(b'data'))
155
156
157##############################################################################
158# Start TLS Tests
159##############################################################################
160
161
162class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
163
164    PAYLOAD_SIZE = 1024 * 100
165    TIMEOUT = 60
166
167    def new_loop(self):
168        raise NotImplementedError
169
170    def test_buf_feed_data(self):
171
172        class Proto(asyncio.BufferedProtocol):
173
174            def __init__(self, bufsize, usemv):
175                self.buf = bytearray(bufsize)
176                self.mv = memoryview(self.buf)
177                self.data = b''
178                self.usemv = usemv
179
180            def get_buffer(self, sizehint):
181                if self.usemv:
182                    return self.mv
183                else:
184                    return self.buf
185
186            def buffer_updated(self, nsize):
187                if self.usemv:
188                    self.data += self.mv[:nsize]
189                else:
190                    self.data += self.buf[:nsize]
191
192        for usemv in [False, True]:
193            proto = Proto(1, usemv)
194            protocols._feed_data_to_buffered_proto(proto, b'12345')
195            self.assertEqual(proto.data, b'12345')
196
197            proto = Proto(2, usemv)
198            protocols._feed_data_to_buffered_proto(proto, b'12345')
199            self.assertEqual(proto.data, b'12345')
200
201            proto = Proto(2, usemv)
202            protocols._feed_data_to_buffered_proto(proto, b'1234')
203            self.assertEqual(proto.data, b'1234')
204
205            proto = Proto(4, usemv)
206            protocols._feed_data_to_buffered_proto(proto, b'1234')
207            self.assertEqual(proto.data, b'1234')
208
209            proto = Proto(100, usemv)
210            protocols._feed_data_to_buffered_proto(proto, b'12345')
211            self.assertEqual(proto.data, b'12345')
212
213            proto = Proto(0, usemv)
214            with self.assertRaisesRegex(RuntimeError, 'empty buffer'):
215                protocols._feed_data_to_buffered_proto(proto, b'12345')
216
217    def test_start_tls_client_reg_proto_1(self):
218        HELLO_MSG = b'1' * self.PAYLOAD_SIZE
219
220        server_context = test_utils.simple_server_sslcontext()
221        client_context = test_utils.simple_client_sslcontext()
222
223        def serve(sock):
224            sock.settimeout(self.TIMEOUT)
225
226            data = sock.recv_all(len(HELLO_MSG))
227            self.assertEqual(len(data), len(HELLO_MSG))
228
229            sock.start_tls(server_context, server_side=True)
230
231            sock.sendall(b'O')
232            data = sock.recv_all(len(HELLO_MSG))
233            self.assertEqual(len(data), len(HELLO_MSG))
234
235            sock.shutdown(socket.SHUT_RDWR)
236            sock.close()
237
238        class ClientProto(asyncio.Protocol):
239            def __init__(self, on_data, on_eof):
240                self.on_data = on_data
241                self.on_eof = on_eof
242                self.con_made_cnt = 0
243
244            def connection_made(proto, tr):
245                proto.con_made_cnt += 1
246                # Ensure connection_made gets called only once.
247                self.assertEqual(proto.con_made_cnt, 1)
248
249            def data_received(self, data):
250                self.on_data.set_result(data)
251
252            def eof_received(self):
253                self.on_eof.set_result(True)
254
255        async def client(addr):
256            await asyncio.sleep(0.5)
257
258            on_data = self.loop.create_future()
259            on_eof = self.loop.create_future()
260
261            tr, proto = await self.loop.create_connection(
262                lambda: ClientProto(on_data, on_eof), *addr)
263
264            tr.write(HELLO_MSG)
265            new_tr = await self.loop.start_tls(tr, proto, client_context)
266
267            self.assertEqual(await on_data, b'O')
268            new_tr.write(HELLO_MSG)
269            await on_eof
270
271            new_tr.close()
272
273        with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
274            self.loop.run_until_complete(
275                asyncio.wait_for(client(srv.addr), timeout=10))
276
277        # No garbage is left if SSL is closed uncleanly
278        client_context = weakref.ref(client_context)
279        self.assertIsNone(client_context())
280
281    def test_create_connection_memory_leak(self):
282        HELLO_MSG = b'1' * self.PAYLOAD_SIZE
283
284        server_context = test_utils.simple_server_sslcontext()
285        client_context = test_utils.simple_client_sslcontext()
286
287        def serve(sock):
288            sock.settimeout(self.TIMEOUT)
289
290            sock.start_tls(server_context, server_side=True)
291
292            sock.sendall(b'O')
293            data = sock.recv_all(len(HELLO_MSG))
294            self.assertEqual(len(data), len(HELLO_MSG))
295
296            sock.shutdown(socket.SHUT_RDWR)
297            sock.close()
298
299        class ClientProto(asyncio.Protocol):
300            def __init__(self, on_data, on_eof):
301                self.on_data = on_data
302                self.on_eof = on_eof
303                self.con_made_cnt = 0
304
305            def connection_made(proto, tr):
306                # XXX: We assume user stores the transport in protocol
307                proto.tr = tr
308                proto.con_made_cnt += 1
309                # Ensure connection_made gets called only once.
310                self.assertEqual(proto.con_made_cnt, 1)
311
312            def data_received(self, data):
313                self.on_data.set_result(data)
314
315            def eof_received(self):
316                self.on_eof.set_result(True)
317
318        async def client(addr):
319            await asyncio.sleep(0.5)
320
321            on_data = self.loop.create_future()
322            on_eof = self.loop.create_future()
323
324            tr, proto = await self.loop.create_connection(
325                lambda: ClientProto(on_data, on_eof), *addr,
326                ssl=client_context)
327
328            self.assertEqual(await on_data, b'O')
329            tr.write(HELLO_MSG)
330            await on_eof
331
332            tr.close()
333
334        with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
335            self.loop.run_until_complete(
336                asyncio.wait_for(client(srv.addr), timeout=10))
337
338        # No garbage is left for SSL client from loop.create_connection, even
339        # if user stores the SSLTransport in corresponding protocol instance
340        client_context = weakref.ref(client_context)
341        self.assertIsNone(client_context())
342
343    def test_start_tls_client_buf_proto_1(self):
344        HELLO_MSG = b'1' * self.PAYLOAD_SIZE
345
346        server_context = test_utils.simple_server_sslcontext()
347        client_context = test_utils.simple_client_sslcontext()
348        client_con_made_calls = 0
349
350        def serve(sock):
351            sock.settimeout(self.TIMEOUT)
352
353            data = sock.recv_all(len(HELLO_MSG))
354            self.assertEqual(len(data), len(HELLO_MSG))
355
356            sock.start_tls(server_context, server_side=True)
357
358            sock.sendall(b'O')
359            data = sock.recv_all(len(HELLO_MSG))
360            self.assertEqual(len(data), len(HELLO_MSG))
361
362            sock.sendall(b'2')
363            data = sock.recv_all(len(HELLO_MSG))
364            self.assertEqual(len(data), len(HELLO_MSG))
365
366            sock.shutdown(socket.SHUT_RDWR)
367            sock.close()
368
369        class ClientProtoFirst(asyncio.BufferedProtocol):
370            def __init__(self, on_data):
371                self.on_data = on_data
372                self.buf = bytearray(1)
373
374            def connection_made(self, tr):
375                nonlocal client_con_made_calls
376                client_con_made_calls += 1
377
378            def get_buffer(self, sizehint):
379                return self.buf
380
381            def buffer_updated(self, nsize):
382                assert nsize == 1
383                self.on_data.set_result(bytes(self.buf[:nsize]))
384
385        class ClientProtoSecond(asyncio.Protocol):
386            def __init__(self, on_data, on_eof):
387                self.on_data = on_data
388                self.on_eof = on_eof
389                self.con_made_cnt = 0
390
391            def connection_made(self, tr):
392                nonlocal client_con_made_calls
393                client_con_made_calls += 1
394
395            def data_received(self, data):
396                self.on_data.set_result(data)
397
398            def eof_received(self):
399                self.on_eof.set_result(True)
400
401        async def client(addr):
402            await asyncio.sleep(0.5)
403
404            on_data1 = self.loop.create_future()
405            on_data2 = self.loop.create_future()
406            on_eof = self.loop.create_future()
407
408            tr, proto = await self.loop.create_connection(
409                lambda: ClientProtoFirst(on_data1), *addr)
410
411            tr.write(HELLO_MSG)
412            new_tr = await self.loop.start_tls(tr, proto, client_context)
413
414            self.assertEqual(await on_data1, b'O')
415            new_tr.write(HELLO_MSG)
416
417            new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof))
418            self.assertEqual(await on_data2, b'2')
419            new_tr.write(HELLO_MSG)
420            await on_eof
421
422            new_tr.close()
423
424            # connection_made() should be called only once -- when
425            # we establish connection for the first time. Start TLS
426            # doesn't call connection_made() on application protocols.
427            self.assertEqual(client_con_made_calls, 1)
428
429        with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
430            self.loop.run_until_complete(
431                asyncio.wait_for(client(srv.addr),
432                                 timeout=self.TIMEOUT))
433
434    def test_start_tls_slow_client_cancel(self):
435        HELLO_MSG = b'1' * self.PAYLOAD_SIZE
436
437        client_context = test_utils.simple_client_sslcontext()
438        server_waits_on_handshake = self.loop.create_future()
439
440        def serve(sock):
441            sock.settimeout(self.TIMEOUT)
442
443            data = sock.recv_all(len(HELLO_MSG))
444            self.assertEqual(len(data), len(HELLO_MSG))
445
446            try:
447                self.loop.call_soon_threadsafe(
448                    server_waits_on_handshake.set_result, None)
449                data = sock.recv_all(1024 * 1024)
450            except ConnectionAbortedError:
451                pass
452            finally:
453                sock.close()
454
455        class ClientProto(asyncio.Protocol):
456            def __init__(self, on_data, on_eof):
457                self.on_data = on_data
458                self.on_eof = on_eof
459                self.con_made_cnt = 0
460
461            def connection_made(proto, tr):
462                proto.con_made_cnt += 1
463                # Ensure connection_made gets called only once.
464                self.assertEqual(proto.con_made_cnt, 1)
465
466            def data_received(self, data):
467                self.on_data.set_result(data)
468
469            def eof_received(self):
470                self.on_eof.set_result(True)
471
472        async def client(addr):
473            await asyncio.sleep(0.5)
474
475            on_data = self.loop.create_future()
476            on_eof = self.loop.create_future()
477
478            tr, proto = await self.loop.create_connection(
479                lambda: ClientProto(on_data, on_eof), *addr)
480
481            tr.write(HELLO_MSG)
482
483            await server_waits_on_handshake
484
485            with self.assertRaises(asyncio.TimeoutError):
486                await asyncio.wait_for(
487                    self.loop.start_tls(tr, proto, client_context),
488                    0.5)
489
490        with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
491            self.loop.run_until_complete(
492                asyncio.wait_for(client(srv.addr), timeout=10))
493
494    def test_start_tls_server_1(self):
495        HELLO_MSG = b'1' * self.PAYLOAD_SIZE
496        ANSWER = b'answer'
497
498        server_context = test_utils.simple_server_sslcontext()
499        client_context = test_utils.simple_client_sslcontext()
500        answer = None
501
502        def client(sock, addr):
503            nonlocal answer
504            sock.settimeout(self.TIMEOUT)
505
506            sock.connect(addr)
507            data = sock.recv_all(len(HELLO_MSG))
508            self.assertEqual(len(data), len(HELLO_MSG))
509
510            sock.start_tls(client_context)
511            sock.sendall(HELLO_MSG)
512            answer = sock.recv_all(len(ANSWER))
513            sock.close()
514
515        class ServerProto(asyncio.Protocol):
516            def __init__(self, on_con, on_con_lost, on_got_hello):
517                self.on_con = on_con
518                self.on_con_lost = on_con_lost
519                self.on_got_hello = on_got_hello
520                self.data = b''
521                self.transport = None
522
523            def connection_made(self, tr):
524                self.transport = tr
525                self.on_con.set_result(tr)
526
527            def replace_transport(self, tr):
528                self.transport = tr
529
530            def data_received(self, data):
531                self.data += data
532                if len(self.data) >= len(HELLO_MSG):
533                    self.on_got_hello.set_result(None)
534
535            def connection_lost(self, exc):
536                self.transport = None
537                if exc is None:
538                    self.on_con_lost.set_result(None)
539                else:
540                    self.on_con_lost.set_exception(exc)
541
542        async def main(proto, on_con, on_con_lost, on_got_hello):
543            tr = await on_con
544            tr.write(HELLO_MSG)
545
546            self.assertEqual(proto.data, b'')
547
548            new_tr = await self.loop.start_tls(
549                tr, proto, server_context,
550                server_side=True,
551                ssl_handshake_timeout=self.TIMEOUT)
552            proto.replace_transport(new_tr)
553
554            await on_got_hello
555            new_tr.write(ANSWER)
556
557            await on_con_lost
558            self.assertEqual(proto.data, HELLO_MSG)
559            new_tr.close()
560
561        async def run_main():
562            on_con = self.loop.create_future()
563            on_con_lost = self.loop.create_future()
564            on_got_hello = self.loop.create_future()
565            proto = ServerProto(on_con, on_con_lost, on_got_hello)
566
567            server = await self.loop.create_server(
568                lambda: proto, '127.0.0.1', 0)
569            addr = server.sockets[0].getsockname()
570
571            with self.tcp_client(lambda sock: client(sock, addr),
572                                 timeout=self.TIMEOUT):
573                await asyncio.wait_for(
574                    main(proto, on_con, on_con_lost, on_got_hello),
575                    timeout=self.TIMEOUT)
576
577            server.close()
578            await server.wait_closed()
579            self.assertEqual(answer, ANSWER)
580
581        self.loop.run_until_complete(run_main())
582
583    def test_start_tls_wrong_args(self):
584        async def main():
585            with self.assertRaisesRegex(TypeError, 'SSLContext, got'):
586                await self.loop.start_tls(None, None, None)
587
588            sslctx = test_utils.simple_server_sslcontext()
589            with self.assertRaisesRegex(TypeError, 'is not supported'):
590                await self.loop.start_tls(None, None, sslctx)
591
592        self.loop.run_until_complete(main())
593
594    def test_handshake_timeout(self):
595        # bpo-29970: Check that a connection is aborted if handshake is not
596        # completed in timeout period, instead of remaining open indefinitely
597        client_sslctx = test_utils.simple_client_sslcontext()
598
599        messages = []
600        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
601
602        server_side_aborted = False
603
604        def server(sock):
605            nonlocal server_side_aborted
606            try:
607                sock.recv_all(1024 * 1024)
608            except ConnectionAbortedError:
609                server_side_aborted = True
610            finally:
611                sock.close()
612
613        async def client(addr):
614            await asyncio.wait_for(
615                self.loop.create_connection(
616                    asyncio.Protocol,
617                    *addr,
618                    ssl=client_sslctx,
619                    server_hostname='',
620                    ssl_handshake_timeout=10.0),
621                0.5)
622
623        with self.tcp_server(server,
624                             max_clients=1,
625                             backlog=1) as srv:
626
627            with self.assertRaises(asyncio.TimeoutError):
628                self.loop.run_until_complete(client(srv.addr))
629
630        self.assertTrue(server_side_aborted)
631
632        # Python issue #23197: cancelling a handshake must not raise an
633        # exception or log an error, even if the handshake failed
634        self.assertEqual(messages, [])
635
636        # The 10s handshake timeout should be cancelled to free related
637        # objects without really waiting for 10s
638        client_sslctx = weakref.ref(client_sslctx)
639        self.assertIsNone(client_sslctx())
640
641    def test_create_connection_ssl_slow_handshake(self):
642        client_sslctx = test_utils.simple_client_sslcontext()
643
644        messages = []
645        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
646
647        def server(sock):
648            try:
649                sock.recv_all(1024 * 1024)
650            except ConnectionAbortedError:
651                pass
652            finally:
653                sock.close()
654
655        async def client(addr):
656            with self.assertWarns(DeprecationWarning):
657                reader, writer = await asyncio.open_connection(
658                    *addr,
659                    ssl=client_sslctx,
660                    server_hostname='',
661                    loop=self.loop,
662                    ssl_handshake_timeout=1.0)
663
664        with self.tcp_server(server,
665                             max_clients=1,
666                             backlog=1) as srv:
667
668            with self.assertRaisesRegex(
669                    ConnectionAbortedError,
670                    r'SSL handshake.*is taking longer'):
671
672                self.loop.run_until_complete(client(srv.addr))
673
674        self.assertEqual(messages, [])
675
676    def test_create_connection_ssl_failed_certificate(self):
677        self.loop.set_exception_handler(lambda loop, ctx: None)
678
679        sslctx = test_utils.simple_server_sslcontext()
680        client_sslctx = test_utils.simple_client_sslcontext(
681            disable_verify=False)
682
683        def server(sock):
684            try:
685                sock.start_tls(
686                    sslctx,
687                    server_side=True)
688            except ssl.SSLError:
689                pass
690            except OSError:
691                pass
692            finally:
693                sock.close()
694
695        async def client(addr):
696            with self.assertWarns(DeprecationWarning):
697                reader, writer = await asyncio.open_connection(
698                    *addr,
699                    ssl=client_sslctx,
700                    server_hostname='',
701                    loop=self.loop,
702                    ssl_handshake_timeout=1.0)
703
704        with self.tcp_server(server,
705                             max_clients=1,
706                             backlog=1) as srv:
707
708            with self.assertRaises(ssl.SSLCertVerificationError):
709                self.loop.run_until_complete(client(srv.addr))
710
711    def test_start_tls_client_corrupted_ssl(self):
712        self.loop.set_exception_handler(lambda loop, ctx: None)
713
714        sslctx = test_utils.simple_server_sslcontext()
715        client_sslctx = test_utils.simple_client_sslcontext()
716
717        def server(sock):
718            orig_sock = sock.dup()
719            try:
720                sock.start_tls(
721                    sslctx,
722                    server_side=True)
723                sock.sendall(b'A\n')
724                sock.recv_all(1)
725                orig_sock.send(b'please corrupt the SSL connection')
726            except ssl.SSLError:
727                pass
728            finally:
729                orig_sock.close()
730                sock.close()
731
732        async def client(addr):
733            with self.assertWarns(DeprecationWarning):
734                reader, writer = await asyncio.open_connection(
735                    *addr,
736                    ssl=client_sslctx,
737                    server_hostname='',
738                    loop=self.loop)
739
740            self.assertEqual(await reader.readline(), b'A\n')
741            writer.write(b'B')
742            with self.assertRaises(ssl.SSLError):
743                await reader.readline()
744
745            writer.close()
746            return 'OK'
747
748        with self.tcp_server(server,
749                             max_clients=1,
750                             backlog=1) as srv:
751
752            res = self.loop.run_until_complete(client(srv.addr))
753
754        self.assertEqual(res, 'OK')
755
756
757@unittest.skipIf(ssl is None, 'No ssl module')
758class SelectorStartTLSTests(BaseStartTLS, unittest.TestCase):
759
760    def new_loop(self):
761        return asyncio.SelectorEventLoop()
762
763
764@unittest.skipIf(ssl is None, 'No ssl module')
765@unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only')
766class ProactorStartTLSTests(BaseStartTLS, unittest.TestCase):
767
768    def new_loop(self):
769        return asyncio.ProactorEventLoop()
770
771
772if __name__ == '__main__':
773    unittest.main()
774