1import asyncio
2import os
3import pathlib
4import socket
5import tempfile
6import time
7import unittest
8import sys
9
10from uvloop import _testbase as tb
11
12
13class _TestUnix:
14    def test_create_unix_server_1(self):
15        CNT = 0           # number of clients that were successful
16        TOTAL_CNT = 100   # total number of clients that test will create
17        TIMEOUT = 5.0     # timeout for this test
18
19        async def handle_client(reader, writer):
20            nonlocal CNT
21
22            data = await reader.readexactly(4)
23            self.assertEqual(data, b'AAAA')
24            writer.write(b'OK')
25
26            data = await reader.readexactly(4)
27            self.assertEqual(data, b'BBBB')
28            writer.write(b'SPAM')
29
30            await writer.drain()
31            writer.close()
32            await self.wait_closed(writer)
33
34            CNT += 1
35
36        async def test_client(addr):
37            sock = socket.socket(socket.AF_UNIX)
38            with sock:
39                sock.setblocking(False)
40                await self.loop.sock_connect(sock, addr)
41
42                await self.loop.sock_sendall(sock, b'AAAA')
43
44                buf = b''
45                while len(buf) != 2:
46                    buf += await self.loop.sock_recv(sock, 1)
47                self.assertEqual(buf, b'OK')
48
49                await self.loop.sock_sendall(sock, b'BBBB')
50
51                buf = b''
52                while len(buf) != 4:
53                    buf += await self.loop.sock_recv(sock, 1)
54                self.assertEqual(buf, b'SPAM')
55
56        async def start_server():
57            nonlocal CNT
58            CNT = 0
59
60            with tempfile.TemporaryDirectory() as td:
61                sock_name = os.path.join(td, 'sock')
62                srv = await asyncio.start_unix_server(
63                    handle_client,
64                    sock_name)
65
66                try:
67                    srv_socks = srv.sockets
68                    self.assertTrue(srv_socks)
69                    self.assertTrue(srv.is_serving())
70
71                    tasks = []
72                    for _ in range(TOTAL_CNT):
73                        tasks.append(test_client(sock_name))
74
75                    await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT)
76
77                finally:
78                    self.loop.call_soon(srv.close)
79                    await srv.wait_closed()
80
81                    # Check that the server cleaned-up proxy-sockets
82                    for srv_sock in srv_socks:
83                        self.assertEqual(srv_sock.fileno(), -1)
84
85                    self.assertFalse(srv.is_serving())
86
87                # asyncio doesn't cleanup the sock file
88                self.assertTrue(os.path.exists(sock_name))
89
90        async def start_server_sock(start_server):
91            nonlocal CNT
92            CNT = 0
93
94            with tempfile.TemporaryDirectory() as td:
95                sock_name = os.path.join(td, 'sock')
96                sock = socket.socket(socket.AF_UNIX)
97                sock.bind(sock_name)
98
99                srv = await start_server(sock)
100
101                try:
102                    srv_socks = srv.sockets
103                    self.assertTrue(srv_socks)
104                    self.assertTrue(srv.is_serving())
105
106                    tasks = []
107                    for _ in range(TOTAL_CNT):
108                        tasks.append(test_client(sock_name))
109
110                    await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT)
111
112                finally:
113                    self.loop.call_soon(srv.close)
114                    await srv.wait_closed()
115
116                    # Check that the server cleaned-up proxy-sockets
117                    for srv_sock in srv_socks:
118                        self.assertEqual(srv_sock.fileno(), -1)
119
120                    self.assertFalse(srv.is_serving())
121
122                # asyncio doesn't cleanup the sock file
123                self.assertTrue(os.path.exists(sock_name))
124
125        with self.subTest(func='start_unix_server(host, port)'):
126            self.loop.run_until_complete(start_server())
127            self.assertEqual(CNT, TOTAL_CNT)
128
129        with self.subTest(func='start_unix_server(sock)'):
130            self.loop.run_until_complete(start_server_sock(
131                lambda sock: asyncio.start_unix_server(
132                    handle_client,
133                    None,
134                    sock=sock)))
135            self.assertEqual(CNT, TOTAL_CNT)
136
137        with self.subTest(func='start_server(sock)'):
138            self.loop.run_until_complete(start_server_sock(
139                lambda sock: asyncio.start_server(
140                    handle_client,
141                    None, None,
142                    sock=sock)))
143            self.assertEqual(CNT, TOTAL_CNT)
144
145    def test_create_unix_server_2(self):
146        with tempfile.TemporaryDirectory() as td:
147            sock_name = os.path.join(td, 'sock')
148            with open(sock_name, 'wt') as f:
149                f.write('x')
150
151            with self.assertRaisesRegex(
152                    OSError, "Address '{}' is already in use".format(
153                        sock_name)):
154
155                self.loop.run_until_complete(
156                    self.loop.create_unix_server(object, sock_name))
157
158    def test_create_unix_server_3(self):
159        with self.assertRaisesRegex(
160                ValueError, 'ssl_handshake_timeout is only meaningful'):
161            self.loop.run_until_complete(
162                self.loop.create_unix_server(
163                    lambda: None, path='/tmp/a', ssl_handshake_timeout=10))
164
165    def test_create_unix_server_existing_path_sock(self):
166        with self.unix_sock_name() as path:
167            sock = socket.socket(socket.AF_UNIX)
168            with sock:
169                sock.bind(path)
170                sock.listen(1)
171
172            # Check that no error is raised -- `path` is removed.
173            coro = self.loop.create_unix_server(lambda: None, path)
174            srv = self.loop.run_until_complete(coro)
175            srv.close()
176            self.loop.run_until_complete(srv.wait_closed())
177
178    def test_create_unix_connection_open_unix_con_addr(self):
179        async def client(addr):
180            reader, writer = await asyncio.open_unix_connection(addr)
181
182            writer.write(b'AAAA')
183            self.assertEqual(await reader.readexactly(2), b'OK')
184
185            writer.write(b'BBBB')
186            self.assertEqual(await reader.readexactly(4), b'SPAM')
187
188            writer.close()
189            await self.wait_closed(writer)
190
191        self._test_create_unix_connection_1(client)
192
193    def test_create_unix_connection_open_unix_con_sock(self):
194        async def client(addr):
195            sock = socket.socket(socket.AF_UNIX)
196            sock.connect(addr)
197            reader, writer = await asyncio.open_unix_connection(sock=sock)
198
199            writer.write(b'AAAA')
200            self.assertEqual(await reader.readexactly(2), b'OK')
201
202            writer.write(b'BBBB')
203            self.assertEqual(await reader.readexactly(4), b'SPAM')
204
205            writer.close()
206            await self.wait_closed(writer)
207
208        self._test_create_unix_connection_1(client)
209
210    def test_create_unix_connection_open_con_sock(self):
211        async def client(addr):
212            sock = socket.socket(socket.AF_UNIX)
213            sock.connect(addr)
214            reader, writer = await asyncio.open_connection(sock=sock)
215
216            writer.write(b'AAAA')
217            self.assertEqual(await reader.readexactly(2), b'OK')
218
219            writer.write(b'BBBB')
220            self.assertEqual(await reader.readexactly(4), b'SPAM')
221
222            writer.close()
223            await self.wait_closed(writer)
224
225        self._test_create_unix_connection_1(client)
226
227    def _test_create_unix_connection_1(self, client):
228        CNT = 0
229        TOTAL_CNT = 100
230
231        def server(sock):
232            data = sock.recv_all(4)
233            self.assertEqual(data, b'AAAA')
234            sock.send(b'OK')
235
236            data = sock.recv_all(4)
237            self.assertEqual(data, b'BBBB')
238            sock.send(b'SPAM')
239
240        async def client_wrapper(addr):
241            await client(addr)
242            nonlocal CNT
243            CNT += 1
244
245        def run(coro):
246            nonlocal CNT
247            CNT = 0
248
249            with self.unix_server(server,
250                                  max_clients=TOTAL_CNT,
251                                  backlog=TOTAL_CNT) as srv:
252                tasks = []
253                for _ in range(TOTAL_CNT):
254                    tasks.append(coro(srv.addr))
255
256                self.loop.run_until_complete(asyncio.gather(*tasks))
257
258                # Give time for all transports to close.
259                self.loop.run_until_complete(asyncio.sleep(0.1))
260
261            self.assertEqual(CNT, TOTAL_CNT)
262
263        run(client_wrapper)
264
265    def test_create_unix_connection_2(self):
266        with tempfile.NamedTemporaryFile() as tmp:
267            path = tmp.name
268
269        async def client():
270            reader, writer = await asyncio.open_unix_connection(path)
271            writer.close()
272            await self.wait_closed(writer)
273
274        async def runner():
275            with self.assertRaises(FileNotFoundError):
276                await client()
277
278        self.loop.run_until_complete(runner())
279
280    def test_create_unix_connection_3(self):
281        CNT = 0
282        TOTAL_CNT = 100
283
284        def server(sock):
285            data = sock.recv_all(4)
286            self.assertEqual(data, b'AAAA')
287            sock.close()
288
289        async def client(addr):
290            reader, writer = await asyncio.open_unix_connection(addr)
291
292            sock = writer._transport.get_extra_info('socket')
293            self.assertEqual(sock.family, socket.AF_UNIX)
294
295            writer.write(b'AAAA')
296
297            with self.assertRaises(asyncio.IncompleteReadError):
298                await reader.readexactly(10)
299
300            writer.close()
301            await self.wait_closed(writer)
302
303            nonlocal CNT
304            CNT += 1
305
306        def run(coro):
307            nonlocal CNT
308            CNT = 0
309
310            with self.unix_server(server,
311                                  max_clients=TOTAL_CNT,
312                                  backlog=TOTAL_CNT) as srv:
313                tasks = []
314                for _ in range(TOTAL_CNT):
315                    tasks.append(coro(srv.addr))
316
317                self.loop.run_until_complete(asyncio.gather(*tasks))
318
319            self.assertEqual(CNT, TOTAL_CNT)
320
321        run(client)
322
323    def test_create_unix_connection_4(self):
324        sock = socket.socket(socket.AF_UNIX)
325        sock.close()
326
327        async def client():
328            reader, writer = await asyncio.open_unix_connection(sock=sock)
329            writer.close()
330            await self.wait_closed(writer)
331
332        async def runner():
333            with self.assertRaisesRegex(OSError, 'Bad file'):
334                await client()
335
336        self.loop.run_until_complete(runner())
337
338    def test_create_unix_connection_5(self):
339        s1, s2 = socket.socketpair(socket.AF_UNIX)
340
341        excs = []
342
343        class Proto(asyncio.Protocol):
344            def connection_lost(self, exc):
345                excs.append(exc)
346
347        proto = Proto()
348
349        async def client():
350            t, _ = await self.loop.create_unix_connection(
351                lambda: proto,
352                None,
353                sock=s2)
354
355            t.write(b'AAAAA')
356            s1.close()
357            t.write(b'AAAAA')
358            await asyncio.sleep(0.1)
359
360        self.loop.run_until_complete(client())
361
362        self.assertEqual(len(excs), 1)
363        self.assertIn(excs[0].__class__,
364                      (BrokenPipeError, ConnectionResetError))
365
366    def test_create_unix_connection_6(self):
367        with self.assertRaisesRegex(
368                ValueError, 'ssl_handshake_timeout is only meaningful'):
369            self.loop.run_until_complete(
370                self.loop.create_unix_connection(
371                    lambda: None, path='/tmp/a', ssl_handshake_timeout=10))
372
373
374class Test_UV_Unix(_TestUnix, tb.UVTestCase):
375
376    @unittest.skipUnless(hasattr(os, 'fspath'), 'no os.fspath()')
377    def test_create_unix_connection_pathlib(self):
378        async def run(addr):
379            t, _ = await self.loop.create_unix_connection(
380                asyncio.Protocol, addr)
381            t.close()
382
383        with self.unix_server(lambda sock: time.sleep(0.01)) as srv:
384            addr = pathlib.Path(srv.addr)
385            self.loop.run_until_complete(run(addr))
386
387    @unittest.skipUnless(hasattr(os, 'fspath'), 'no os.fspath()')
388    def test_create_unix_server_pathlib(self):
389        with self.unix_sock_name() as srv_path:
390            srv_path = pathlib.Path(srv_path)
391            srv = self.loop.run_until_complete(
392                self.loop.create_unix_server(asyncio.Protocol, srv_path))
393            srv.close()
394            self.loop.run_until_complete(srv.wait_closed())
395
396    def test_transport_fromsock_get_extra_info(self):
397        # This tests is only for uvloop.  asyncio should pass it
398        # too in Python 3.6.
399
400        async def test(sock):
401            t, _ = await self.loop.create_unix_connection(
402                asyncio.Protocol,
403                sock=sock)
404
405            sock = t.get_extra_info('socket')
406            self.assertIs(t.get_extra_info('socket'), sock)
407
408            with self.assertRaisesRegex(RuntimeError, 'is used by transport'):
409                self.loop.add_writer(sock.fileno(), lambda: None)
410            with self.assertRaisesRegex(RuntimeError, 'is used by transport'):
411                self.loop.remove_writer(sock.fileno())
412
413            t.close()
414
415        s1, s2 = socket.socketpair(socket.AF_UNIX)
416        with s1, s2:
417            self.loop.run_until_complete(test(s1))
418
419    def test_create_unix_server_path_dgram(self):
420        sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
421        with sock:
422            coro = self.loop.create_unix_server(lambda: None,
423                                                sock=sock)
424            with self.assertRaisesRegex(ValueError,
425                                        'A UNIX Domain Stream.*was expected'):
426                self.loop.run_until_complete(coro)
427
428    @unittest.skipUnless(hasattr(socket, 'SOCK_NONBLOCK'),
429                         'no socket.SOCK_NONBLOCK (linux only)')
430    def test_create_unix_server_path_stream_bittype(self):
431        sock = socket.socket(
432            socket.AF_UNIX, socket.SOCK_STREAM | socket.SOCK_NONBLOCK)
433        with tempfile.NamedTemporaryFile() as file:
434            fn = file.name
435        try:
436            with sock:
437                sock.bind(fn)
438                coro = self.loop.create_unix_server(lambda: None, path=None,
439                                                    sock=sock)
440                srv = self.loop.run_until_complete(coro)
441                srv.close()
442                self.loop.run_until_complete(srv.wait_closed())
443        finally:
444            os.unlink(fn)
445
446    @unittest.skipUnless(sys.platform.startswith('linux'), 'requires epoll')
447    def test_epollhup(self):
448        SIZE = 50
449        eof = False
450        done = False
451        recvd = b''
452
453        class Proto(asyncio.BaseProtocol):
454            def connection_made(self, tr):
455                tr.write(b'hello')
456                self.data = bytearray(SIZE)
457                self.buf = memoryview(self.data)
458
459            def get_buffer(self, sizehint):
460                return self.buf
461
462            def buffer_updated(self, nbytes):
463                nonlocal recvd
464                recvd += self.buf[:nbytes]
465
466            def eof_received(self):
467                nonlocal eof
468                eof = True
469
470            def connection_lost(self, exc):
471                nonlocal done
472                done = exc
473
474        async def test():
475            with tempfile.TemporaryDirectory() as td:
476                sock_name = os.path.join(td, 'sock')
477                srv = await self.loop.create_unix_server(Proto, sock_name)
478
479                s = socket.socket(socket.AF_UNIX)
480                with s:
481                    s.setblocking(False)
482                    await self.loop.sock_connect(s, sock_name)
483                    d = await self.loop.sock_recv(s, 100)
484                    self.assertEqual(d, b'hello')
485
486                    # IMPORTANT: overflow recv buffer and close immediately
487                    await self.loop.sock_sendall(s, b'a' * (SIZE + 1))
488
489                srv.close()
490                await srv.wait_closed()
491
492        self.loop.run_until_complete(test())
493        self.assertTrue(eof)
494        self.assertIsNone(done)
495        self.assertEqual(recvd, b'a' * (SIZE + 1))
496
497
498class Test_AIO_Unix(_TestUnix, tb.AIOTestCase):
499    pass
500
501
502class _TestSSL(tb.SSLTestCase):
503
504    ONLYCERT = tb._cert_fullname(__file__, 'ssl_cert.pem')
505    ONLYKEY = tb._cert_fullname(__file__, 'ssl_key.pem')
506
507    def test_create_unix_server_ssl_1(self):
508        CNT = 0           # number of clients that were successful
509        TOTAL_CNT = 25    # total number of clients that test will create
510        TIMEOUT = 10.0    # timeout for this test
511
512        A_DATA = b'A' * 1024 * 1024
513        B_DATA = b'B' * 1024 * 1024
514
515        sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY)
516        client_sslctx = self._create_client_ssl_context()
517
518        clients = []
519
520        async def handle_client(reader, writer):
521            nonlocal CNT
522
523            data = await reader.readexactly(len(A_DATA))
524            self.assertEqual(data, A_DATA)
525            writer.write(b'OK')
526
527            data = await reader.readexactly(len(B_DATA))
528            self.assertEqual(data, B_DATA)
529            writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')])
530
531            await writer.drain()
532            writer.close()
533
534            CNT += 1
535
536        async def test_client(addr):
537            fut = asyncio.Future(loop=self.loop)
538
539            def prog(sock):
540                try:
541                    sock.starttls(client_sslctx)
542
543                    sock.connect(addr)
544                    sock.send(A_DATA)
545
546                    data = sock.recv_all(2)
547                    self.assertEqual(data, b'OK')
548
549                    sock.send(B_DATA)
550                    data = sock.recv_all(4)
551                    self.assertEqual(data, b'SPAM')
552
553                    sock.close()
554
555                except Exception as ex:
556                    self.loop.call_soon_threadsafe(
557                        lambda ex=ex:
558                            (fut.cancelled() or fut.set_exception(ex)))
559                else:
560                    self.loop.call_soon_threadsafe(
561                        lambda: (fut.cancelled() or fut.set_result(None)))
562
563            client = self.unix_client(prog)
564            client.start()
565            clients.append(client)
566
567            await fut
568
569        async def start_server():
570            extras = dict(ssl_handshake_timeout=10.0)
571
572            with tempfile.TemporaryDirectory() as td:
573                sock_name = os.path.join(td, 'sock')
574
575                srv = await asyncio.start_unix_server(
576                    handle_client,
577                    sock_name,
578                    ssl=sslctx,
579                    **extras)
580
581                try:
582                    tasks = []
583                    for _ in range(TOTAL_CNT):
584                        tasks.append(test_client(sock_name))
585
586                    await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT)
587
588                finally:
589                    self.loop.call_soon(srv.close)
590                    await srv.wait_closed()
591
592        try:
593            with self._silence_eof_received_warning():
594                self.loop.run_until_complete(start_server())
595        except asyncio.TimeoutError:
596            if os.environ.get('TRAVIS_OS_NAME') == 'osx':
597                # XXX: figure out why this fails on macOS on Travis
598                raise unittest.SkipTest('unexplained error on Travis macOS')
599            else:
600                raise
601
602        self.assertEqual(CNT, TOTAL_CNT)
603
604        for client in clients:
605            client.stop()
606
607    def test_create_unix_connection_ssl_1(self):
608        CNT = 0
609        TOTAL_CNT = 25
610
611        A_DATA = b'A' * 1024 * 1024
612        B_DATA = b'B' * 1024 * 1024
613
614        sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY)
615        client_sslctx = self._create_client_ssl_context()
616
617        def server(sock):
618            sock.starttls(sslctx, server_side=True)
619
620            data = sock.recv_all(len(A_DATA))
621            self.assertEqual(data, A_DATA)
622            sock.send(b'OK')
623
624            data = sock.recv_all(len(B_DATA))
625            self.assertEqual(data, B_DATA)
626            sock.send(b'SPAM')
627
628            sock.close()
629
630        async def client(addr):
631            extras = dict(ssl_handshake_timeout=10.0)
632
633            reader, writer = await asyncio.open_unix_connection(
634                addr,
635                ssl=client_sslctx,
636                server_hostname='',
637                **extras)
638
639            writer.write(A_DATA)
640            self.assertEqual(await reader.readexactly(2), b'OK')
641
642            writer.write(B_DATA)
643            self.assertEqual(await reader.readexactly(4), b'SPAM')
644
645            nonlocal CNT
646            CNT += 1
647
648            writer.close()
649            await self.wait_closed(writer)
650
651        def run(coro):
652            nonlocal CNT
653            CNT = 0
654
655            with self.unix_server(server,
656                                  max_clients=TOTAL_CNT,
657                                  backlog=TOTAL_CNT) as srv:
658                tasks = []
659                for _ in range(TOTAL_CNT):
660                    tasks.append(coro(srv.addr))
661
662                self.loop.run_until_complete(asyncio.gather(*tasks))
663
664            self.assertEqual(CNT, TOTAL_CNT)
665
666        with self._silence_eof_received_warning():
667            run(client)
668
669
670class Test_UV_UnixSSL(_TestSSL, tb.UVTestCase):
671    pass
672
673
674class Test_AIO_UnixSSL(_TestSSL, tb.AIOTestCase):
675    pass
676