1""" 2Test suite for socketserver. 3""" 4 5import contextlib 6import io 7import os 8import select 9import signal 10import socket 11import tempfile 12import threading 13import unittest 14import socketserver 15 16import test.support 17from test.support import reap_children, verbose 18from test.support import os_helper 19from test.support import socket_helper 20from test.support import threading_helper 21 22 23test.support.requires("network") 24 25TEST_STR = b"hello world\n" 26HOST = socket_helper.HOST 27 28HAVE_UNIX_SOCKETS = hasattr(socket, "AF_UNIX") 29requires_unix_sockets = unittest.skipUnless(HAVE_UNIX_SOCKETS, 30 'requires Unix sockets') 31HAVE_FORKING = hasattr(os, "fork") 32requires_forking = unittest.skipUnless(HAVE_FORKING, 'requires forking') 33 34def signal_alarm(n): 35 """Call signal.alarm when it exists (i.e. not on Windows).""" 36 if hasattr(signal, 'alarm'): 37 signal.alarm(n) 38 39# Remember real select() to avoid interferences with mocking 40_real_select = select.select 41 42def receive(sock, n, timeout=test.support.SHORT_TIMEOUT): 43 r, w, x = _real_select([sock], [], [], timeout) 44 if sock in r: 45 return sock.recv(n) 46 else: 47 raise RuntimeError("timed out on %r" % (sock,)) 48 49if HAVE_UNIX_SOCKETS and HAVE_FORKING: 50 class ForkingUnixStreamServer(socketserver.ForkingMixIn, 51 socketserver.UnixStreamServer): 52 pass 53 54 class ForkingUnixDatagramServer(socketserver.ForkingMixIn, 55 socketserver.UnixDatagramServer): 56 pass 57 58 59@contextlib.contextmanager 60def simple_subprocess(testcase): 61 """Tests that a custom child process is not waited on (Issue 1540386)""" 62 pid = os.fork() 63 if pid == 0: 64 # Don't raise an exception; it would be caught by the test harness. 65 os._exit(72) 66 try: 67 yield None 68 except: 69 raise 70 finally: 71 test.support.wait_process(pid, exitcode=72) 72 73 74class SocketServerTest(unittest.TestCase): 75 """Test all socket servers.""" 76 77 def setUp(self): 78 signal_alarm(60) # Kill deadlocks after 60 seconds. 79 self.port_seed = 0 80 self.test_files = [] 81 82 def tearDown(self): 83 signal_alarm(0) # Didn't deadlock. 84 reap_children() 85 86 for fn in self.test_files: 87 try: 88 os.remove(fn) 89 except OSError: 90 pass 91 self.test_files[:] = [] 92 93 def pickaddr(self, proto): 94 if proto == socket.AF_INET: 95 return (HOST, 0) 96 else: 97 # XXX: We need a way to tell AF_UNIX to pick its own name 98 # like AF_INET provides port==0. 99 dir = None 100 fn = tempfile.mktemp(prefix='unix_socket.', dir=dir) 101 self.test_files.append(fn) 102 return fn 103 104 def make_server(self, addr, svrcls, hdlrbase): 105 class MyServer(svrcls): 106 def handle_error(self, request, client_address): 107 self.close_request(request) 108 raise 109 110 class MyHandler(hdlrbase): 111 def handle(self): 112 line = self.rfile.readline() 113 self.wfile.write(line) 114 115 if verbose: print("creating server") 116 try: 117 server = MyServer(addr, MyHandler) 118 except PermissionError as e: 119 # Issue 29184: cannot bind() a Unix socket on Android. 120 self.skipTest('Cannot create server (%s, %s): %s' % 121 (svrcls, addr, e)) 122 self.assertEqual(server.server_address, server.socket.getsockname()) 123 return server 124 125 @threading_helper.reap_threads 126 def run_server(self, svrcls, hdlrbase, testfunc): 127 server = self.make_server(self.pickaddr(svrcls.address_family), 128 svrcls, hdlrbase) 129 # We had the OS pick a port, so pull the real address out of 130 # the server. 131 addr = server.server_address 132 if verbose: 133 print("ADDR =", addr) 134 print("CLASS =", svrcls) 135 136 t = threading.Thread( 137 name='%s serving' % svrcls, 138 target=server.serve_forever, 139 # Short poll interval to make the test finish quickly. 140 # Time between requests is short enough that we won't wake 141 # up spuriously too many times. 142 kwargs={'poll_interval':0.01}) 143 t.daemon = True # In case this function raises. 144 t.start() 145 if verbose: print("server running") 146 for i in range(3): 147 if verbose: print("test client", i) 148 testfunc(svrcls.address_family, addr) 149 if verbose: print("waiting for server") 150 server.shutdown() 151 t.join() 152 server.server_close() 153 self.assertEqual(-1, server.socket.fileno()) 154 if HAVE_FORKING and isinstance(server, socketserver.ForkingMixIn): 155 # bpo-31151: Check that ForkingMixIn.server_close() waits until 156 # all children completed 157 self.assertFalse(server.active_children) 158 if verbose: print("done") 159 160 def stream_examine(self, proto, addr): 161 with socket.socket(proto, socket.SOCK_STREAM) as s: 162 s.connect(addr) 163 s.sendall(TEST_STR) 164 buf = data = receive(s, 100) 165 while data and b'\n' not in buf: 166 data = receive(s, 100) 167 buf += data 168 self.assertEqual(buf, TEST_STR) 169 170 def dgram_examine(self, proto, addr): 171 with socket.socket(proto, socket.SOCK_DGRAM) as s: 172 if HAVE_UNIX_SOCKETS and proto == socket.AF_UNIX: 173 s.bind(self.pickaddr(proto)) 174 s.sendto(TEST_STR, addr) 175 buf = data = receive(s, 100) 176 while data and b'\n' not in buf: 177 data = receive(s, 100) 178 buf += data 179 self.assertEqual(buf, TEST_STR) 180 181 def test_TCPServer(self): 182 self.run_server(socketserver.TCPServer, 183 socketserver.StreamRequestHandler, 184 self.stream_examine) 185 186 def test_ThreadingTCPServer(self): 187 self.run_server(socketserver.ThreadingTCPServer, 188 socketserver.StreamRequestHandler, 189 self.stream_examine) 190 191 @requires_forking 192 def test_ForkingTCPServer(self): 193 with simple_subprocess(self): 194 self.run_server(socketserver.ForkingTCPServer, 195 socketserver.StreamRequestHandler, 196 self.stream_examine) 197 198 @requires_unix_sockets 199 def test_UnixStreamServer(self): 200 self.run_server(socketserver.UnixStreamServer, 201 socketserver.StreamRequestHandler, 202 self.stream_examine) 203 204 @requires_unix_sockets 205 def test_ThreadingUnixStreamServer(self): 206 self.run_server(socketserver.ThreadingUnixStreamServer, 207 socketserver.StreamRequestHandler, 208 self.stream_examine) 209 210 @requires_unix_sockets 211 @requires_forking 212 def test_ForkingUnixStreamServer(self): 213 with simple_subprocess(self): 214 self.run_server(ForkingUnixStreamServer, 215 socketserver.StreamRequestHandler, 216 self.stream_examine) 217 218 def test_UDPServer(self): 219 self.run_server(socketserver.UDPServer, 220 socketserver.DatagramRequestHandler, 221 self.dgram_examine) 222 223 def test_ThreadingUDPServer(self): 224 self.run_server(socketserver.ThreadingUDPServer, 225 socketserver.DatagramRequestHandler, 226 self.dgram_examine) 227 228 @requires_forking 229 def test_ForkingUDPServer(self): 230 with simple_subprocess(self): 231 self.run_server(socketserver.ForkingUDPServer, 232 socketserver.DatagramRequestHandler, 233 self.dgram_examine) 234 235 @requires_unix_sockets 236 def test_UnixDatagramServer(self): 237 self.run_server(socketserver.UnixDatagramServer, 238 socketserver.DatagramRequestHandler, 239 self.dgram_examine) 240 241 @requires_unix_sockets 242 def test_ThreadingUnixDatagramServer(self): 243 self.run_server(socketserver.ThreadingUnixDatagramServer, 244 socketserver.DatagramRequestHandler, 245 self.dgram_examine) 246 247 @requires_unix_sockets 248 @requires_forking 249 def test_ForkingUnixDatagramServer(self): 250 self.run_server(ForkingUnixDatagramServer, 251 socketserver.DatagramRequestHandler, 252 self.dgram_examine) 253 254 @threading_helper.reap_threads 255 def test_shutdown(self): 256 # Issue #2302: shutdown() should always succeed in making an 257 # other thread leave serve_forever(). 258 class MyServer(socketserver.TCPServer): 259 pass 260 261 class MyHandler(socketserver.StreamRequestHandler): 262 pass 263 264 threads = [] 265 for i in range(20): 266 s = MyServer((HOST, 0), MyHandler) 267 t = threading.Thread( 268 name='MyServer serving', 269 target=s.serve_forever, 270 kwargs={'poll_interval':0.01}) 271 t.daemon = True # In case this function raises. 272 threads.append((t, s)) 273 for t, s in threads: 274 t.start() 275 s.shutdown() 276 for t, s in threads: 277 t.join() 278 s.server_close() 279 280 def test_close_immediately(self): 281 class MyServer(socketserver.ThreadingMixIn, socketserver.TCPServer): 282 pass 283 284 server = MyServer((HOST, 0), lambda: None) 285 server.server_close() 286 287 def test_tcpserver_bind_leak(self): 288 # Issue #22435: the server socket wouldn't be closed if bind()/listen() 289 # failed. 290 # Create many servers for which bind() will fail, to see if this result 291 # in FD exhaustion. 292 for i in range(1024): 293 with self.assertRaises(OverflowError): 294 socketserver.TCPServer((HOST, -1), 295 socketserver.StreamRequestHandler) 296 297 def test_context_manager(self): 298 with socketserver.TCPServer((HOST, 0), 299 socketserver.StreamRequestHandler) as server: 300 pass 301 self.assertEqual(-1, server.socket.fileno()) 302 303 304class ErrorHandlerTest(unittest.TestCase): 305 """Test that the servers pass normal exceptions from the handler to 306 handle_error(), and that exiting exceptions like SystemExit and 307 KeyboardInterrupt are not passed.""" 308 309 def tearDown(self): 310 os_helper.unlink(os_helper.TESTFN) 311 312 def test_sync_handled(self): 313 BaseErrorTestServer(ValueError) 314 self.check_result(handled=True) 315 316 def test_sync_not_handled(self): 317 with self.assertRaises(SystemExit): 318 BaseErrorTestServer(SystemExit) 319 self.check_result(handled=False) 320 321 def test_threading_handled(self): 322 ThreadingErrorTestServer(ValueError) 323 self.check_result(handled=True) 324 325 def test_threading_not_handled(self): 326 with threading_helper.catch_threading_exception() as cm: 327 ThreadingErrorTestServer(SystemExit) 328 self.check_result(handled=False) 329 330 self.assertIs(cm.exc_type, SystemExit) 331 332 @requires_forking 333 def test_forking_handled(self): 334 ForkingErrorTestServer(ValueError) 335 self.check_result(handled=True) 336 337 @requires_forking 338 def test_forking_not_handled(self): 339 ForkingErrorTestServer(SystemExit) 340 self.check_result(handled=False) 341 342 def check_result(self, handled): 343 with open(os_helper.TESTFN) as log: 344 expected = 'Handler called\n' + 'Error handled\n' * handled 345 self.assertEqual(log.read(), expected) 346 347 348class BaseErrorTestServer(socketserver.TCPServer): 349 def __init__(self, exception): 350 self.exception = exception 351 super().__init__((HOST, 0), BadHandler) 352 with socket.create_connection(self.server_address): 353 pass 354 try: 355 self.handle_request() 356 finally: 357 self.server_close() 358 self.wait_done() 359 360 def handle_error(self, request, client_address): 361 with open(os_helper.TESTFN, 'a') as log: 362 log.write('Error handled\n') 363 364 def wait_done(self): 365 pass 366 367 368class BadHandler(socketserver.BaseRequestHandler): 369 def handle(self): 370 with open(os_helper.TESTFN, 'a') as log: 371 log.write('Handler called\n') 372 raise self.server.exception('Test error') 373 374 375class ThreadingErrorTestServer(socketserver.ThreadingMixIn, 376 BaseErrorTestServer): 377 def __init__(self, *pos, **kw): 378 self.done = threading.Event() 379 super().__init__(*pos, **kw) 380 381 def shutdown_request(self, *pos, **kw): 382 super().shutdown_request(*pos, **kw) 383 self.done.set() 384 385 def wait_done(self): 386 self.done.wait() 387 388 389if HAVE_FORKING: 390 class ForkingErrorTestServer(socketserver.ForkingMixIn, BaseErrorTestServer): 391 pass 392 393 394class SocketWriterTest(unittest.TestCase): 395 def test_basics(self): 396 class Handler(socketserver.StreamRequestHandler): 397 def handle(self): 398 self.server.wfile = self.wfile 399 self.server.wfile_fileno = self.wfile.fileno() 400 self.server.request_fileno = self.request.fileno() 401 402 server = socketserver.TCPServer((HOST, 0), Handler) 403 self.addCleanup(server.server_close) 404 s = socket.socket( 405 server.address_family, socket.SOCK_STREAM, socket.IPPROTO_TCP) 406 with s: 407 s.connect(server.server_address) 408 server.handle_request() 409 self.assertIsInstance(server.wfile, io.BufferedIOBase) 410 self.assertEqual(server.wfile_fileno, server.request_fileno) 411 412 def test_write(self): 413 # Test that wfile.write() sends data immediately, and that it does 414 # not truncate sends when interrupted by a Unix signal 415 pthread_kill = test.support.get_attribute(signal, 'pthread_kill') 416 417 class Handler(socketserver.StreamRequestHandler): 418 def handle(self): 419 self.server.sent1 = self.wfile.write(b'write data\n') 420 # Should be sent immediately, without requiring flush() 421 self.server.received = self.rfile.readline() 422 big_chunk = b'\0' * test.support.SOCK_MAX_SIZE 423 self.server.sent2 = self.wfile.write(big_chunk) 424 425 server = socketserver.TCPServer((HOST, 0), Handler) 426 self.addCleanup(server.server_close) 427 interrupted = threading.Event() 428 429 def signal_handler(signum, frame): 430 interrupted.set() 431 432 original = signal.signal(signal.SIGUSR1, signal_handler) 433 self.addCleanup(signal.signal, signal.SIGUSR1, original) 434 response1 = None 435 received2 = None 436 main_thread = threading.get_ident() 437 438 def run_client(): 439 s = socket.socket(server.address_family, socket.SOCK_STREAM, 440 socket.IPPROTO_TCP) 441 with s, s.makefile('rb') as reader: 442 s.connect(server.server_address) 443 nonlocal response1 444 response1 = reader.readline() 445 s.sendall(b'client response\n') 446 447 reader.read(100) 448 # The main thread should now be blocking in a send() syscall. 449 # But in theory, it could get interrupted by other signals, 450 # and then retried. So keep sending the signal in a loop, in 451 # case an earlier signal happens to be delivered at an 452 # inconvenient moment. 453 while True: 454 pthread_kill(main_thread, signal.SIGUSR1) 455 if interrupted.wait(timeout=float(1)): 456 break 457 nonlocal received2 458 received2 = len(reader.read()) 459 460 background = threading.Thread(target=run_client) 461 background.start() 462 server.handle_request() 463 background.join() 464 self.assertEqual(server.sent1, len(response1)) 465 self.assertEqual(response1, b'write data\n') 466 self.assertEqual(server.received, b'client response\n') 467 self.assertEqual(server.sent2, test.support.SOCK_MAX_SIZE) 468 self.assertEqual(received2, test.support.SOCK_MAX_SIZE - 100) 469 470 471class MiscTestCase(unittest.TestCase): 472 473 def test_all(self): 474 # objects defined in the module should be in __all__ 475 expected = [] 476 for name in dir(socketserver): 477 if not name.startswith('_'): 478 mod_object = getattr(socketserver, name) 479 if getattr(mod_object, '__module__', None) == 'socketserver': 480 expected.append(name) 481 self.assertCountEqual(socketserver.__all__, expected) 482 483 def test_shutdown_request_called_if_verify_request_false(self): 484 # Issue #26309: BaseServer should call shutdown_request even if 485 # verify_request is False 486 487 class MyServer(socketserver.TCPServer): 488 def verify_request(self, request, client_address): 489 return False 490 491 shutdown_called = 0 492 def shutdown_request(self, request): 493 self.shutdown_called += 1 494 socketserver.TCPServer.shutdown_request(self, request) 495 496 server = MyServer((HOST, 0), socketserver.StreamRequestHandler) 497 s = socket.socket(server.address_family, socket.SOCK_STREAM) 498 s.connect(server.server_address) 499 s.close() 500 server.handle_request() 501 self.assertEqual(server.shutdown_called, 1) 502 server.server_close() 503 504 def test_threads_reaped(self): 505 """ 506 In #37193, users reported a memory leak 507 due to the saving of every request thread. Ensure that 508 not all threads are kept forever. 509 """ 510 class MyServer(socketserver.ThreadingMixIn, socketserver.TCPServer): 511 pass 512 513 server = MyServer((HOST, 0), socketserver.StreamRequestHandler) 514 for n in range(10): 515 with socket.create_connection(server.server_address): 516 server.handle_request() 517 self.assertLess(len(server._threads), 10) 518 server.server_close() 519 520 521if __name__ == "__main__": 522 unittest.main() 523