1import asyncio 2import os 3import pathlib 4import socket 5import tempfile 6import time 7import unittest 8import sys 9 10from uvloop import _testbase as tb 11 12 13class _TestUnix: 14 def test_create_unix_server_1(self): 15 CNT = 0 # number of clients that were successful 16 TOTAL_CNT = 100 # total number of clients that test will create 17 TIMEOUT = 5.0 # timeout for this test 18 19 async def handle_client(reader, writer): 20 nonlocal CNT 21 22 data = await reader.readexactly(4) 23 self.assertEqual(data, b'AAAA') 24 writer.write(b'OK') 25 26 data = await reader.readexactly(4) 27 self.assertEqual(data, b'BBBB') 28 writer.write(b'SPAM') 29 30 await writer.drain() 31 writer.close() 32 await self.wait_closed(writer) 33 34 CNT += 1 35 36 async def test_client(addr): 37 sock = socket.socket(socket.AF_UNIX) 38 with sock: 39 sock.setblocking(False) 40 await self.loop.sock_connect(sock, addr) 41 42 await self.loop.sock_sendall(sock, b'AAAA') 43 44 buf = b'' 45 while len(buf) != 2: 46 buf += await self.loop.sock_recv(sock, 1) 47 self.assertEqual(buf, b'OK') 48 49 await self.loop.sock_sendall(sock, b'BBBB') 50 51 buf = b'' 52 while len(buf) != 4: 53 buf += await self.loop.sock_recv(sock, 1) 54 self.assertEqual(buf, b'SPAM') 55 56 async def start_server(): 57 nonlocal CNT 58 CNT = 0 59 60 with tempfile.TemporaryDirectory() as td: 61 sock_name = os.path.join(td, 'sock') 62 srv = await asyncio.start_unix_server( 63 handle_client, 64 sock_name) 65 66 try: 67 srv_socks = srv.sockets 68 self.assertTrue(srv_socks) 69 self.assertTrue(srv.is_serving()) 70 71 tasks = [] 72 for _ in range(TOTAL_CNT): 73 tasks.append(test_client(sock_name)) 74 75 await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT) 76 77 finally: 78 self.loop.call_soon(srv.close) 79 await srv.wait_closed() 80 81 # Check that the server cleaned-up proxy-sockets 82 for srv_sock in srv_socks: 83 self.assertEqual(srv_sock.fileno(), -1) 84 85 self.assertFalse(srv.is_serving()) 86 87 # asyncio doesn't cleanup the sock file 88 self.assertTrue(os.path.exists(sock_name)) 89 90 async def start_server_sock(start_server): 91 nonlocal CNT 92 CNT = 0 93 94 with tempfile.TemporaryDirectory() as td: 95 sock_name = os.path.join(td, 'sock') 96 sock = socket.socket(socket.AF_UNIX) 97 sock.bind(sock_name) 98 99 srv = await start_server(sock) 100 101 try: 102 srv_socks = srv.sockets 103 self.assertTrue(srv_socks) 104 self.assertTrue(srv.is_serving()) 105 106 tasks = [] 107 for _ in range(TOTAL_CNT): 108 tasks.append(test_client(sock_name)) 109 110 await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT) 111 112 finally: 113 self.loop.call_soon(srv.close) 114 await srv.wait_closed() 115 116 # Check that the server cleaned-up proxy-sockets 117 for srv_sock in srv_socks: 118 self.assertEqual(srv_sock.fileno(), -1) 119 120 self.assertFalse(srv.is_serving()) 121 122 # asyncio doesn't cleanup the sock file 123 self.assertTrue(os.path.exists(sock_name)) 124 125 with self.subTest(func='start_unix_server(host, port)'): 126 self.loop.run_until_complete(start_server()) 127 self.assertEqual(CNT, TOTAL_CNT) 128 129 with self.subTest(func='start_unix_server(sock)'): 130 self.loop.run_until_complete(start_server_sock( 131 lambda sock: asyncio.start_unix_server( 132 handle_client, 133 None, 134 sock=sock))) 135 self.assertEqual(CNT, TOTAL_CNT) 136 137 with self.subTest(func='start_server(sock)'): 138 self.loop.run_until_complete(start_server_sock( 139 lambda sock: asyncio.start_server( 140 handle_client, 141 None, None, 142 sock=sock))) 143 self.assertEqual(CNT, TOTAL_CNT) 144 145 def test_create_unix_server_2(self): 146 with tempfile.TemporaryDirectory() as td: 147 sock_name = os.path.join(td, 'sock') 148 with open(sock_name, 'wt') as f: 149 f.write('x') 150 151 with self.assertRaisesRegex( 152 OSError, "Address '{}' is already in use".format( 153 sock_name)): 154 155 self.loop.run_until_complete( 156 self.loop.create_unix_server(object, sock_name)) 157 158 def test_create_unix_server_3(self): 159 with self.assertRaisesRegex( 160 ValueError, 'ssl_handshake_timeout is only meaningful'): 161 self.loop.run_until_complete( 162 self.loop.create_unix_server( 163 lambda: None, path='/tmp/a', ssl_handshake_timeout=10)) 164 165 def test_create_unix_server_existing_path_sock(self): 166 with self.unix_sock_name() as path: 167 sock = socket.socket(socket.AF_UNIX) 168 with sock: 169 sock.bind(path) 170 sock.listen(1) 171 172 # Check that no error is raised -- `path` is removed. 173 coro = self.loop.create_unix_server(lambda: None, path) 174 srv = self.loop.run_until_complete(coro) 175 srv.close() 176 self.loop.run_until_complete(srv.wait_closed()) 177 178 def test_create_unix_connection_open_unix_con_addr(self): 179 async def client(addr): 180 reader, writer = await asyncio.open_unix_connection(addr) 181 182 writer.write(b'AAAA') 183 self.assertEqual(await reader.readexactly(2), b'OK') 184 185 writer.write(b'BBBB') 186 self.assertEqual(await reader.readexactly(4), b'SPAM') 187 188 writer.close() 189 await self.wait_closed(writer) 190 191 self._test_create_unix_connection_1(client) 192 193 def test_create_unix_connection_open_unix_con_sock(self): 194 async def client(addr): 195 sock = socket.socket(socket.AF_UNIX) 196 sock.connect(addr) 197 reader, writer = await asyncio.open_unix_connection(sock=sock) 198 199 writer.write(b'AAAA') 200 self.assertEqual(await reader.readexactly(2), b'OK') 201 202 writer.write(b'BBBB') 203 self.assertEqual(await reader.readexactly(4), b'SPAM') 204 205 writer.close() 206 await self.wait_closed(writer) 207 208 self._test_create_unix_connection_1(client) 209 210 def test_create_unix_connection_open_con_sock(self): 211 async def client(addr): 212 sock = socket.socket(socket.AF_UNIX) 213 sock.connect(addr) 214 reader, writer = await asyncio.open_connection(sock=sock) 215 216 writer.write(b'AAAA') 217 self.assertEqual(await reader.readexactly(2), b'OK') 218 219 writer.write(b'BBBB') 220 self.assertEqual(await reader.readexactly(4), b'SPAM') 221 222 writer.close() 223 await self.wait_closed(writer) 224 225 self._test_create_unix_connection_1(client) 226 227 def _test_create_unix_connection_1(self, client): 228 CNT = 0 229 TOTAL_CNT = 100 230 231 def server(sock): 232 data = sock.recv_all(4) 233 self.assertEqual(data, b'AAAA') 234 sock.send(b'OK') 235 236 data = sock.recv_all(4) 237 self.assertEqual(data, b'BBBB') 238 sock.send(b'SPAM') 239 240 async def client_wrapper(addr): 241 await client(addr) 242 nonlocal CNT 243 CNT += 1 244 245 def run(coro): 246 nonlocal CNT 247 CNT = 0 248 249 with self.unix_server(server, 250 max_clients=TOTAL_CNT, 251 backlog=TOTAL_CNT) as srv: 252 tasks = [] 253 for _ in range(TOTAL_CNT): 254 tasks.append(coro(srv.addr)) 255 256 self.loop.run_until_complete(asyncio.gather(*tasks)) 257 258 # Give time for all transports to close. 259 self.loop.run_until_complete(asyncio.sleep(0.1)) 260 261 self.assertEqual(CNT, TOTAL_CNT) 262 263 run(client_wrapper) 264 265 def test_create_unix_connection_2(self): 266 with tempfile.NamedTemporaryFile() as tmp: 267 path = tmp.name 268 269 async def client(): 270 reader, writer = await asyncio.open_unix_connection(path) 271 writer.close() 272 await self.wait_closed(writer) 273 274 async def runner(): 275 with self.assertRaises(FileNotFoundError): 276 await client() 277 278 self.loop.run_until_complete(runner()) 279 280 def test_create_unix_connection_3(self): 281 CNT = 0 282 TOTAL_CNT = 100 283 284 def server(sock): 285 data = sock.recv_all(4) 286 self.assertEqual(data, b'AAAA') 287 sock.close() 288 289 async def client(addr): 290 reader, writer = await asyncio.open_unix_connection(addr) 291 292 sock = writer._transport.get_extra_info('socket') 293 self.assertEqual(sock.family, socket.AF_UNIX) 294 295 writer.write(b'AAAA') 296 297 with self.assertRaises(asyncio.IncompleteReadError): 298 await reader.readexactly(10) 299 300 writer.close() 301 await self.wait_closed(writer) 302 303 nonlocal CNT 304 CNT += 1 305 306 def run(coro): 307 nonlocal CNT 308 CNT = 0 309 310 with self.unix_server(server, 311 max_clients=TOTAL_CNT, 312 backlog=TOTAL_CNT) as srv: 313 tasks = [] 314 for _ in range(TOTAL_CNT): 315 tasks.append(coro(srv.addr)) 316 317 self.loop.run_until_complete(asyncio.gather(*tasks)) 318 319 self.assertEqual(CNT, TOTAL_CNT) 320 321 run(client) 322 323 def test_create_unix_connection_4(self): 324 sock = socket.socket(socket.AF_UNIX) 325 sock.close() 326 327 async def client(): 328 reader, writer = await asyncio.open_unix_connection(sock=sock) 329 writer.close() 330 await self.wait_closed(writer) 331 332 async def runner(): 333 with self.assertRaisesRegex(OSError, 'Bad file'): 334 await client() 335 336 self.loop.run_until_complete(runner()) 337 338 def test_create_unix_connection_5(self): 339 s1, s2 = socket.socketpair(socket.AF_UNIX) 340 341 excs = [] 342 343 class Proto(asyncio.Protocol): 344 def connection_lost(self, exc): 345 excs.append(exc) 346 347 proto = Proto() 348 349 async def client(): 350 t, _ = await self.loop.create_unix_connection( 351 lambda: proto, 352 None, 353 sock=s2) 354 355 t.write(b'AAAAA') 356 s1.close() 357 t.write(b'AAAAA') 358 await asyncio.sleep(0.1) 359 360 self.loop.run_until_complete(client()) 361 362 self.assertEqual(len(excs), 1) 363 self.assertIn(excs[0].__class__, 364 (BrokenPipeError, ConnectionResetError)) 365 366 def test_create_unix_connection_6(self): 367 with self.assertRaisesRegex( 368 ValueError, 'ssl_handshake_timeout is only meaningful'): 369 self.loop.run_until_complete( 370 self.loop.create_unix_connection( 371 lambda: None, path='/tmp/a', ssl_handshake_timeout=10)) 372 373 374class Test_UV_Unix(_TestUnix, tb.UVTestCase): 375 376 @unittest.skipUnless(hasattr(os, 'fspath'), 'no os.fspath()') 377 def test_create_unix_connection_pathlib(self): 378 async def run(addr): 379 t, _ = await self.loop.create_unix_connection( 380 asyncio.Protocol, addr) 381 t.close() 382 383 with self.unix_server(lambda sock: time.sleep(0.01)) as srv: 384 addr = pathlib.Path(srv.addr) 385 self.loop.run_until_complete(run(addr)) 386 387 @unittest.skipUnless(hasattr(os, 'fspath'), 'no os.fspath()') 388 def test_create_unix_server_pathlib(self): 389 with self.unix_sock_name() as srv_path: 390 srv_path = pathlib.Path(srv_path) 391 srv = self.loop.run_until_complete( 392 self.loop.create_unix_server(asyncio.Protocol, srv_path)) 393 srv.close() 394 self.loop.run_until_complete(srv.wait_closed()) 395 396 def test_transport_fromsock_get_extra_info(self): 397 # This tests is only for uvloop. asyncio should pass it 398 # too in Python 3.6. 399 400 async def test(sock): 401 t, _ = await self.loop.create_unix_connection( 402 asyncio.Protocol, 403 sock=sock) 404 405 sock = t.get_extra_info('socket') 406 self.assertIs(t.get_extra_info('socket'), sock) 407 408 with self.assertRaisesRegex(RuntimeError, 'is used by transport'): 409 self.loop.add_writer(sock.fileno(), lambda: None) 410 with self.assertRaisesRegex(RuntimeError, 'is used by transport'): 411 self.loop.remove_writer(sock.fileno()) 412 413 t.close() 414 415 s1, s2 = socket.socketpair(socket.AF_UNIX) 416 with s1, s2: 417 self.loop.run_until_complete(test(s1)) 418 419 def test_create_unix_server_path_dgram(self): 420 sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) 421 with sock: 422 coro = self.loop.create_unix_server(lambda: None, 423 sock=sock) 424 with self.assertRaisesRegex(ValueError, 425 'A UNIX Domain Stream.*was expected'): 426 self.loop.run_until_complete(coro) 427 428 @unittest.skipUnless(hasattr(socket, 'SOCK_NONBLOCK'), 429 'no socket.SOCK_NONBLOCK (linux only)') 430 def test_create_unix_server_path_stream_bittype(self): 431 sock = socket.socket( 432 socket.AF_UNIX, socket.SOCK_STREAM | socket.SOCK_NONBLOCK) 433 with tempfile.NamedTemporaryFile() as file: 434 fn = file.name 435 try: 436 with sock: 437 sock.bind(fn) 438 coro = self.loop.create_unix_server(lambda: None, path=None, 439 sock=sock) 440 srv = self.loop.run_until_complete(coro) 441 srv.close() 442 self.loop.run_until_complete(srv.wait_closed()) 443 finally: 444 os.unlink(fn) 445 446 @unittest.skipUnless(sys.platform.startswith('linux'), 'requires epoll') 447 def test_epollhup(self): 448 SIZE = 50 449 eof = False 450 done = False 451 recvd = b'' 452 453 class Proto(asyncio.BaseProtocol): 454 def connection_made(self, tr): 455 tr.write(b'hello') 456 self.data = bytearray(SIZE) 457 self.buf = memoryview(self.data) 458 459 def get_buffer(self, sizehint): 460 return self.buf 461 462 def buffer_updated(self, nbytes): 463 nonlocal recvd 464 recvd += self.buf[:nbytes] 465 466 def eof_received(self): 467 nonlocal eof 468 eof = True 469 470 def connection_lost(self, exc): 471 nonlocal done 472 done = exc 473 474 async def test(): 475 with tempfile.TemporaryDirectory() as td: 476 sock_name = os.path.join(td, 'sock') 477 srv = await self.loop.create_unix_server(Proto, sock_name) 478 479 s = socket.socket(socket.AF_UNIX) 480 with s: 481 s.setblocking(False) 482 await self.loop.sock_connect(s, sock_name) 483 d = await self.loop.sock_recv(s, 100) 484 self.assertEqual(d, b'hello') 485 486 # IMPORTANT: overflow recv buffer and close immediately 487 await self.loop.sock_sendall(s, b'a' * (SIZE + 1)) 488 489 srv.close() 490 await srv.wait_closed() 491 492 self.loop.run_until_complete(test()) 493 self.assertTrue(eof) 494 self.assertIsNone(done) 495 self.assertEqual(recvd, b'a' * (SIZE + 1)) 496 497 498class Test_AIO_Unix(_TestUnix, tb.AIOTestCase): 499 pass 500 501 502class _TestSSL(tb.SSLTestCase): 503 504 ONLYCERT = tb._cert_fullname(__file__, 'ssl_cert.pem') 505 ONLYKEY = tb._cert_fullname(__file__, 'ssl_key.pem') 506 507 def test_create_unix_server_ssl_1(self): 508 CNT = 0 # number of clients that were successful 509 TOTAL_CNT = 25 # total number of clients that test will create 510 TIMEOUT = 10.0 # timeout for this test 511 512 A_DATA = b'A' * 1024 * 1024 513 B_DATA = b'B' * 1024 * 1024 514 515 sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY) 516 client_sslctx = self._create_client_ssl_context() 517 518 clients = [] 519 520 async def handle_client(reader, writer): 521 nonlocal CNT 522 523 data = await reader.readexactly(len(A_DATA)) 524 self.assertEqual(data, A_DATA) 525 writer.write(b'OK') 526 527 data = await reader.readexactly(len(B_DATA)) 528 self.assertEqual(data, B_DATA) 529 writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')]) 530 531 await writer.drain() 532 writer.close() 533 534 CNT += 1 535 536 async def test_client(addr): 537 fut = asyncio.Future(loop=self.loop) 538 539 def prog(sock): 540 try: 541 sock.starttls(client_sslctx) 542 543 sock.connect(addr) 544 sock.send(A_DATA) 545 546 data = sock.recv_all(2) 547 self.assertEqual(data, b'OK') 548 549 sock.send(B_DATA) 550 data = sock.recv_all(4) 551 self.assertEqual(data, b'SPAM') 552 553 sock.close() 554 555 except Exception as ex: 556 self.loop.call_soon_threadsafe( 557 lambda ex=ex: 558 (fut.cancelled() or fut.set_exception(ex))) 559 else: 560 self.loop.call_soon_threadsafe( 561 lambda: (fut.cancelled() or fut.set_result(None))) 562 563 client = self.unix_client(prog) 564 client.start() 565 clients.append(client) 566 567 await fut 568 569 async def start_server(): 570 extras = dict(ssl_handshake_timeout=10.0) 571 572 with tempfile.TemporaryDirectory() as td: 573 sock_name = os.path.join(td, 'sock') 574 575 srv = await asyncio.start_unix_server( 576 handle_client, 577 sock_name, 578 ssl=sslctx, 579 **extras) 580 581 try: 582 tasks = [] 583 for _ in range(TOTAL_CNT): 584 tasks.append(test_client(sock_name)) 585 586 await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT) 587 588 finally: 589 self.loop.call_soon(srv.close) 590 await srv.wait_closed() 591 592 try: 593 with self._silence_eof_received_warning(): 594 self.loop.run_until_complete(start_server()) 595 except asyncio.TimeoutError: 596 if os.environ.get('TRAVIS_OS_NAME') == 'osx': 597 # XXX: figure out why this fails on macOS on Travis 598 raise unittest.SkipTest('unexplained error on Travis macOS') 599 else: 600 raise 601 602 self.assertEqual(CNT, TOTAL_CNT) 603 604 for client in clients: 605 client.stop() 606 607 def test_create_unix_connection_ssl_1(self): 608 CNT = 0 609 TOTAL_CNT = 25 610 611 A_DATA = b'A' * 1024 * 1024 612 B_DATA = b'B' * 1024 * 1024 613 614 sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY) 615 client_sslctx = self._create_client_ssl_context() 616 617 def server(sock): 618 sock.starttls(sslctx, server_side=True) 619 620 data = sock.recv_all(len(A_DATA)) 621 self.assertEqual(data, A_DATA) 622 sock.send(b'OK') 623 624 data = sock.recv_all(len(B_DATA)) 625 self.assertEqual(data, B_DATA) 626 sock.send(b'SPAM') 627 628 sock.close() 629 630 async def client(addr): 631 extras = dict(ssl_handshake_timeout=10.0) 632 633 reader, writer = await asyncio.open_unix_connection( 634 addr, 635 ssl=client_sslctx, 636 server_hostname='', 637 **extras) 638 639 writer.write(A_DATA) 640 self.assertEqual(await reader.readexactly(2), b'OK') 641 642 writer.write(B_DATA) 643 self.assertEqual(await reader.readexactly(4), b'SPAM') 644 645 nonlocal CNT 646 CNT += 1 647 648 writer.close() 649 await self.wait_closed(writer) 650 651 def run(coro): 652 nonlocal CNT 653 CNT = 0 654 655 with self.unix_server(server, 656 max_clients=TOTAL_CNT, 657 backlog=TOTAL_CNT) as srv: 658 tasks = [] 659 for _ in range(TOTAL_CNT): 660 tasks.append(coro(srv.addr)) 661 662 self.loop.run_until_complete(asyncio.gather(*tasks)) 663 664 self.assertEqual(CNT, TOTAL_CNT) 665 666 with self._silence_eof_received_warning(): 667 run(client) 668 669 670class Test_UV_UnixSSL(_TestSSL, tb.UVTestCase): 671 pass 672 673 674class Test_AIO_UnixSSL(_TestSSL, tb.AIOTestCase): 675 pass 676