1import asyncio
2import asyncio.sslproto
3import gc
4import os
5import select
6import socket
7import unittest.mock
8import ssl
9import sys
10import threading
11import time
12import weakref
13
14from OpenSSL import SSL as openssl_ssl
15from uvloop import _testbase as tb
16
17
18SSL_HANDSHAKE_TIMEOUT = 15.0
19
20
21class MyBaseProto(asyncio.Protocol):
22    connected = None
23    done = None
24
25    def __init__(self, loop=None):
26        self.transport = None
27        self.state = 'INITIAL'
28        self.nbytes = 0
29        if loop is not None:
30            self.connected = asyncio.Future(loop=loop)
31            self.done = asyncio.Future(loop=loop)
32
33    def connection_made(self, transport):
34        self.transport = transport
35        assert self.state == 'INITIAL', self.state
36        self.state = 'CONNECTED'
37        if self.connected:
38            self.connected.set_result(None)
39
40    def data_received(self, data):
41        assert self.state == 'CONNECTED', self.state
42        self.nbytes += len(data)
43
44    def eof_received(self):
45        assert self.state == 'CONNECTED', self.state
46        self.state = 'EOF'
47
48    def connection_lost(self, exc):
49        assert self.state in ('CONNECTED', 'EOF'), self.state
50        self.state = 'CLOSED'
51        if self.done:
52            self.done.set_result(None)
53
54
55class _TestTCP:
56    def test_create_server_1(self):
57        CNT = 0           # number of clients that were successful
58        TOTAL_CNT = 25    # total number of clients that test will create
59        TIMEOUT = 5.0     # timeout for this test
60
61        A_DATA = b'A' * 1024 * 1024
62        B_DATA = b'B' * 1024 * 1024
63
64        async def handle_client(reader, writer):
65            nonlocal CNT
66
67            data = await reader.readexactly(len(A_DATA))
68            self.assertEqual(data, A_DATA)
69            writer.write(b'OK')
70
71            data = await reader.readexactly(len(B_DATA))
72            self.assertEqual(data, B_DATA)
73            writer.writelines([b'S', b'P'])
74            writer.write(bytearray(b'A'))
75            writer.write(memoryview(b'M'))
76
77            if self.implementation == 'uvloop':
78                tr = writer.transport
79                sock = tr.get_extra_info('socket')
80                self.assertTrue(
81                    sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY))
82
83            await writer.drain()
84            writer.close()
85
86            CNT += 1
87
88        async def test_client(addr):
89            sock = socket.socket()
90            with sock:
91                sock.setblocking(False)
92                await self.loop.sock_connect(sock, addr)
93
94                await self.loop.sock_sendall(sock, A_DATA)
95
96                buf = b''
97                while len(buf) != 2:
98                    buf += await self.loop.sock_recv(sock, 1)
99                self.assertEqual(buf, b'OK')
100
101                await self.loop.sock_sendall(sock, B_DATA)
102
103                buf = b''
104                while len(buf) != 4:
105                    buf += await self.loop.sock_recv(sock, 1)
106                self.assertEqual(buf, b'SPAM')
107
108            self.assertEqual(sock.fileno(), -1)
109            self.assertEqual(sock._io_refs, 0)
110            self.assertTrue(sock._closed)
111
112        async def start_server():
113            nonlocal CNT
114            CNT = 0
115
116            srv = await asyncio.start_server(
117                handle_client,
118                ('127.0.0.1', 'localhost'), 0,
119                family=socket.AF_INET)
120
121            srv_socks = srv.sockets
122            self.assertTrue(srv_socks)
123            self.assertTrue(srv.is_serving())
124
125            addr = srv_socks[0].getsockname()
126
127            tasks = []
128            for _ in range(TOTAL_CNT):
129                tasks.append(test_client(addr))
130
131            await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT)
132
133            self.loop.call_soon(srv.close)
134            await srv.wait_closed()
135
136            # Check that the server cleaned-up proxy-sockets
137            for srv_sock in srv_socks:
138                self.assertEqual(srv_sock.fileno(), -1)
139
140            self.assertFalse(srv.is_serving())
141
142        async def start_server_sock():
143            nonlocal CNT
144            CNT = 0
145
146            sock = socket.socket()
147            sock.bind(('127.0.0.1', 0))
148            addr = sock.getsockname()
149
150            srv = await asyncio.start_server(
151                handle_client,
152                None, None,
153                family=socket.AF_INET,
154                sock=sock)
155
156            self.assertIs(srv.get_loop(), self.loop)
157
158            srv_socks = srv.sockets
159            self.assertTrue(srv_socks)
160            self.assertTrue(srv.is_serving())
161
162            tasks = []
163            for _ in range(TOTAL_CNT):
164                tasks.append(test_client(addr))
165
166            await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT)
167
168            srv.close()
169            await srv.wait_closed()
170
171            # Check that the server cleaned-up proxy-sockets
172            for srv_sock in srv_socks:
173                self.assertEqual(srv_sock.fileno(), -1)
174
175            self.assertFalse(srv.is_serving())
176
177        self.loop.run_until_complete(start_server())
178        self.assertEqual(CNT, TOTAL_CNT)
179
180        self.loop.run_until_complete(start_server_sock())
181        self.assertEqual(CNT, TOTAL_CNT)
182
183    def test_create_server_2(self):
184        with self.assertRaisesRegex(ValueError, 'nor sock were specified'):
185            self.loop.run_until_complete(self.loop.create_server(object))
186
187    def test_create_server_3(self):
188        ''' check ephemeral port can be used '''
189
190        async def start_server_ephemeral_ports():
191
192            for port_sentinel in [0, None]:
193                srv = await self.loop.create_server(
194                    asyncio.Protocol,
195                    '127.0.0.1', port_sentinel,
196                    family=socket.AF_INET)
197
198                srv_socks = srv.sockets
199                self.assertTrue(srv_socks)
200                self.assertTrue(srv.is_serving())
201
202                host, port = srv_socks[0].getsockname()
203                self.assertNotEqual(0, port)
204
205                self.loop.call_soon(srv.close)
206                await srv.wait_closed()
207
208                # Check that the server cleaned-up proxy-sockets
209                for srv_sock in srv_socks:
210                    self.assertEqual(srv_sock.fileno(), -1)
211
212                self.assertFalse(srv.is_serving())
213
214        self.loop.run_until_complete(start_server_ephemeral_ports())
215
216    def test_create_server_4(self):
217        sock = socket.socket()
218        sock.bind(('127.0.0.1', 0))
219
220        with sock:
221            addr = sock.getsockname()
222
223            with self.assertRaisesRegex(OSError,
224                                        r"error while attempting.*\('127.*: "
225                                        r"address already in use"):
226
227                self.loop.run_until_complete(
228                    self.loop.create_server(object, *addr))
229
230    def test_create_server_5(self):
231        # Test that create_server sets the TCP_IPV6ONLY flag,
232        # so it can bind to ipv4 and ipv6 addresses
233        # simultaneously.
234
235        port = tb.find_free_port()
236
237        async def runner():
238            srv = await self.loop.create_server(
239                asyncio.Protocol,
240                None, port)
241
242            srv.close()
243            await srv.wait_closed()
244
245        self.loop.run_until_complete(runner())
246
247    def test_create_server_6(self):
248        if not hasattr(socket, 'SO_REUSEPORT'):
249            raise unittest.SkipTest(
250                'The system does not support SO_REUSEPORT')
251
252        port = tb.find_free_port()
253
254        async def runner():
255            srv1 = await self.loop.create_server(
256                asyncio.Protocol,
257                None, port,
258                reuse_port=True)
259
260            srv2 = await self.loop.create_server(
261                asyncio.Protocol,
262                None, port,
263                reuse_port=True)
264
265            srv1.close()
266            srv2.close()
267
268            await srv1.wait_closed()
269            await srv2.wait_closed()
270
271        self.loop.run_until_complete(runner())
272
273    def test_create_server_7(self):
274        # Test that create_server() stores a hard ref to the server object
275        # somewhere in the loop.  In asyncio it so happens that
276        # loop.sock_accept() has a reference to the server object so it
277        # never gets GCed.
278
279        class Proto(asyncio.Protocol):
280            def connection_made(self, tr):
281                self.tr = tr
282                self.tr.write(b'hello')
283
284        async def test():
285            port = tb.find_free_port()
286            srv = await self.loop.create_server(Proto, '127.0.0.1', port)
287            wsrv = weakref.ref(srv)
288            del srv
289
290            gc.collect()
291            gc.collect()
292            gc.collect()
293
294            s = socket.socket(socket.AF_INET)
295            with s:
296                s.setblocking(False)
297                await self.loop.sock_connect(s, ('127.0.0.1', port))
298                d = await self.loop.sock_recv(s, 100)
299                self.assertEqual(d, b'hello')
300
301            srv = wsrv()
302            srv.close()
303            await srv.wait_closed()
304            del srv
305
306            # Let all transports shutdown.
307            await asyncio.sleep(0.1)
308
309            gc.collect()
310            gc.collect()
311            gc.collect()
312
313            self.assertIsNone(wsrv())
314
315        self.loop.run_until_complete(test())
316
317    def test_create_server_8(self):
318        with self.assertRaisesRegex(
319                ValueError, 'ssl_handshake_timeout is only meaningful'):
320            self.loop.run_until_complete(
321                self.loop.create_server(
322                    lambda: None, host='::', port=0, ssl_handshake_timeout=10))
323
324    def test_create_server_9(self):
325        async def handle_client(reader, writer):
326            pass
327
328        async def start_server():
329            srv = await asyncio.start_server(
330                handle_client,
331                '127.0.0.1', 0,
332                family=socket.AF_INET,
333                start_serving=False)
334
335            await srv.start_serving()
336            self.assertTrue(srv.is_serving())
337
338            # call start_serving again
339            await srv.start_serving()
340            self.assertTrue(srv.is_serving())
341
342            srv.close()
343            await srv.wait_closed()
344            self.assertFalse(srv.is_serving())
345
346        self.loop.run_until_complete(start_server())
347
348    def test_create_server_10(self):
349        async def handle_client(reader, writer):
350            pass
351
352        async def start_server():
353            srv = await asyncio.start_server(
354                handle_client,
355                '127.0.0.1', 0,
356                family=socket.AF_INET,
357                start_serving=False)
358
359            async with srv:
360                fut = asyncio.ensure_future(srv.serve_forever())
361                await asyncio.sleep(0)
362                self.assertTrue(srv.is_serving())
363
364                fut.cancel()
365                with self.assertRaises(asyncio.CancelledError):
366                    await fut
367                self.assertFalse(srv.is_serving())
368
369        self.loop.run_until_complete(start_server())
370
371    def test_create_connection_open_con_addr(self):
372        async def client(addr):
373            reader, writer = await asyncio.open_connection(*addr)
374
375            writer.write(b'AAAA')
376            self.assertEqual(await reader.readexactly(2), b'OK')
377
378            re = r'(a bytes-like object)|(must be byte-ish)'
379            with self.assertRaisesRegex(TypeError, re):
380                writer.write('AAAA')
381
382            writer.write(b'BBBB')
383            self.assertEqual(await reader.readexactly(4), b'SPAM')
384
385            if self.implementation == 'uvloop':
386                tr = writer.transport
387                sock = tr.get_extra_info('socket')
388                self.assertTrue(
389                    sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY))
390
391            writer.close()
392            await self.wait_closed(writer)
393
394        self._test_create_connection_1(client)
395
396    def test_create_connection_open_con_sock(self):
397        async def client(addr):
398            sock = socket.socket()
399            sock.connect(addr)
400            reader, writer = await asyncio.open_connection(sock=sock)
401
402            writer.write(b'AAAA')
403            self.assertEqual(await reader.readexactly(2), b'OK')
404
405            writer.write(b'BBBB')
406            self.assertEqual(await reader.readexactly(4), b'SPAM')
407
408            if self.implementation == 'uvloop':
409                tr = writer.transport
410                sock = tr.get_extra_info('socket')
411                self.assertTrue(
412                    sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY))
413
414            writer.close()
415            await self.wait_closed(writer)
416
417        self._test_create_connection_1(client)
418
419    def _test_create_connection_1(self, client):
420        CNT = 0
421        TOTAL_CNT = 100
422
423        def server(sock):
424            data = sock.recv_all(4)
425            self.assertEqual(data, b'AAAA')
426            sock.send(b'OK')
427
428            data = sock.recv_all(4)
429            self.assertEqual(data, b'BBBB')
430            sock.send(b'SPAM')
431
432        async def client_wrapper(addr):
433            await client(addr)
434            nonlocal CNT
435            CNT += 1
436
437        def run(coro):
438            nonlocal CNT
439            CNT = 0
440
441            with self.tcp_server(server,
442                                 max_clients=TOTAL_CNT,
443                                 backlog=TOTAL_CNT) as srv:
444                tasks = []
445                for _ in range(TOTAL_CNT):
446                    tasks.append(coro(srv.addr))
447
448                self.loop.run_until_complete(asyncio.gather(*tasks))
449
450            self.assertEqual(CNT, TOTAL_CNT)
451
452        run(client_wrapper)
453
454    def test_create_connection_2(self):
455        sock = socket.socket()
456        with sock:
457            sock.bind(('127.0.0.1', 0))
458            addr = sock.getsockname()
459
460        async def client():
461            reader, writer = await asyncio.open_connection(*addr)
462            writer.close()
463            await self.wait_closed(writer)
464
465        async def runner():
466            with self.assertRaises(ConnectionRefusedError):
467                await client()
468
469        self.loop.run_until_complete(runner())
470
471    def test_create_connection_3(self):
472        CNT = 0
473        TOTAL_CNT = 100
474
475        def server(sock):
476            data = sock.recv_all(4)
477            self.assertEqual(data, b'AAAA')
478            sock.close()
479
480        async def client(addr):
481            reader, writer = await asyncio.open_connection(*addr)
482
483            writer.write(b'AAAA')
484
485            with self.assertRaises(asyncio.IncompleteReadError):
486                await reader.readexactly(10)
487
488            writer.close()
489            await self.wait_closed(writer)
490
491            nonlocal CNT
492            CNT += 1
493
494        def run(coro):
495            nonlocal CNT
496            CNT = 0
497
498            with self.tcp_server(server,
499                                 max_clients=TOTAL_CNT,
500                                 backlog=TOTAL_CNT) as srv:
501                tasks = []
502                for _ in range(TOTAL_CNT):
503                    tasks.append(coro(srv.addr))
504
505                self.loop.run_until_complete(asyncio.gather(*tasks))
506
507            self.assertEqual(CNT, TOTAL_CNT)
508
509        run(client)
510
511    def test_create_connection_4(self):
512        sock = socket.socket()
513        sock.close()
514
515        async def client():
516            reader, writer = await asyncio.open_connection(sock=sock)
517            writer.close()
518            await self.wait_closed(writer)
519
520        async def runner():
521            with self.assertRaisesRegex(OSError, 'Bad file'):
522                await client()
523
524        self.loop.run_until_complete(runner())
525
526    def test_create_connection_5(self):
527        def server(sock):
528            try:
529                data = sock.recv_all(4)
530            except ConnectionError:
531                return
532            self.assertEqual(data, b'AAAA')
533            sock.send(b'OK')
534
535        async def client(addr):
536            fut = asyncio.ensure_future(
537                self.loop.create_connection(asyncio.Protocol, *addr))
538            await asyncio.sleep(0)
539            fut.cancel()
540            with self.assertRaises(asyncio.CancelledError):
541                await fut
542
543        with self.tcp_server(server,
544                             max_clients=1,
545                             backlog=1) as srv:
546            self.loop.run_until_complete(client(srv.addr))
547
548    def test_create_connection_6(self):
549        with self.assertRaisesRegex(
550                ValueError, 'ssl_handshake_timeout is only meaningful'):
551            self.loop.run_until_complete(
552                self.loop.create_connection(
553                    lambda: None, host='::', port=0, ssl_handshake_timeout=10))
554
555    def test_transport_shutdown(self):
556        CNT = 0           # number of clients that were successful
557        TOTAL_CNT = 100   # total number of clients that test will create
558        TIMEOUT = 5.0     # timeout for this test
559
560        async def handle_client(reader, writer):
561            nonlocal CNT
562
563            data = await reader.readexactly(4)
564            self.assertEqual(data, b'AAAA')
565
566            writer.write(b'OK')
567            writer.write_eof()
568            writer.write_eof()
569
570            await writer.drain()
571            writer.close()
572
573            CNT += 1
574
575        async def test_client(addr):
576            reader, writer = await asyncio.open_connection(*addr)
577
578            writer.write(b'AAAA')
579            data = await reader.readexactly(2)
580            self.assertEqual(data, b'OK')
581
582            writer.close()
583            await self.wait_closed(writer)
584
585        async def start_server():
586            nonlocal CNT
587            CNT = 0
588
589            srv = await asyncio.start_server(
590                handle_client,
591                '127.0.0.1', 0,
592                family=socket.AF_INET)
593
594            srv_socks = srv.sockets
595            self.assertTrue(srv_socks)
596
597            addr = srv_socks[0].getsockname()
598
599            tasks = []
600            for _ in range(TOTAL_CNT):
601                tasks.append(test_client(addr))
602
603            await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT)
604
605            srv.close()
606            await srv.wait_closed()
607
608        self.loop.run_until_complete(start_server())
609        self.assertEqual(CNT, TOTAL_CNT)
610
611    def test_tcp_handle_exception_in_connection_made(self):
612        # Test that if connection_made raises an exception,
613        # 'create_connection' still returns.
614
615        # Silence error logging
616        self.loop.set_exception_handler(lambda *args: None)
617
618        fut = asyncio.Future()
619        connection_lost_called = asyncio.Future()
620
621        async def server(reader, writer):
622            try:
623                await reader.read()
624            finally:
625                writer.close()
626
627        class Proto(asyncio.Protocol):
628            def connection_made(self, tr):
629                1 / 0
630
631            def connection_lost(self, exc):
632                connection_lost_called.set_result(exc)
633
634        srv = self.loop.run_until_complete(asyncio.start_server(
635            server,
636            '127.0.0.1', 0,
637            family=socket.AF_INET))
638
639        async def runner():
640            tr, pr = await asyncio.wait_for(
641                self.loop.create_connection(
642                    Proto, *srv.sockets[0].getsockname()),
643                timeout=1.0)
644            fut.set_result(None)
645            tr.close()
646
647        self.loop.run_until_complete(runner())
648        srv.close()
649        self.loop.run_until_complete(srv.wait_closed())
650        self.loop.run_until_complete(fut)
651
652        self.assertIsNone(
653            self.loop.run_until_complete(connection_lost_called))
654
655    def test_context_run_segfault(self):
656        is_new = False
657        done = self.loop.create_future()
658
659        def server(sock):
660            sock.sendall(b'hello')
661
662        class Protocol(asyncio.Protocol):
663            def __init__(self):
664                self.transport = None
665
666            def connection_made(self, transport):
667                self.transport = transport
668
669            def data_received(self, data):
670                try:
671                    self = weakref.ref(self)
672                    nonlocal is_new
673                    if is_new:
674                        done.set_result(data)
675                    else:
676                        is_new = True
677                        new_proto = Protocol()
678                        self().transport.set_protocol(new_proto)
679                        new_proto.connection_made(self().transport)
680                        new_proto.data_received(data)
681                except Exception as e:
682                    done.set_exception(e)
683
684        async def test(addr):
685            await self.loop.create_connection(Protocol, *addr)
686            data = await done
687            self.assertEqual(data, b'hello')
688
689        with self.tcp_server(server) as srv:
690            self.loop.run_until_complete(test(srv.addr))
691
692
693class Test_UV_TCP(_TestTCP, tb.UVTestCase):
694
695    def test_create_server_buffered_1(self):
696        SIZE = 123123
697        eof = False
698        fut = asyncio.Future()
699
700        class Proto(asyncio.BaseProtocol):
701            def connection_made(self, tr):
702                self.tr = tr
703                self.recvd = b''
704                self.data = bytearray(50)
705                self.buf = memoryview(self.data)
706
707            def get_buffer(self, sizehint):
708                return self.buf
709
710            def buffer_updated(self, nbytes):
711                self.recvd += self.buf[:nbytes]
712                if self.recvd == b'a' * SIZE:
713                    self.tr.write(b'hello')
714
715            def eof_received(self):
716                nonlocal eof
717                eof = True
718
719            def connection_lost(self, exc):
720                fut.set_result(exc)
721
722        async def test():
723            port = tb.find_free_port()
724            srv = await self.loop.create_server(Proto, '127.0.0.1', port)
725
726            s = socket.socket(socket.AF_INET)
727            with s:
728                s.setblocking(False)
729                await self.loop.sock_connect(s, ('127.0.0.1', port))
730                await self.loop.sock_sendall(s, b'a' * SIZE)
731                d = await self.loop.sock_recv(s, 100)
732                self.assertEqual(d, b'hello')
733
734            srv.close()
735            await srv.wait_closed()
736
737        self.loop.run_until_complete(test())
738        self.loop.run_until_complete(fut)
739        self.assertTrue(eof)
740        self.assertIsNone(fut.result())
741
742    def test_create_server_buffered_2(self):
743        class ProtoExc(asyncio.BaseProtocol):
744            def __init__(self):
745                self._lost_exc = None
746
747            def get_buffer(self, sizehint):
748                1 / 0
749
750            def buffer_updated(self, nbytes):
751                pass
752
753            def connection_lost(self, exc):
754                self._lost_exc = exc
755
756            def eof_received(self):
757                pass
758
759        class ProtoZeroBuf1(asyncio.BaseProtocol):
760            def __init__(self):
761                self._lost_exc = None
762
763            def get_buffer(self, sizehint):
764                return bytearray(0)
765
766            def buffer_updated(self, nbytes):
767                pass
768
769            def connection_lost(self, exc):
770                self._lost_exc = exc
771
772            def eof_received(self):
773                pass
774
775        class ProtoZeroBuf2(asyncio.BaseProtocol):
776            def __init__(self):
777                self._lost_exc = None
778
779            def get_buffer(self, sizehint):
780                return memoryview(bytearray(0))
781
782            def buffer_updated(self, nbytes):
783                pass
784
785            def connection_lost(self, exc):
786                self._lost_exc = exc
787
788            def eof_received(self):
789                pass
790
791        class ProtoUpdatedError(asyncio.BaseProtocol):
792            def __init__(self):
793                self._lost_exc = None
794
795            def get_buffer(self, sizehint):
796                return memoryview(bytearray(100))
797
798            def buffer_updated(self, nbytes):
799                raise RuntimeError('oups')
800
801            def connection_lost(self, exc):
802                self._lost_exc = exc
803
804            def eof_received(self):
805                pass
806
807        async def test(proto_factory, exc_type, exc_re):
808            port = tb.find_free_port()
809            proto = proto_factory()
810            srv = await self.loop.create_server(
811                lambda: proto, '127.0.0.1', port)
812
813            try:
814                s = socket.socket(socket.AF_INET)
815                with s:
816                    s.setblocking(False)
817                    await self.loop.sock_connect(s, ('127.0.0.1', port))
818                    await self.loop.sock_sendall(s, b'a')
819                    d = await self.loop.sock_recv(s, 100)
820                    if not d:
821                        raise ConnectionResetError
822            except ConnectionResetError:
823                pass
824            else:
825                self.fail("server didn't abort the connection")
826                return
827            finally:
828                srv.close()
829                await srv.wait_closed()
830
831            if proto._lost_exc is None:
832                self.fail("connection_lost() was not called")
833                return
834
835            with self.assertRaisesRegex(exc_type, exc_re):
836                raise proto._lost_exc
837
838        self.loop.set_exception_handler(lambda loop, ctx: None)
839
840        self.loop.run_until_complete(
841            test(ProtoExc, RuntimeError, 'unhandled error .* get_buffer'))
842
843        self.loop.run_until_complete(
844            test(ProtoZeroBuf1, RuntimeError, 'unhandled error .* get_buffer'))
845
846        self.loop.run_until_complete(
847            test(ProtoZeroBuf2, RuntimeError, 'unhandled error .* get_buffer'))
848
849        self.loop.run_until_complete(
850            test(ProtoUpdatedError, RuntimeError, r'^oups$'))
851
852    def test_transport_get_extra_info(self):
853        # This tests is only for uvloop.  asyncio should pass it
854        # too in Python 3.6.
855
856        fut = asyncio.Future()
857
858        async def handle_client(reader, writer):
859            with self.assertRaises(asyncio.IncompleteReadError):
860                await reader.readexactly(4)
861            writer.close()
862
863            # Previously, when we used socket.fromfd to create a socket
864            # for UVTransports (to make get_extra_info() work), a duplicate
865            # of the socket was created, preventing UVTransport from being
866            # properly closed.
867            # This test ensures that server handle will receive an EOF
868            # and finish the request.
869            fut.set_result(None)
870
871        async def test_client(addr):
872            t, p = await self.loop.create_connection(
873                lambda: asyncio.Protocol(), *addr)
874
875            if hasattr(t, 'get_protocol'):
876                p2 = asyncio.Protocol()
877                self.assertIs(t.get_protocol(), p)
878                t.set_protocol(p2)
879                self.assertIs(t.get_protocol(), p2)
880                t.set_protocol(p)
881
882            self.assertFalse(t._paused)
883            self.assertTrue(t.is_reading())
884            t.pause_reading()
885            t.pause_reading()  # Check that it's OK to call it 2nd time.
886            self.assertTrue(t._paused)
887            self.assertFalse(t.is_reading())
888            t.resume_reading()
889            t.resume_reading()  # Check that it's OK to call it 2nd time.
890            self.assertFalse(t._paused)
891            self.assertTrue(t.is_reading())
892
893            sock = t.get_extra_info('socket')
894            self.assertIs(sock, t.get_extra_info('socket'))
895            sockname = sock.getsockname()
896            peername = sock.getpeername()
897
898            with self.assertRaisesRegex(RuntimeError, 'is used by transport'):
899                self.loop.add_writer(sock.fileno(), lambda: None)
900            with self.assertRaisesRegex(RuntimeError, 'is used by transport'):
901                self.loop.remove_writer(sock.fileno())
902            with self.assertRaisesRegex(RuntimeError, 'is used by transport'):
903                self.loop.add_reader(sock.fileno(), lambda: None)
904            with self.assertRaisesRegex(RuntimeError, 'is used by transport'):
905                self.loop.remove_reader(sock.fileno())
906
907            self.assertEqual(t.get_extra_info('sockname'),
908                             sockname)
909            self.assertEqual(t.get_extra_info('peername'),
910                             peername)
911
912            t.write(b'OK')  # We want server to fail.
913
914            self.assertFalse(t._closing)
915            t.abort()
916            self.assertTrue(t._closing)
917
918            self.assertFalse(t.is_reading())
919            # Check that pause_reading and resume_reading don't raise
920            # errors if called after the transport is closed.
921            t.pause_reading()
922            t.resume_reading()
923
924            await fut
925
926            # Test that peername and sockname are available after
927            # the transport is closed.
928            self.assertEqual(t.get_extra_info('peername'),
929                             peername)
930            self.assertEqual(t.get_extra_info('sockname'),
931                             sockname)
932
933        async def start_server():
934            srv = await asyncio.start_server(
935                handle_client,
936                '127.0.0.1', 0,
937                family=socket.AF_INET)
938
939            addr = srv.sockets[0].getsockname()
940            await test_client(addr)
941
942            srv.close()
943            await srv.wait_closed()
944
945        self.loop.run_until_complete(start_server())
946
947    def test_create_server_float_backlog(self):
948        # asyncio spits out a warning we cannot suppress
949
950        async def runner(bl):
951            await self.loop.create_server(
952                asyncio.Protocol,
953                None, 0, backlog=bl)
954
955        for bl in (1.1, '1'):
956            with self.subTest(backlog=bl):
957                with self.assertRaisesRegex(TypeError, 'integer'):
958                    self.loop.run_until_complete(runner(bl))
959
960    def test_many_small_writes(self):
961        N = 10000
962        TOTAL = 0
963
964        fut = self.loop.create_future()
965
966        async def server(reader, writer):
967            nonlocal TOTAL
968            while True:
969                d = await reader.read(10000)
970                if not d:
971                    break
972                TOTAL += len(d)
973            fut.set_result(True)
974            writer.close()
975
976        async def run():
977            srv = await asyncio.start_server(
978                server,
979                '127.0.0.1', 0,
980                family=socket.AF_INET)
981
982            addr = srv.sockets[0].getsockname()
983            r, w = await asyncio.open_connection(*addr)
984
985            DATA = b'x' * 102400
986
987            # Test _StreamWriteContext with short sequences of writes
988            w.write(DATA)
989            await w.drain()
990            for _ in range(3):
991                w.write(DATA)
992            await w.drain()
993            for _ in range(10):
994                w.write(DATA)
995            await w.drain()
996
997            for _ in range(N):
998                w.write(DATA)
999
1000                try:
1001                    w.write('a')
1002                except TypeError:
1003                    pass
1004
1005            await w.drain()
1006            for _ in range(N):
1007                w.write(DATA)
1008                await w.drain()
1009
1010            w.close()
1011            await fut
1012            await self.wait_closed(w)
1013
1014            srv.close()
1015            await srv.wait_closed()
1016
1017            self.assertEqual(TOTAL, N * 2 * len(DATA) + 14 * len(DATA))
1018
1019        self.loop.run_until_complete(run())
1020
1021    @unittest.skipIf(sys.version_info[:3] >= (3, 8, 0),
1022                     "3.8 has a different method of GCing unclosed streams")
1023    def test_tcp_handle_unclosed_gc(self):
1024        fut = self.loop.create_future()
1025
1026        async def server(reader, writer):
1027            writer.transport.abort()
1028            fut.set_result(True)
1029
1030        async def run():
1031            addr = srv.sockets[0].getsockname()
1032            await asyncio.open_connection(*addr)
1033            await fut
1034            srv.close()
1035            await srv.wait_closed()
1036
1037        srv = self.loop.run_until_complete(asyncio.start_server(
1038            server,
1039            '127.0.0.1', 0,
1040            family=socket.AF_INET))
1041
1042        if self.loop.get_debug():
1043            rx = r'unclosed resource <TCP.*; ' \
1044                 r'object created at(.|\n)*test_tcp_handle_unclosed_gc'
1045        else:
1046            rx = r'unclosed resource <TCP.*'
1047
1048        with self.assertWarnsRegex(ResourceWarning, rx):
1049            self.loop.create_task(run())
1050            self.loop.run_until_complete(srv.wait_closed())
1051            self.loop.run_until_complete(asyncio.sleep(0.1))
1052
1053            srv = None
1054            gc.collect()
1055            gc.collect()
1056            gc.collect()
1057
1058            self.loop.run_until_complete(asyncio.sleep(0.1))
1059
1060        # Since one TCPTransport handle wasn't closed correctly,
1061        # we need to disable this check:
1062        self.skip_unclosed_handles_check()
1063
1064    def test_tcp_handle_abort_in_connection_made(self):
1065        async def server(reader, writer):
1066            try:
1067                await reader.read()
1068            finally:
1069                writer.close()
1070
1071        class Proto(asyncio.Protocol):
1072            def connection_made(self, tr):
1073                tr.abort()
1074
1075        srv = self.loop.run_until_complete(asyncio.start_server(
1076            server,
1077            '127.0.0.1', 0,
1078            family=socket.AF_INET))
1079
1080        async def runner():
1081            tr, pr = await asyncio.wait_for(
1082                self.loop.create_connection(
1083                    Proto, *srv.sockets[0].getsockname()),
1084                timeout=1.0)
1085
1086            # Asyncio would return a closed socket, which we
1087            # can't do: the transport was aborted, hence there
1088            # is no FD to attach a socket to (to make
1089            # get_extra_info() work).
1090            self.assertIsNone(tr.get_extra_info('socket'))
1091            tr.close()
1092
1093        self.loop.run_until_complete(runner())
1094        srv.close()
1095        self.loop.run_until_complete(srv.wait_closed())
1096
1097    def test_connect_accepted_socket_ssl_args(self):
1098        with self.assertRaisesRegex(
1099                ValueError, 'ssl_handshake_timeout is only meaningful'):
1100            with socket.socket() as s:
1101                self.loop.run_until_complete(
1102                    self.loop.connect_accepted_socket(
1103                        (lambda: None),
1104                        s,
1105                        ssl_handshake_timeout=SSL_HANDSHAKE_TIMEOUT
1106                    )
1107                )
1108
1109    def test_connect_accepted_socket(self, server_ssl=None, client_ssl=None):
1110        loop = self.loop
1111
1112        class MyProto(MyBaseProto):
1113
1114            def connection_lost(self, exc):
1115                super().connection_lost(exc)
1116                loop.call_soon(loop.stop)
1117
1118            def data_received(self, data):
1119                super().data_received(data)
1120                self.transport.write(expected_response)
1121
1122        lsock = socket.socket(socket.AF_INET)
1123        lsock.bind(('127.0.0.1', 0))
1124        lsock.listen(1)
1125        addr = lsock.getsockname()
1126
1127        message = b'test data'
1128        response = None
1129        expected_response = b'roger'
1130
1131        def client():
1132            nonlocal response
1133            try:
1134                csock = socket.socket(socket.AF_INET)
1135                if client_ssl is not None:
1136                    csock = client_ssl.wrap_socket(csock)
1137                csock.connect(addr)
1138                csock.sendall(message)
1139                response = csock.recv(99)
1140                csock.close()
1141            except Exception as exc:
1142                print(
1143                    "Failure in client thread in test_connect_accepted_socket",
1144                    exc)
1145
1146        thread = threading.Thread(target=client, daemon=True)
1147        thread.start()
1148
1149        conn, _ = lsock.accept()
1150        proto = MyProto(loop=loop)
1151        proto.loop = loop
1152
1153        extras = {}
1154        if server_ssl:
1155            extras = dict(ssl_handshake_timeout=SSL_HANDSHAKE_TIMEOUT)
1156
1157        f = loop.create_task(
1158            loop.connect_accepted_socket(
1159                (lambda: proto), conn, ssl=server_ssl,
1160                **extras))
1161        loop.run_forever()
1162        conn.close()
1163        lsock.close()
1164
1165        thread.join(1)
1166        self.assertFalse(thread.is_alive())
1167        self.assertEqual(proto.state, 'CLOSED')
1168        self.assertEqual(proto.nbytes, len(message))
1169        self.assertEqual(response, expected_response)
1170        tr, _ = f.result()
1171
1172        if server_ssl:
1173            self.assertIn('SSL', tr.__class__.__name__)
1174
1175        tr.close()
1176        # let it close
1177        self.loop.run_until_complete(asyncio.sleep(0.1))
1178
1179    @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'no Unix sockets')
1180    def test_create_connection_wrong_sock(self):
1181        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
1182        with sock:
1183            coro = self.loop.create_connection(MyBaseProto, sock=sock)
1184            with self.assertRaisesRegex(ValueError,
1185                                        'A Stream Socket was expected'):
1186                self.loop.run_until_complete(coro)
1187
1188    @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'no Unix sockets')
1189    def test_create_server_wrong_sock(self):
1190        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
1191        with sock:
1192            coro = self.loop.create_server(MyBaseProto, sock=sock)
1193            with self.assertRaisesRegex(ValueError,
1194                                        'A Stream Socket was expected'):
1195                self.loop.run_until_complete(coro)
1196
1197    @unittest.skipUnless(hasattr(socket, 'SOCK_NONBLOCK'),
1198                         'no socket.SOCK_NONBLOCK (linux only)')
1199    def test_create_server_stream_bittype(self):
1200        sock = socket.socket(
1201            socket.AF_INET, socket.SOCK_STREAM | socket.SOCK_NONBLOCK)
1202        with sock:
1203            coro = self.loop.create_server(lambda: None, sock=sock)
1204            srv = self.loop.run_until_complete(coro)
1205            srv.close()
1206            self.loop.run_until_complete(srv.wait_closed())
1207
1208    def test_flowcontrol_mixin_set_write_limits(self):
1209        async def client(addr):
1210            paused = False
1211
1212            class Protocol(asyncio.Protocol):
1213                def pause_writing(self):
1214                    nonlocal paused
1215                    paused = True
1216
1217                def resume_writing(self):
1218                    nonlocal paused
1219                    paused = False
1220
1221            t, p = await self.loop.create_connection(Protocol, *addr)
1222
1223            t.write(b'q' * 512)
1224            self.assertEqual(t.get_write_buffer_size(), 512)
1225
1226            t.set_write_buffer_limits(low=16385)
1227            self.assertFalse(paused)
1228            self.assertEqual(t.get_write_buffer_limits(), (16385, 65540))
1229
1230            with self.assertRaisesRegex(ValueError, 'high.*must be >= low'):
1231                t.set_write_buffer_limits(high=0, low=1)
1232
1233            t.set_write_buffer_limits(high=1024, low=128)
1234            self.assertFalse(paused)
1235            self.assertEqual(t.get_write_buffer_limits(), (128, 1024))
1236
1237            t.set_write_buffer_limits(high=256, low=128)
1238            self.assertTrue(paused)
1239            self.assertEqual(t.get_write_buffer_limits(), (128, 256))
1240
1241            t.close()
1242
1243        with self.tcp_server(lambda sock: sock.recv_all(1),
1244                             max_clients=1,
1245                             backlog=1) as srv:
1246            self.loop.run_until_complete(client(srv.addr))
1247
1248
1249class Test_AIO_TCP(_TestTCP, tb.AIOTestCase):
1250    pass
1251
1252
1253class _TestSSL(tb.SSLTestCase):
1254
1255    ONLYCERT = tb._cert_fullname(__file__, 'ssl_cert.pem')
1256    ONLYKEY = tb._cert_fullname(__file__, 'ssl_key.pem')
1257
1258    PAYLOAD_SIZE = 1024 * 100
1259    TIMEOUT = 60
1260
1261    def test_create_server_ssl_1(self):
1262        CNT = 0           # number of clients that were successful
1263        TOTAL_CNT = 25    # total number of clients that test will create
1264        TIMEOUT = 10.0    # timeout for this test
1265
1266        A_DATA = b'A' * 1024 * 1024
1267        B_DATA = b'B' * 1024 * 1024
1268
1269        sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY)
1270        client_sslctx = self._create_client_ssl_context()
1271
1272        clients = []
1273
1274        async def handle_client(reader, writer):
1275            nonlocal CNT
1276
1277            data = await reader.readexactly(len(A_DATA))
1278            self.assertEqual(data, A_DATA)
1279            writer.write(b'OK')
1280
1281            data = await reader.readexactly(len(B_DATA))
1282            self.assertEqual(data, B_DATA)
1283            writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')])
1284
1285            await writer.drain()
1286            writer.close()
1287
1288            CNT += 1
1289
1290        async def test_client(addr):
1291            fut = asyncio.Future()
1292
1293            def prog(sock):
1294                try:
1295                    sock.starttls(client_sslctx)
1296                    sock.connect(addr)
1297                    sock.send(A_DATA)
1298
1299                    data = sock.recv_all(2)
1300                    self.assertEqual(data, b'OK')
1301
1302                    sock.send(B_DATA)
1303                    data = sock.recv_all(4)
1304                    self.assertEqual(data, b'SPAM')
1305
1306                    sock.close()
1307
1308                except Exception as ex:
1309                    self.loop.call_soon_threadsafe(fut.set_exception, ex)
1310                else:
1311                    self.loop.call_soon_threadsafe(fut.set_result, None)
1312
1313            client = self.tcp_client(prog)
1314            client.start()
1315            clients.append(client)
1316
1317            await fut
1318
1319        async def start_server():
1320            extras = dict(ssl_handshake_timeout=SSL_HANDSHAKE_TIMEOUT)
1321
1322            srv = await asyncio.start_server(
1323                handle_client,
1324                '127.0.0.1', 0,
1325                family=socket.AF_INET,
1326                ssl=sslctx,
1327                **extras)
1328
1329            try:
1330                srv_socks = srv.sockets
1331                self.assertTrue(srv_socks)
1332
1333                addr = srv_socks[0].getsockname()
1334
1335                tasks = []
1336                for _ in range(TOTAL_CNT):
1337                    tasks.append(test_client(addr))
1338
1339                await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT)
1340
1341            finally:
1342                self.loop.call_soon(srv.close)
1343                await srv.wait_closed()
1344
1345        with self._silence_eof_received_warning():
1346            self.loop.run_until_complete(start_server())
1347
1348        self.assertEqual(CNT, TOTAL_CNT)
1349
1350        for client in clients:
1351            client.stop()
1352
1353    def test_create_connection_ssl_1(self):
1354        if self.implementation == 'asyncio':
1355            # Don't crash on asyncio errors
1356            self.loop.set_exception_handler(None)
1357
1358        CNT = 0
1359        TOTAL_CNT = 25
1360
1361        A_DATA = b'A' * 1024 * 1024
1362        B_DATA = b'B' * 1024 * 1024
1363
1364        sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY)
1365        client_sslctx = self._create_client_ssl_context()
1366
1367        def server(sock):
1368            sock.starttls(
1369                sslctx,
1370                server_side=True)
1371
1372            data = sock.recv_all(len(A_DATA))
1373            self.assertEqual(data, A_DATA)
1374            sock.send(b'OK')
1375
1376            data = sock.recv_all(len(B_DATA))
1377            self.assertEqual(data, B_DATA)
1378            sock.send(b'SPAM')
1379
1380            sock.close()
1381
1382        async def client(addr):
1383            extras = dict(ssl_handshake_timeout=SSL_HANDSHAKE_TIMEOUT)
1384
1385            reader, writer = await asyncio.open_connection(
1386                *addr,
1387                ssl=client_sslctx,
1388                server_hostname='',
1389                **extras)
1390
1391            writer.write(A_DATA)
1392            self.assertEqual(await reader.readexactly(2), b'OK')
1393
1394            writer.write(B_DATA)
1395            self.assertEqual(await reader.readexactly(4), b'SPAM')
1396
1397            nonlocal CNT
1398            CNT += 1
1399
1400            writer.close()
1401            await self.wait_closed(writer)
1402
1403        async def client_sock(addr):
1404            sock = socket.socket()
1405            sock.connect(addr)
1406            reader, writer = await asyncio.open_connection(
1407                sock=sock,
1408                ssl=client_sslctx,
1409                server_hostname='')
1410
1411            writer.write(A_DATA)
1412            self.assertEqual(await reader.readexactly(2), b'OK')
1413
1414            writer.write(B_DATA)
1415            self.assertEqual(await reader.readexactly(4), b'SPAM')
1416
1417            nonlocal CNT
1418            CNT += 1
1419
1420            writer.close()
1421            await self.wait_closed(writer)
1422            sock.close()
1423
1424        def run(coro):
1425            nonlocal CNT
1426            CNT = 0
1427
1428            with self.tcp_server(server,
1429                                 max_clients=TOTAL_CNT,
1430                                 backlog=TOTAL_CNT) as srv:
1431                tasks = []
1432                for _ in range(TOTAL_CNT):
1433                    tasks.append(coro(srv.addr))
1434
1435                self.loop.run_until_complete(asyncio.gather(*tasks))
1436
1437            self.assertEqual(CNT, TOTAL_CNT)
1438
1439        with self._silence_eof_received_warning():
1440            run(client)
1441
1442        with self._silence_eof_received_warning():
1443            run(client_sock)
1444
1445    def test_create_connection_ssl_slow_handshake(self):
1446        if self.implementation == 'asyncio':
1447            raise unittest.SkipTest()
1448
1449        client_sslctx = self._create_client_ssl_context()
1450
1451        # silence error logger
1452        self.loop.set_exception_handler(lambda *args: None)
1453
1454        def server(sock):
1455            try:
1456                sock.recv_all(1024 * 1024)
1457            except ConnectionAbortedError:
1458                pass
1459            finally:
1460                sock.close()
1461
1462        async def client(addr):
1463            reader, writer = await asyncio.open_connection(
1464                *addr,
1465                ssl=client_sslctx,
1466                server_hostname='',
1467                ssl_handshake_timeout=1.0)
1468            writer.close()
1469            await self.wait_closed(writer)
1470
1471        with self.tcp_server(server,
1472                             max_clients=1,
1473                             backlog=1) as srv:
1474
1475            with self.assertRaisesRegex(
1476                    ConnectionAbortedError,
1477                    r'SSL handshake.*is taking longer'):
1478
1479                self.loop.run_until_complete(client(srv.addr))
1480
1481    def test_create_connection_ssl_failed_certificate(self):
1482        if self.implementation == 'asyncio':
1483            raise unittest.SkipTest()
1484
1485        # silence error logger
1486        self.loop.set_exception_handler(lambda *args: None)
1487
1488        sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY)
1489        client_sslctx = self._create_client_ssl_context(disable_verify=False)
1490
1491        def server(sock):
1492            try:
1493                sock.starttls(
1494                    sslctx,
1495                    server_side=True)
1496                sock.connect()
1497            except (ssl.SSLError, OSError):
1498                pass
1499            finally:
1500                sock.close()
1501
1502        async def client(addr):
1503            reader, writer = await asyncio.open_connection(
1504                *addr,
1505                ssl=client_sslctx,
1506                server_hostname='',
1507                ssl_handshake_timeout=1.0)
1508            writer.close()
1509            await self.wait_closed(writer)
1510
1511        with self.tcp_server(server,
1512                             max_clients=1,
1513                             backlog=1) as srv:
1514
1515            with self.assertRaises(ssl.SSLCertVerificationError):
1516                self.loop.run_until_complete(client(srv.addr))
1517
1518    def test_start_tls_wrong_args(self):
1519        if self.implementation == 'asyncio':
1520            raise unittest.SkipTest()
1521
1522        async def main():
1523            with self.assertRaisesRegex(TypeError, 'SSLContext, got'):
1524                await self.loop.start_tls(None, None, None)
1525
1526            sslctx = self._create_server_ssl_context(
1527                self.ONLYCERT, self.ONLYKEY)
1528            with self.assertRaisesRegex(TypeError, 'is not supported'):
1529                await self.loop.start_tls(None, None, sslctx)
1530
1531        self.loop.run_until_complete(main())
1532
1533    def test_ssl_handshake_timeout(self):
1534        if self.implementation == 'asyncio':
1535            raise unittest.SkipTest()
1536
1537        # bpo-29970: Check that a connection is aborted if handshake is not
1538        # completed in timeout period, instead of remaining open indefinitely
1539        client_sslctx = self._create_client_ssl_context()
1540
1541        # silence error logger
1542        messages = []
1543        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
1544
1545        server_side_aborted = False
1546
1547        def server(sock):
1548            nonlocal server_side_aborted
1549            try:
1550                sock.recv_all(1024 * 1024)
1551            except ConnectionAbortedError:
1552                server_side_aborted = True
1553            finally:
1554                sock.close()
1555
1556        async def client(addr):
1557            await asyncio.wait_for(
1558                self.loop.create_connection(
1559                    asyncio.Protocol,
1560                    *addr,
1561                    ssl=client_sslctx,
1562                    server_hostname='',
1563                    ssl_handshake_timeout=SSL_HANDSHAKE_TIMEOUT
1564                ),
1565                0.5
1566            )
1567
1568        with self.tcp_server(server,
1569                             max_clients=1,
1570                             backlog=1) as srv:
1571
1572            with self.assertRaises(asyncio.TimeoutError):
1573                self.loop.run_until_complete(client(srv.addr))
1574
1575        self.assertTrue(server_side_aborted)
1576
1577        # Python issue #23197: cancelling a handshake must not raise an
1578        # exception or log an error, even if the handshake failed
1579        self.assertEqual(messages, [])
1580
1581    def test_ssl_handshake_connection_lost(self):
1582        # #246: make sure that no connection_lost() is called before
1583        # connection_made() is called first
1584
1585        client_sslctx = self._create_client_ssl_context()
1586
1587        # silence error logger
1588        self.loop.set_exception_handler(lambda loop, ctx: None)
1589
1590        connection_made_called = False
1591        connection_lost_called = False
1592
1593        def server(sock):
1594            sock.recv(1024)
1595            # break the connection during handshake
1596            sock.close()
1597
1598        class ClientProto(asyncio.Protocol):
1599            def connection_made(self, transport):
1600                nonlocal connection_made_called
1601                connection_made_called = True
1602
1603            def connection_lost(self, exc):
1604                nonlocal connection_lost_called
1605                connection_lost_called = True
1606
1607        async def client(addr):
1608            await self.loop.create_connection(
1609                ClientProto,
1610                *addr,
1611                ssl=client_sslctx,
1612                server_hostname=''),
1613
1614        with self.tcp_server(server,
1615                             max_clients=1,
1616                             backlog=1) as srv:
1617
1618            with self.assertRaises(ConnectionResetError):
1619                self.loop.run_until_complete(client(srv.addr))
1620
1621        if connection_lost_called:
1622            if connection_made_called:
1623                self.fail("unexpected call to connection_lost()")
1624            else:
1625                self.fail("unexpected call to connection_lost() without"
1626                          "calling connection_made()")
1627        elif connection_made_called:
1628            self.fail("unexpected call to connection_made()")
1629
1630    def test_ssl_connect_accepted_socket(self):
1631        if hasattr(ssl, 'PROTOCOL_TLS'):
1632            proto = ssl.PROTOCOL_TLS
1633        else:
1634            proto = ssl.PROTOCOL_SSLv23
1635        server_context = ssl.SSLContext(proto)
1636        server_context.load_cert_chain(self.ONLYCERT, self.ONLYKEY)
1637        if hasattr(server_context, 'check_hostname'):
1638            server_context.check_hostname = False
1639        server_context.verify_mode = ssl.CERT_NONE
1640
1641        client_context = ssl.SSLContext(proto)
1642        if hasattr(server_context, 'check_hostname'):
1643            client_context.check_hostname = False
1644        client_context.verify_mode = ssl.CERT_NONE
1645
1646        Test_UV_TCP.test_connect_accepted_socket(
1647            self, server_context, client_context)
1648
1649    def test_start_tls_client_corrupted_ssl(self):
1650        if self.implementation == 'asyncio':
1651            raise unittest.SkipTest()
1652
1653        self.loop.set_exception_handler(lambda loop, ctx: None)
1654
1655        sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY)
1656        client_sslctx = self._create_client_ssl_context()
1657
1658        def server(sock):
1659            orig_sock = sock.dup()
1660            try:
1661                sock.starttls(
1662                    sslctx,
1663                    server_side=True)
1664                sock.sendall(b'A\n')
1665                sock.recv_all(1)
1666                orig_sock.send(b'please corrupt the SSL connection')
1667            except ssl.SSLError:
1668                pass
1669            finally:
1670                sock.close()
1671                orig_sock.close()
1672
1673        async def client(addr):
1674            reader, writer = await asyncio.open_connection(
1675                *addr,
1676                ssl=client_sslctx,
1677                server_hostname='')
1678
1679            self.assertEqual(await reader.readline(), b'A\n')
1680            writer.write(b'B')
1681            with self.assertRaises(ssl.SSLError):
1682                await reader.readline()
1683            writer.close()
1684            try:
1685                await self.wait_closed(writer)
1686            except ssl.SSLError:
1687                pass
1688            return 'OK'
1689
1690        with self.tcp_server(server,
1691                             max_clients=1,
1692                             backlog=1) as srv:
1693
1694            res = self.loop.run_until_complete(client(srv.addr))
1695
1696        self.assertEqual(res, 'OK')
1697
1698    def test_start_tls_client_reg_proto_1(self):
1699        if self.implementation == 'asyncio':
1700            raise unittest.SkipTest()
1701
1702        HELLO_MSG = b'1' * self.PAYLOAD_SIZE
1703
1704        server_context = self._create_server_ssl_context(
1705            self.ONLYCERT, self.ONLYKEY)
1706        client_context = self._create_client_ssl_context()
1707
1708        def serve(sock):
1709            sock.settimeout(self.TIMEOUT)
1710
1711            data = sock.recv_all(len(HELLO_MSG))
1712            self.assertEqual(len(data), len(HELLO_MSG))
1713
1714            sock.starttls(server_context, server_side=True)
1715
1716            sock.sendall(b'O')
1717            data = sock.recv_all(len(HELLO_MSG))
1718            self.assertEqual(len(data), len(HELLO_MSG))
1719
1720            sock.unwrap()
1721            sock.close()
1722
1723        class ClientProto(asyncio.Protocol):
1724            def __init__(self, on_data, on_eof):
1725                self.on_data = on_data
1726                self.on_eof = on_eof
1727                self.con_made_cnt = 0
1728
1729            def connection_made(proto, tr):
1730                proto.con_made_cnt += 1
1731                # Ensure connection_made gets called only once.
1732                self.assertEqual(proto.con_made_cnt, 1)
1733
1734            def data_received(self, data):
1735                self.on_data.set_result(data)
1736
1737            def eof_received(self):
1738                self.on_eof.set_result(True)
1739
1740        async def client(addr):
1741            await asyncio.sleep(0.5)
1742
1743            on_data = self.loop.create_future()
1744            on_eof = self.loop.create_future()
1745
1746            tr, proto = await self.loop.create_connection(
1747                lambda: ClientProto(on_data, on_eof), *addr)
1748
1749            tr.write(HELLO_MSG)
1750            new_tr = await self.loop.start_tls(tr, proto, client_context)
1751
1752            self.assertEqual(await on_data, b'O')
1753            new_tr.write(HELLO_MSG)
1754            await on_eof
1755
1756            new_tr.close()
1757
1758        with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
1759            self.loop.run_until_complete(
1760                asyncio.wait_for(client(srv.addr), timeout=10))
1761
1762    def test_create_connection_memory_leak(self):
1763        if self.implementation == 'asyncio':
1764            raise unittest.SkipTest()
1765
1766        HELLO_MSG = b'1' * self.PAYLOAD_SIZE
1767
1768        server_context = self._create_server_ssl_context(
1769            self.ONLYCERT, self.ONLYKEY)
1770        client_context = self._create_client_ssl_context()
1771
1772        def serve(sock):
1773            sock.settimeout(self.TIMEOUT)
1774
1775            sock.starttls(server_context, server_side=True)
1776
1777            sock.sendall(b'O')
1778            data = sock.recv_all(len(HELLO_MSG))
1779            self.assertEqual(len(data), len(HELLO_MSG))
1780
1781            sock.unwrap()
1782            sock.close()
1783
1784        class ClientProto(asyncio.Protocol):
1785            def __init__(self, on_data, on_eof):
1786                self.on_data = on_data
1787                self.on_eof = on_eof
1788                self.con_made_cnt = 0
1789
1790            def connection_made(proto, tr):
1791                # XXX: We assume user stores the transport in protocol
1792                proto.tr = tr
1793                proto.con_made_cnt += 1
1794                # Ensure connection_made gets called only once.
1795                self.assertEqual(proto.con_made_cnt, 1)
1796
1797            def data_received(self, data):
1798                self.on_data.set_result(data)
1799
1800            def eof_received(self):
1801                self.on_eof.set_result(True)
1802
1803        async def client(addr):
1804            await asyncio.sleep(0.5)
1805
1806            on_data = self.loop.create_future()
1807            on_eof = self.loop.create_future()
1808
1809            tr, proto = await self.loop.create_connection(
1810                lambda: ClientProto(on_data, on_eof), *addr,
1811                ssl=client_context)
1812
1813            self.assertEqual(await on_data, b'O')
1814            tr.write(HELLO_MSG)
1815            await on_eof
1816
1817            tr.close()
1818
1819        with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
1820            self.loop.run_until_complete(
1821                asyncio.wait_for(client(srv.addr), timeout=10))
1822
1823        # No garbage is left for SSL client from loop.create_connection, even
1824        # if user stores the SSLTransport in corresponding protocol instance
1825        client_context = weakref.ref(client_context)
1826        self.assertIsNone(client_context())
1827
1828    def test_start_tls_client_buf_proto_1(self):
1829        if self.implementation == 'asyncio':
1830            raise unittest.SkipTest()
1831
1832        HELLO_MSG = b'1' * self.PAYLOAD_SIZE
1833
1834        server_context = self._create_server_ssl_context(
1835            self.ONLYCERT, self.ONLYKEY)
1836        client_context = self._create_client_ssl_context()
1837
1838        client_con_made_calls = 0
1839
1840        def serve(sock):
1841            sock.settimeout(self.TIMEOUT)
1842
1843            data = sock.recv_all(len(HELLO_MSG))
1844            self.assertEqual(len(data), len(HELLO_MSG))
1845
1846            sock.starttls(server_context, server_side=True)
1847
1848            sock.sendall(b'O')
1849            data = sock.recv_all(len(HELLO_MSG))
1850            self.assertEqual(len(data), len(HELLO_MSG))
1851
1852            sock.sendall(b'2')
1853            data = sock.recv_all(len(HELLO_MSG))
1854            self.assertEqual(len(data), len(HELLO_MSG))
1855
1856            sock.unwrap()
1857            sock.close()
1858
1859        class ClientProtoFirst(asyncio.BaseProtocol):
1860            def __init__(self, on_data):
1861                self.on_data = on_data
1862                self.buf = bytearray(1)
1863
1864            def connection_made(self, tr):
1865                nonlocal client_con_made_calls
1866                client_con_made_calls += 1
1867
1868            def get_buffer(self, sizehint):
1869                return self.buf
1870
1871            def buffer_updated(self, nsize):
1872                assert nsize == 1
1873                self.on_data.set_result(bytes(self.buf[:nsize]))
1874
1875            def eof_received(self):
1876                pass
1877
1878        class ClientProtoSecond(asyncio.Protocol):
1879            def __init__(self, on_data, on_eof):
1880                self.on_data = on_data
1881                self.on_eof = on_eof
1882                self.con_made_cnt = 0
1883
1884            def connection_made(self, tr):
1885                nonlocal client_con_made_calls
1886                client_con_made_calls += 1
1887
1888            def data_received(self, data):
1889                self.on_data.set_result(data)
1890
1891            def eof_received(self):
1892                self.on_eof.set_result(True)
1893
1894        async def client(addr):
1895            await asyncio.sleep(0.5)
1896
1897            on_data1 = self.loop.create_future()
1898            on_data2 = self.loop.create_future()
1899            on_eof = self.loop.create_future()
1900
1901            tr, proto = await self.loop.create_connection(
1902                lambda: ClientProtoFirst(on_data1), *addr)
1903
1904            tr.write(HELLO_MSG)
1905            new_tr = await self.loop.start_tls(tr, proto, client_context)
1906
1907            self.assertEqual(await on_data1, b'O')
1908            new_tr.write(HELLO_MSG)
1909
1910            new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof))
1911            self.assertEqual(await on_data2, b'2')
1912            new_tr.write(HELLO_MSG)
1913            await on_eof
1914
1915            new_tr.close()
1916
1917            # connection_made() should be called only once -- when
1918            # we establish connection for the first time. Start TLS
1919            # doesn't call connection_made() on application protocols.
1920            self.assertEqual(client_con_made_calls, 1)
1921
1922        with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
1923            self.loop.run_until_complete(
1924                asyncio.wait_for(client(srv.addr),
1925                                 timeout=self.TIMEOUT))
1926
1927    def test_start_tls_slow_client_cancel(self):
1928        if self.implementation == 'asyncio':
1929            raise unittest.SkipTest()
1930
1931        HELLO_MSG = b'1' * self.PAYLOAD_SIZE
1932
1933        client_context = self._create_client_ssl_context()
1934        server_waits_on_handshake = self.loop.create_future()
1935
1936        def serve(sock):
1937            sock.settimeout(self.TIMEOUT)
1938
1939            data = sock.recv_all(len(HELLO_MSG))
1940            self.assertEqual(len(data), len(HELLO_MSG))
1941
1942            try:
1943                self.loop.call_soon_threadsafe(
1944                    server_waits_on_handshake.set_result, None)
1945                data = sock.recv_all(1024 * 1024)
1946            except ConnectionAbortedError:
1947                pass
1948            finally:
1949                sock.close()
1950
1951        class ClientProto(asyncio.Protocol):
1952            def __init__(self, on_data, on_eof):
1953                self.on_data = on_data
1954                self.on_eof = on_eof
1955                self.con_made_cnt = 0
1956
1957            def connection_made(proto, tr):
1958                proto.con_made_cnt += 1
1959                # Ensure connection_made gets called only once.
1960                self.assertEqual(proto.con_made_cnt, 1)
1961
1962            def data_received(self, data):
1963                self.on_data.set_result(data)
1964
1965            def eof_received(self):
1966                self.on_eof.set_result(True)
1967
1968        async def client(addr):
1969            await asyncio.sleep(0.5)
1970
1971            on_data = self.loop.create_future()
1972            on_eof = self.loop.create_future()
1973
1974            tr, proto = await self.loop.create_connection(
1975                lambda: ClientProto(on_data, on_eof), *addr)
1976
1977            tr.write(HELLO_MSG)
1978
1979            await server_waits_on_handshake
1980
1981            with self.assertRaises(asyncio.TimeoutError):
1982                await asyncio.wait_for(
1983                    self.loop.start_tls(tr, proto, client_context),
1984                    0.5)
1985
1986        with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
1987            self.loop.run_until_complete(
1988                asyncio.wait_for(client(srv.addr), timeout=10))
1989
1990    def test_start_tls_server_1(self):
1991        if self.implementation == 'asyncio':
1992            raise unittest.SkipTest()
1993
1994        HELLO_MSG = b'1' * self.PAYLOAD_SIZE
1995
1996        server_context = self._create_server_ssl_context(
1997            self.ONLYCERT, self.ONLYKEY)
1998        client_context = self._create_client_ssl_context()
1999
2000        def client(sock, addr):
2001            sock.settimeout(self.TIMEOUT)
2002
2003            sock.connect(addr)
2004            data = sock.recv_all(len(HELLO_MSG))
2005            self.assertEqual(len(data), len(HELLO_MSG))
2006
2007            sock.starttls(client_context)
2008            sock.sendall(HELLO_MSG)
2009
2010            sock.unwrap()
2011            sock.close()
2012
2013        class ServerProto(asyncio.Protocol):
2014            def __init__(self, on_con, on_eof, on_con_lost):
2015                self.on_con = on_con
2016                self.on_eof = on_eof
2017                self.on_con_lost = on_con_lost
2018                self.data = b''
2019
2020            def connection_made(self, tr):
2021                self.on_con.set_result(tr)
2022
2023            def data_received(self, data):
2024                self.data += data
2025
2026            def eof_received(self):
2027                self.on_eof.set_result(1)
2028
2029            def connection_lost(self, exc):
2030                if exc is None:
2031                    self.on_con_lost.set_result(None)
2032                else:
2033                    self.on_con_lost.set_exception(exc)
2034
2035        async def main(proto, on_con, on_eof, on_con_lost):
2036            tr = await on_con
2037            tr.write(HELLO_MSG)
2038
2039            self.assertEqual(proto.data, b'')
2040
2041            new_tr = await self.loop.start_tls(
2042                tr, proto, server_context,
2043                server_side=True,
2044                ssl_handshake_timeout=self.TIMEOUT)
2045
2046            await on_eof
2047            await on_con_lost
2048            self.assertEqual(proto.data, HELLO_MSG)
2049            new_tr.close()
2050
2051        async def run_main():
2052            on_con = self.loop.create_future()
2053            on_eof = self.loop.create_future()
2054            on_con_lost = self.loop.create_future()
2055            proto = ServerProto(on_con, on_eof, on_con_lost)
2056
2057            server = await self.loop.create_server(
2058                lambda: proto, '127.0.0.1', 0)
2059            addr = server.sockets[0].getsockname()
2060
2061            with self.tcp_client(lambda sock: client(sock, addr),
2062                                 timeout=self.TIMEOUT):
2063                await asyncio.wait_for(
2064                    main(proto, on_con, on_eof, on_con_lost),
2065                    timeout=self.TIMEOUT)
2066
2067            server.close()
2068            await server.wait_closed()
2069
2070        self.loop.run_until_complete(run_main())
2071
2072    def test_create_server_ssl_over_ssl(self):
2073        if self.implementation == 'asyncio':
2074            raise unittest.SkipTest('asyncio does not support SSL over SSL')
2075
2076        CNT = 0           # number of clients that were successful
2077        TOTAL_CNT = 25    # total number of clients that test will create
2078        TIMEOUT = 20.0    # timeout for this test
2079
2080        A_DATA = b'A' * 1024 * 1024
2081        B_DATA = b'B' * 1024 * 1024
2082
2083        sslctx_1 = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY)
2084        client_sslctx_1 = self._create_client_ssl_context()
2085        sslctx_2 = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY)
2086        client_sslctx_2 = self._create_client_ssl_context()
2087
2088        clients = []
2089
2090        async def handle_client(reader, writer):
2091            nonlocal CNT
2092
2093            data = await reader.readexactly(len(A_DATA))
2094            self.assertEqual(data, A_DATA)
2095            writer.write(b'OK')
2096
2097            data = await reader.readexactly(len(B_DATA))
2098            self.assertEqual(data, B_DATA)
2099            writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')])
2100
2101            await writer.drain()
2102            writer.close()
2103
2104            CNT += 1
2105
2106        class ServerProtocol(asyncio.StreamReaderProtocol):
2107            def connection_made(self, transport):
2108                super_ = super()
2109                transport.pause_reading()
2110                fut = self._loop.create_task(self._loop.start_tls(
2111                    transport, self, sslctx_2, server_side=True))
2112
2113                def cb(_):
2114                    try:
2115                        tr = fut.result()
2116                    except Exception as ex:
2117                        super_.connection_lost(ex)
2118                    else:
2119                        super_.connection_made(tr)
2120                fut.add_done_callback(cb)
2121
2122        def server_protocol_factory():
2123            reader = asyncio.StreamReader()
2124            protocol = ServerProtocol(reader, handle_client)
2125            return protocol
2126
2127        async def test_client(addr):
2128            fut = asyncio.Future()
2129
2130            def prog(sock):
2131                try:
2132                    sock.connect(addr)
2133                    sock.starttls(client_sslctx_1)
2134
2135                    # because wrap_socket() doesn't work correctly on
2136                    # SSLSocket, we have to do the 2nd level SSL manually
2137                    incoming = ssl.MemoryBIO()
2138                    outgoing = ssl.MemoryBIO()
2139                    sslobj = client_sslctx_2.wrap_bio(incoming, outgoing)
2140
2141                    def do(func, *args):
2142                        while True:
2143                            try:
2144                                rv = func(*args)
2145                                break
2146                            except ssl.SSLWantReadError:
2147                                if outgoing.pending:
2148                                    sock.send(outgoing.read())
2149                                incoming.write(sock.recv(65536))
2150                        if outgoing.pending:
2151                            sock.send(outgoing.read())
2152                        return rv
2153
2154                    do(sslobj.do_handshake)
2155
2156                    do(sslobj.write, A_DATA)
2157                    data = do(sslobj.read, 2)
2158                    self.assertEqual(data, b'OK')
2159
2160                    do(sslobj.write, B_DATA)
2161                    data = b''
2162                    while True:
2163                        chunk = do(sslobj.read, 4)
2164                        if not chunk:
2165                            break
2166                        data += chunk
2167                    self.assertEqual(data, b'SPAM')
2168
2169                    do(sslobj.unwrap)
2170                    sock.close()
2171
2172                except Exception as ex:
2173                    self.loop.call_soon_threadsafe(fut.set_exception, ex)
2174                    sock.close()
2175                else:
2176                    self.loop.call_soon_threadsafe(fut.set_result, None)
2177
2178            client = self.tcp_client(prog)
2179            client.start()
2180            clients.append(client)
2181
2182            await fut
2183
2184        async def start_server():
2185            extras = dict(ssl_handshake_timeout=SSL_HANDSHAKE_TIMEOUT)
2186
2187            srv = await self.loop.create_server(
2188                server_protocol_factory,
2189                '127.0.0.1', 0,
2190                family=socket.AF_INET,
2191                ssl=sslctx_1,
2192                **extras)
2193
2194            try:
2195                srv_socks = srv.sockets
2196                self.assertTrue(srv_socks)
2197
2198                addr = srv_socks[0].getsockname()
2199
2200                tasks = []
2201                for _ in range(TOTAL_CNT):
2202                    tasks.append(test_client(addr))
2203
2204                await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT)
2205
2206            finally:
2207                self.loop.call_soon(srv.close)
2208                await srv.wait_closed()
2209
2210        with self._silence_eof_received_warning():
2211            self.loop.run_until_complete(start_server())
2212
2213        self.assertEqual(CNT, TOTAL_CNT)
2214
2215        for client in clients:
2216            client.stop()
2217
2218    def test_renegotiation(self):
2219        if self.implementation == 'asyncio':
2220            raise unittest.SkipTest('asyncio does not support renegotiation')
2221
2222        CNT = 0
2223        TOTAL_CNT = 25
2224
2225        A_DATA = b'A' * 1024 * 1024
2226        B_DATA = b'B' * 1024 * 1024
2227
2228        sslctx = openssl_ssl.Context(openssl_ssl.TLSv1_2_METHOD)
2229        if hasattr(openssl_ssl, 'OP_NO_SSLV2'):
2230            sslctx.set_options(openssl_ssl.OP_NO_SSLV2)
2231        sslctx.use_privatekey_file(self.ONLYKEY)
2232        sslctx.use_certificate_chain_file(self.ONLYCERT)
2233        client_sslctx = self._create_client_ssl_context()
2234        if hasattr(ssl, 'OP_NO_TLSv1_3'):
2235            client_sslctx.options |= ssl.OP_NO_TLSv1_3
2236
2237        def server(sock):
2238            conn = openssl_ssl.Connection(sslctx, sock)
2239            conn.set_accept_state()
2240
2241            data = b''
2242            while len(data) < len(A_DATA):
2243                try:
2244                    chunk = conn.recv(len(A_DATA) - len(data))
2245                    if not chunk:
2246                        break
2247                    data += chunk
2248                except openssl_ssl.WantReadError:
2249                    pass
2250            self.assertEqual(data, A_DATA)
2251            conn.renegotiate()
2252            if conn.renegotiate_pending():
2253                conn.send(b'OK')
2254            else:
2255                conn.send(b'ER')
2256
2257            data = b''
2258            while len(data) < len(B_DATA):
2259                try:
2260                    chunk = conn.recv(len(B_DATA) - len(data))
2261                    if not chunk:
2262                        break
2263                    data += chunk
2264                except openssl_ssl.WantReadError:
2265                    pass
2266            self.assertEqual(data, B_DATA)
2267            if conn.renegotiate_pending():
2268                conn.send(b'ERRO')
2269            else:
2270                conn.send(b'SPAM')
2271
2272            conn.shutdown()
2273
2274        async def client(addr):
2275            extras = dict(ssl_handshake_timeout=SSL_HANDSHAKE_TIMEOUT)
2276
2277            reader, writer = await asyncio.open_connection(
2278                *addr,
2279                ssl=client_sslctx,
2280                server_hostname='',
2281                **extras)
2282
2283            writer.write(A_DATA)
2284            self.assertEqual(await reader.readexactly(2), b'OK')
2285
2286            writer.write(B_DATA)
2287            self.assertEqual(await reader.readexactly(4), b'SPAM')
2288
2289            nonlocal CNT
2290            CNT += 1
2291
2292            writer.close()
2293            await self.wait_closed(writer)
2294
2295        async def client_sock(addr):
2296            sock = socket.socket()
2297            sock.connect(addr)
2298            reader, writer = await asyncio.open_connection(
2299                sock=sock,
2300                ssl=client_sslctx,
2301                server_hostname='')
2302
2303            writer.write(A_DATA)
2304            self.assertEqual(await reader.readexactly(2), b'OK')
2305
2306            writer.write(B_DATA)
2307            self.assertEqual(await reader.readexactly(4), b'SPAM')
2308
2309            nonlocal CNT
2310            CNT += 1
2311
2312            writer.close()
2313            await self.wait_closed(writer)
2314            sock.close()
2315
2316        def run(coro):
2317            nonlocal CNT
2318            CNT = 0
2319
2320            with self.tcp_server(server,
2321                                 max_clients=TOTAL_CNT,
2322                                 backlog=TOTAL_CNT) as srv:
2323                tasks = []
2324                for _ in range(TOTAL_CNT):
2325                    tasks.append(coro(srv.addr))
2326
2327                self.loop.run_until_complete(
2328                    asyncio.gather(*tasks))
2329
2330            self.assertEqual(CNT, TOTAL_CNT)
2331
2332        with self._silence_eof_received_warning():
2333            run(client)
2334
2335        with self._silence_eof_received_warning():
2336            run(client_sock)
2337
2338    def test_shutdown_timeout(self):
2339        if self.implementation == 'asyncio':
2340            raise unittest.SkipTest()
2341
2342        CNT = 0           # number of clients that were successful
2343        TOTAL_CNT = 25    # total number of clients that test will create
2344        TIMEOUT = 10.0    # timeout for this test
2345
2346        A_DATA = b'A' * 1024 * 1024
2347
2348        sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY)
2349        client_sslctx = self._create_client_ssl_context()
2350
2351        clients = []
2352
2353        async def handle_client(reader, writer):
2354            nonlocal CNT
2355
2356            data = await reader.readexactly(len(A_DATA))
2357            self.assertEqual(data, A_DATA)
2358            writer.write(b'OK')
2359            await writer.drain()
2360            writer.close()
2361            with self.assertRaisesRegex(asyncio.TimeoutError,
2362                                        'SSL shutdown timed out'):
2363                await reader.read()
2364            CNT += 1
2365
2366        async def test_client(addr):
2367            fut = asyncio.Future()
2368
2369            def prog(sock):
2370                try:
2371                    sock.starttls(client_sslctx)
2372                    sock.connect(addr)
2373                    sock.send(A_DATA)
2374
2375                    data = sock.recv_all(2)
2376                    self.assertEqual(data, b'OK')
2377
2378                    data = sock.recv(1024)
2379                    self.assertEqual(data, b'')
2380
2381                    fd = sock.detach()
2382                    try:
2383                        select.select([fd], [], [], 3)
2384                    finally:
2385                        os.close(fd)
2386
2387                except Exception as ex:
2388                    self.loop.call_soon_threadsafe(fut.set_exception, ex)
2389                else:
2390                    self.loop.call_soon_threadsafe(fut.set_result, None)
2391
2392            client = self.tcp_client(prog)
2393            client.start()
2394            clients.append(client)
2395
2396            await fut
2397
2398        async def start_server():
2399            extras = {'ssl_handshake_timeout': SSL_HANDSHAKE_TIMEOUT}
2400            if self.implementation != 'asyncio':  # or self.PY38
2401                extras['ssl_shutdown_timeout'] = 0.5
2402
2403            srv = await asyncio.start_server(
2404                handle_client,
2405                '127.0.0.1', 0,
2406                family=socket.AF_INET,
2407                ssl=sslctx,
2408                **extras)
2409
2410            try:
2411                srv_socks = srv.sockets
2412                self.assertTrue(srv_socks)
2413
2414                addr = srv_socks[0].getsockname()
2415
2416                tasks = []
2417                for _ in range(TOTAL_CNT):
2418                    tasks.append(test_client(addr))
2419
2420                await asyncio.wait_for(
2421                    asyncio.gather(*tasks),
2422                    TIMEOUT)
2423
2424            finally:
2425                self.loop.call_soon(srv.close)
2426                await srv.wait_closed()
2427
2428        with self._silence_eof_received_warning():
2429            self.loop.run_until_complete(start_server())
2430
2431        self.assertEqual(CNT, TOTAL_CNT)
2432
2433        for client in clients:
2434            client.stop()
2435
2436    def test_shutdown_cleanly(self):
2437        if self.implementation == 'asyncio':
2438            raise unittest.SkipTest()
2439
2440        CNT = 0
2441        TOTAL_CNT = 25
2442
2443        A_DATA = b'A' * 1024 * 1024
2444
2445        sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY)
2446        client_sslctx = self._create_client_ssl_context()
2447
2448        def server(sock):
2449            sock.starttls(
2450                sslctx,
2451                server_side=True)
2452
2453            data = sock.recv_all(len(A_DATA))
2454            self.assertEqual(data, A_DATA)
2455            sock.send(b'OK')
2456
2457            sock.unwrap()
2458
2459            sock.close()
2460
2461        async def client(addr):
2462            extras = dict(ssl_handshake_timeout=SSL_HANDSHAKE_TIMEOUT)
2463
2464            reader, writer = await asyncio.open_connection(
2465                *addr,
2466                ssl=client_sslctx,
2467                server_hostname='',
2468                **extras)
2469
2470            writer.write(A_DATA)
2471            self.assertEqual(await reader.readexactly(2), b'OK')
2472
2473            self.assertEqual(await reader.read(), b'')
2474
2475            nonlocal CNT
2476            CNT += 1
2477
2478            writer.close()
2479            await self.wait_closed(writer)
2480
2481        def run(coro):
2482            nonlocal CNT
2483            CNT = 0
2484
2485            with self.tcp_server(server,
2486                                 max_clients=TOTAL_CNT,
2487                                 backlog=TOTAL_CNT) as srv:
2488                tasks = []
2489                for _ in range(TOTAL_CNT):
2490                    tasks.append(coro(srv.addr))
2491
2492                self.loop.run_until_complete(
2493                    asyncio.gather(*tasks))
2494
2495            self.assertEqual(CNT, TOTAL_CNT)
2496
2497        with self._silence_eof_received_warning():
2498            run(client)
2499
2500    def test_write_to_closed_transport(self):
2501        if self.implementation == 'asyncio':
2502            raise unittest.SkipTest()
2503
2504        sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY)
2505        client_sslctx = self._create_client_ssl_context()
2506        future = None
2507
2508        def server(sock):
2509            sock.starttls(sslctx, server_side=True)
2510            sock.shutdown(socket.SHUT_RDWR)
2511            sock.close()
2512
2513        def unwrap_server(sock):
2514            sock.starttls(sslctx, server_side=True)
2515            while True:
2516                try:
2517                    sock.unwrap()
2518                    break
2519                except ssl.SSLError as ex:
2520                    # Since OpenSSL 1.1.1, it raises "application data after
2521                    # close notify"
2522                    # Python < 3.8:
2523                    if ex.reason == 'KRB5_S_INIT':
2524                        break
2525                    # Python >= 3.8:
2526                    if ex.reason == 'APPLICATION_DATA_AFTER_CLOSE_NOTIFY':
2527                        break
2528                    raise ex
2529                except OSError as ex:
2530                    # OpenSSL < 1.1.1
2531                    if ex.errno != 0:
2532                        raise
2533            sock.close()
2534
2535        async def client(addr):
2536            nonlocal future
2537            future = self.loop.create_future()
2538
2539            reader, writer = await asyncio.open_connection(
2540                *addr,
2541                ssl=client_sslctx,
2542                server_hostname='')
2543            writer.write(b'I AM WRITING NOWHERE1' * 100)
2544
2545            try:
2546                data = await reader.read()
2547                self.assertEqual(data, b'')
2548            except (ConnectionResetError, BrokenPipeError):
2549                pass
2550
2551            for i in range(25):
2552                writer.write(b'I AM WRITING NOWHERE2' * 100)
2553
2554            self.assertEqual(
2555                writer.transport.get_write_buffer_size(), 0)
2556
2557            await future
2558
2559            writer.close()
2560            await self.wait_closed(writer)
2561
2562        def run(meth):
2563            def wrapper(sock):
2564                try:
2565                    meth(sock)
2566                except Exception as ex:
2567                    self.loop.call_soon_threadsafe(future.set_exception, ex)
2568                else:
2569                    self.loop.call_soon_threadsafe(future.set_result, None)
2570            return wrapper
2571
2572        with self._silence_eof_received_warning():
2573            with self.tcp_server(run(server)) as srv:
2574                self.loop.run_until_complete(client(srv.addr))
2575
2576            with self.tcp_server(run(unwrap_server)) as srv:
2577                self.loop.run_until_complete(client(srv.addr))
2578
2579    def test_flush_before_shutdown(self):
2580        if self.implementation == 'asyncio':
2581            raise unittest.SkipTest()
2582
2583        CHUNK = 1024 * 128
2584        SIZE = 32
2585
2586        sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY)
2587        sslctx_openssl = openssl_ssl.Context(openssl_ssl.TLSv1_2_METHOD)
2588        if hasattr(openssl_ssl, 'OP_NO_SSLV2'):
2589            sslctx_openssl.set_options(openssl_ssl.OP_NO_SSLV2)
2590        sslctx_openssl.use_privatekey_file(self.ONLYKEY)
2591        sslctx_openssl.use_certificate_chain_file(self.ONLYCERT)
2592        client_sslctx = self._create_client_ssl_context()
2593        if hasattr(ssl, 'OP_NO_TLSv1_3'):
2594            client_sslctx.options |= ssl.OP_NO_TLSv1_3
2595
2596        future = None
2597
2598        def server(sock):
2599            sock.starttls(sslctx, server_side=True)
2600            self.assertEqual(sock.recv_all(4), b'ping')
2601            sock.send(b'pong')
2602            time.sleep(0.5)  # hopefully stuck the TCP buffer
2603            data = sock.recv_all(CHUNK * SIZE)
2604            self.assertEqual(len(data), CHUNK * SIZE)
2605            sock.close()
2606
2607        def run(meth):
2608            def wrapper(sock):
2609                try:
2610                    meth(sock)
2611                except Exception as ex:
2612                    self.loop.call_soon_threadsafe(future.set_exception, ex)
2613                else:
2614                    self.loop.call_soon_threadsafe(future.set_result, None)
2615            return wrapper
2616
2617        async def client(addr):
2618            nonlocal future
2619            future = self.loop.create_future()
2620            reader, writer = await asyncio.open_connection(
2621                *addr,
2622                ssl=client_sslctx,
2623                server_hostname='')
2624            sslprotocol = writer.get_extra_info('uvloop.sslproto')
2625            writer.write(b'ping')
2626            data = await reader.readexactly(4)
2627            self.assertEqual(data, b'pong')
2628
2629            sslprotocol.pause_writing()
2630            for _ in range(SIZE):
2631                writer.write(b'x' * CHUNK)
2632
2633            writer.close()
2634            sslprotocol.resume_writing()
2635
2636            await self.wait_closed(writer)
2637            try:
2638                data = await reader.read()
2639                self.assertEqual(data, b'')
2640            except ConnectionResetError:
2641                pass
2642            await future
2643
2644        with self.tcp_server(run(server)) as srv:
2645            self.loop.run_until_complete(client(srv.addr))
2646
2647    def test_remote_shutdown_receives_trailing_data(self):
2648        CHUNK = 1024 * 16
2649        SIZE = 8
2650        count = 0
2651
2652        sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY)
2653        client_sslctx = self._create_client_ssl_context()
2654        future = None
2655        filled = threading.Lock()
2656        eof_received = threading.Lock()
2657
2658        def server(sock):
2659            incoming = ssl.MemoryBIO()
2660            outgoing = ssl.MemoryBIO()
2661            sslobj = sslctx.wrap_bio(incoming, outgoing, server_side=True)
2662
2663            while True:
2664                try:
2665                    sslobj.do_handshake()
2666                except ssl.SSLWantReadError:
2667                    if outgoing.pending:
2668                        sock.send(outgoing.read())
2669                    incoming.write(sock.recv(16384))
2670                else:
2671                    if outgoing.pending:
2672                        sock.send(outgoing.read())
2673                    break
2674
2675            while True:
2676                try:
2677                    data = sslobj.read(4)
2678                except ssl.SSLWantReadError:
2679                    incoming.write(sock.recv(16384))
2680                else:
2681                    break
2682
2683            self.assertEqual(data, b'ping')
2684            sslobj.write(b'pong')
2685            sock.send(outgoing.read())
2686
2687            data_len = 0
2688
2689            with filled:
2690                # trigger peer's resume_writing()
2691                incoming.write(sock.recv(65536 * 4))
2692                while True:
2693                    try:
2694                        chunk = len(sslobj.read(16384))
2695                        data_len += chunk
2696                    except ssl.SSLWantReadError:
2697                        break
2698
2699                # send close_notify but don't wait for response
2700                with self.assertRaises(ssl.SSLWantReadError):
2701                    sslobj.unwrap()
2702                sock.send(outgoing.read())
2703
2704            with eof_received:
2705                # should receive all data
2706                while True:
2707                    try:
2708                        chunk = len(sslobj.read(16384))
2709                        data_len += chunk
2710                    except ssl.SSLWantReadError:
2711                        incoming.write(sock.recv(16384))
2712                        if not incoming.pending:
2713                            # EOF received
2714                            break
2715                    except ssl.SSLZeroReturnError:
2716                        break
2717
2718            self.assertEqual(data_len, CHUNK * count)
2719
2720            if self.implementation == 'uvloop':
2721                # Verify that close_notify is received. asyncio is currently
2722                # not guaranteed to send close_notify before dropping off
2723                sslobj.unwrap()
2724
2725            sock.close()
2726
2727        async def client(addr):
2728            nonlocal future, count
2729            future = self.loop.create_future()
2730
2731            with eof_received:
2732                with filled:
2733                    reader, writer = await asyncio.open_connection(
2734                        *addr,
2735                        ssl=client_sslctx,
2736                        server_hostname='')
2737                    writer.write(b'ping')
2738                    data = await reader.readexactly(4)
2739                    self.assertEqual(data, b'pong')
2740
2741                    count = 0
2742                    try:
2743                        while True:
2744                            writer.write(b'x' * CHUNK)
2745                            count += 1
2746                            await asyncio.wait_for(
2747                                asyncio.ensure_future(writer.drain()), 0.5)
2748                    except asyncio.TimeoutError:
2749                        # fill write backlog in a hacky way for uvloop
2750                        if self.implementation == 'uvloop':
2751                            for _ in range(SIZE):
2752                                writer.transport._test__append_write_backlog(
2753                                    b'x' * CHUNK)
2754                                count += 1
2755
2756                data = await reader.read()
2757                self.assertEqual(data, b'')
2758
2759            await future
2760
2761            writer.close()
2762            await self.wait_closed(writer)
2763
2764        def run(meth):
2765            def wrapper(sock):
2766                try:
2767                    meth(sock)
2768                except Exception as ex:
2769                    self.loop.call_soon_threadsafe(future.set_exception, ex)
2770                else:
2771                    self.loop.call_soon_threadsafe(future.set_result, None)
2772            return wrapper
2773
2774        with self.tcp_server(run(server)) as srv:
2775            self.loop.run_until_complete(client(srv.addr))
2776
2777    def test_connect_timeout_warning(self):
2778        s = socket.socket(socket.AF_INET)
2779        s.bind(('127.0.0.1', 0))
2780        addr = s.getsockname()
2781
2782        async def test():
2783            try:
2784                await asyncio.wait_for(
2785                    self.loop.create_connection(asyncio.Protocol,
2786                                                *addr, ssl=True),
2787                    0.1)
2788            except (ConnectionRefusedError, asyncio.TimeoutError):
2789                pass
2790            else:
2791                self.fail('TimeoutError is not raised')
2792
2793        with s:
2794            try:
2795                with self.assertWarns(ResourceWarning) as cm:
2796                    self.loop.run_until_complete(test())
2797                    gc.collect()
2798                    gc.collect()
2799                    gc.collect()
2800            except AssertionError as e:
2801                self.assertEqual(str(e), 'ResourceWarning not triggered')
2802            else:
2803                self.fail('Unexpected ResourceWarning: {}'.format(cm.warning))
2804
2805    def test_handshake_timeout_handler_leak(self):
2806        if self.implementation == 'asyncio':
2807            # Okay this turns out to be an issue for asyncio.sslproto too
2808            raise unittest.SkipTest()
2809
2810        s = socket.socket(socket.AF_INET)
2811        s.bind(('127.0.0.1', 0))
2812        s.listen(1)
2813        addr = s.getsockname()
2814
2815        async def test(ctx):
2816            try:
2817                await asyncio.wait_for(
2818                    self.loop.create_connection(asyncio.Protocol, *addr,
2819                                                ssl=ctx),
2820                    0.1)
2821            except (ConnectionRefusedError, asyncio.TimeoutError):
2822                pass
2823            else:
2824                self.fail('TimeoutError is not raised')
2825
2826        with s:
2827            ctx = ssl.create_default_context()
2828            self.loop.run_until_complete(test(ctx))
2829            ctx = weakref.ref(ctx)
2830
2831        # SSLProtocol should be DECREF to 0
2832        self.assertIsNone(ctx())
2833
2834    def test_shutdown_timeout_handler_leak(self):
2835        loop = self.loop
2836
2837        def server(sock):
2838            sslctx = self._create_server_ssl_context(self.ONLYCERT,
2839                                                     self.ONLYKEY)
2840            sock = sslctx.wrap_socket(sock, server_side=True)
2841            sock.recv(32)
2842            sock.close()
2843
2844        class Protocol(asyncio.Protocol):
2845            def __init__(self):
2846                self.fut = asyncio.Future(loop=loop)
2847
2848            def connection_lost(self, exc):
2849                self.fut.set_result(None)
2850
2851        async def client(addr, ctx):
2852            tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx)
2853            tr.close()
2854            await pr.fut
2855
2856        with self.tcp_server(server) as srv:
2857            ctx = self._create_client_ssl_context()
2858            loop.run_until_complete(client(srv.addr, ctx))
2859            ctx = weakref.ref(ctx)
2860
2861        if self.implementation == 'asyncio':
2862            # asyncio has no shutdown timeout, but it ends up with a circular
2863            # reference loop - not ideal (introduces gc glitches), but at least
2864            # not leaking
2865            gc.collect()
2866            gc.collect()
2867            gc.collect()
2868
2869        # SSLProtocol should be DECREF to 0
2870        self.assertIsNone(ctx())
2871
2872    def test_shutdown_timeout_handler_not_set(self):
2873        if self.implementation == 'asyncio':
2874            # asyncio doesn't call SSL eof_received() so we can't run this test
2875            raise unittest.SkipTest()
2876
2877        loop = self.loop
2878        extra = None
2879
2880        def server(sock):
2881            sslctx = self._create_server_ssl_context(self.ONLYCERT,
2882                                                     self.ONLYKEY)
2883            sock = sslctx.wrap_socket(sock, server_side=True)
2884            sock.send(b'hello')
2885            assert sock.recv(1024) == b'world'
2886            sock.send(b'extra bytes')
2887            # sending EOF here
2888            sock.shutdown(socket.SHUT_WR)
2889            # make sure we have enough time to reproduce the issue
2890            self.assertEqual(sock.recv(1024), b'')
2891            sock.close()
2892
2893        class Protocol(asyncio.Protocol):
2894            def __init__(self):
2895                self.fut = asyncio.Future(loop=loop)
2896                self.transport = None
2897
2898            def connection_made(self, transport):
2899                self.transport = transport
2900
2901            def data_received(self, data):
2902                if data == b'hello':
2903                    self.transport.write(b'world')
2904                    # pause reading would make incoming data stay in the sslobj
2905                    self.transport.pause_reading()
2906                else:
2907                    nonlocal extra
2908                    extra = data
2909
2910            def connection_lost(self, exc):
2911                if exc is None:
2912                    self.fut.set_result(None)
2913                else:
2914                    self.fut.set_exception(exc)
2915
2916            def eof_received(self):
2917                self.transport.resume_reading()
2918
2919        async def client(addr):
2920            ctx = self._create_client_ssl_context()
2921            tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx)
2922            await pr.fut
2923            tr.close()
2924            # extra data received after transport.close() should be ignored
2925            self.assertIsNone(extra)
2926
2927        with self.tcp_server(server) as srv:
2928            loop.run_until_complete(client(srv.addr))
2929
2930    def test_shutdown_while_pause_reading(self):
2931        if self.implementation == 'asyncio':
2932            raise unittest.SkipTest()
2933
2934        loop = self.loop
2935        conn_made = loop.create_future()
2936        eof_recvd = loop.create_future()
2937        conn_lost = loop.create_future()
2938        data_recv = False
2939
2940        def server(sock):
2941            sslctx = self._create_server_ssl_context(self.ONLYCERT,
2942                                                     self.ONLYKEY)
2943            incoming = ssl.MemoryBIO()
2944            outgoing = ssl.MemoryBIO()
2945            sslobj = sslctx.wrap_bio(incoming, outgoing, server_side=True)
2946
2947            while True:
2948                try:
2949                    sslobj.do_handshake()
2950                    sslobj.write(b'trailing data')
2951                    break
2952                except ssl.SSLWantReadError:
2953                    if outgoing.pending:
2954                        sock.send(outgoing.read())
2955                    incoming.write(sock.recv(16384))
2956            if outgoing.pending:
2957                sock.send(outgoing.read())
2958
2959            while True:
2960                try:
2961                    self.assertEqual(sslobj.read(), b'')  # close_notify
2962                    break
2963                except ssl.SSLWantReadError:
2964                    incoming.write(sock.recv(16384))
2965
2966            while True:
2967                try:
2968                    sslobj.unwrap()
2969                except ssl.SSLWantReadError:
2970                    if outgoing.pending:
2971                        sock.send(outgoing.read())
2972                    incoming.write(sock.recv(16384))
2973                else:
2974                    if outgoing.pending:
2975                        sock.send(outgoing.read())
2976                    break
2977
2978            self.assertEqual(sock.recv(16384), b'')  # socket closed
2979
2980        class Protocol(asyncio.Protocol):
2981            def connection_made(self, transport):
2982                conn_made.set_result(None)
2983
2984            def data_received(self, data):
2985                nonlocal data_recv
2986                data_recv = True
2987
2988            def eof_received(self):
2989                eof_recvd.set_result(None)
2990
2991            def connection_lost(self, exc):
2992                if exc is None:
2993                    conn_lost.set_result(None)
2994                else:
2995                    conn_lost.set_exception(exc)
2996
2997        async def client(addr):
2998            ctx = self._create_client_ssl_context()
2999            tr, _ = await loop.create_connection(Protocol, *addr, ssl=ctx)
3000            await conn_made
3001            self.assertFalse(data_recv)
3002
3003            tr.pause_reading()
3004            tr.close()
3005
3006            await asyncio.wait_for(eof_recvd, 10)
3007            await asyncio.wait_for(conn_lost, 10)
3008
3009        with self.tcp_server(server) as srv:
3010            loop.run_until_complete(client(srv.addr))
3011
3012    def test_bpo_39951_discard_trailing_data(self):
3013        sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY)
3014        client_sslctx = self._create_client_ssl_context()
3015        future = None
3016        close_notify = threading.Lock()
3017
3018        def server(sock):
3019            incoming = ssl.MemoryBIO()
3020            outgoing = ssl.MemoryBIO()
3021            sslobj = sslctx.wrap_bio(incoming, outgoing, server_side=True)
3022
3023            while True:
3024                try:
3025                    sslobj.do_handshake()
3026                except ssl.SSLWantReadError:
3027                    if outgoing.pending:
3028                        sock.send(outgoing.read())
3029                    incoming.write(sock.recv(16384))
3030                else:
3031                    if outgoing.pending:
3032                        sock.send(outgoing.read())
3033                    break
3034
3035            while True:
3036                try:
3037                    data = sslobj.read(4)
3038                except ssl.SSLWantReadError:
3039                    incoming.write(sock.recv(16384))
3040                else:
3041                    break
3042
3043            self.assertEqual(data, b'ping')
3044            sslobj.write(b'pong')
3045            sock.send(outgoing.read())
3046
3047            with close_notify:
3048                sslobj.write(b'trailing')
3049                sock.send(outgoing.read())
3050                time.sleep(0.5)  # allow time for the client to receive
3051
3052            incoming.write(sock.recv(16384))
3053            sslobj.unwrap()
3054            sock.send(outgoing.read())
3055            sock.close()
3056
3057        async def client(addr):
3058            nonlocal future
3059            future = self.loop.create_future()
3060
3061            with close_notify:
3062                reader, writer = await asyncio.open_connection(
3063                    *addr,
3064                    ssl=client_sslctx,
3065                    server_hostname='')
3066                writer.write(b'ping')
3067                data = await reader.readexactly(4)
3068                self.assertEqual(data, b'pong')
3069
3070                writer.close()
3071
3072            try:
3073                await self.wait_closed(writer)
3074            except ssl.SSLError as e:
3075                if self.implementation == 'asyncio' and \
3076                        'application data after close notify' in str(e):
3077                    raise unittest.SkipTest('bpo-39951')
3078                raise
3079            await future
3080
3081        def run(meth):
3082            def wrapper(sock):
3083                try:
3084                    meth(sock)
3085                except Exception as ex:
3086                    self.loop.call_soon_threadsafe(future.set_exception, ex)
3087                else:
3088                    self.loop.call_soon_threadsafe(future.set_result, None)
3089            return wrapper
3090
3091        with self.tcp_server(run(server)) as srv:
3092            self.loop.run_until_complete(client(srv.addr))
3093
3094    def test_first_data_after_wakeup(self):
3095        if self.implementation == 'asyncio':
3096            raise unittest.SkipTest()
3097
3098        server_context = self._create_server_ssl_context(
3099            self.ONLYCERT, self.ONLYKEY)
3100        client_context = self._create_client_ssl_context()
3101        loop = self.loop
3102        this = self
3103        fut = self.loop.create_future()
3104
3105        def client(sock, addr):
3106            try:
3107                sock.connect(addr)
3108
3109                incoming = ssl.MemoryBIO()
3110                outgoing = ssl.MemoryBIO()
3111                sslobj = client_context.wrap_bio(incoming, outgoing)
3112
3113                # Do handshake manually so that we could collect the last piece
3114                while True:
3115                    try:
3116                        sslobj.do_handshake()
3117                        break
3118                    except ssl.SSLWantReadError:
3119                        if outgoing.pending:
3120                            sock.send(outgoing.read())
3121                        incoming.write(sock.recv(65536))
3122
3123                # Send the first data together with the last handshake payload
3124                sslobj.write(b'hello')
3125                sock.send(outgoing.read())
3126
3127                while True:
3128                    try:
3129                        incoming.write(sock.recv(65536))
3130                        self.assertEqual(sslobj.read(1024), b'hello')
3131                        break
3132                    except ssl.SSLWantReadError:
3133                        pass
3134
3135                sock.close()
3136
3137            except Exception as ex:
3138                loop.call_soon_threadsafe(fut.set_exception, ex)
3139                sock.close()
3140            else:
3141                loop.call_soon_threadsafe(fut.set_result, None)
3142
3143        class EchoProto(asyncio.Protocol):
3144            def connection_made(self, tr):
3145                self.tr = tr
3146                # manually run the coroutine, in order to avoid accidental data
3147                coro = loop.start_tls(
3148                    tr, self, server_context,
3149                    server_side=True,
3150                    ssl_handshake_timeout=this.TIMEOUT,
3151                )
3152                waiter = coro.send(None)
3153
3154                def tls_started(_):
3155                    try:
3156                        coro.send(None)
3157                    except StopIteration as e:
3158                        # update self.tr to SSL transport as soon as we know it
3159                        self.tr = e.value
3160
3161                waiter.add_done_callback(tls_started)
3162
3163            def data_received(self, data):
3164                # This is a dumb protocol that writes back whatever it receives
3165                # regardless of whether self.tr is SSL or not
3166                self.tr.write(data)
3167
3168        async def run_main():
3169            proto = EchoProto()
3170
3171            server = await self.loop.create_server(
3172                lambda: proto, '127.0.0.1', 0)
3173            addr = server.sockets[0].getsockname()
3174
3175            with self.tcp_client(lambda sock: client(sock, addr),
3176                                 timeout=self.TIMEOUT):
3177                await asyncio.wait_for(fut, timeout=self.TIMEOUT)
3178                proto.tr.close()
3179
3180            server.close()
3181            await server.wait_closed()
3182
3183        self.loop.run_until_complete(run_main())
3184
3185
3186class Test_UV_TCPSSL(_TestSSL, tb.UVTestCase):
3187    pass
3188
3189
3190class Test_AIO_TCPSSL(_TestSSL, tb.AIOTestCase):
3191    pass
3192