1""" 2Test suite for socketserver. 3""" 4 5import contextlib 6import io 7import os 8import select 9import signal 10import socket 11import tempfile 12import threading 13import unittest 14import socketserver 15 16import test.support 17from test.support import reap_children, reap_threads, verbose 18 19 20test.support.requires("network") 21 22TEST_STR = b"hello world\n" 23HOST = test.support.HOST 24 25HAVE_UNIX_SOCKETS = hasattr(socket, "AF_UNIX") 26requires_unix_sockets = unittest.skipUnless(HAVE_UNIX_SOCKETS, 27 'requires Unix sockets') 28HAVE_FORKING = hasattr(os, "fork") 29requires_forking = unittest.skipUnless(HAVE_FORKING, 'requires forking') 30 31def signal_alarm(n): 32 """Call signal.alarm when it exists (i.e. not on Windows).""" 33 if hasattr(signal, 'alarm'): 34 signal.alarm(n) 35 36# Remember real select() to avoid interferences with mocking 37_real_select = select.select 38 39def receive(sock, n, timeout=20): 40 r, w, x = _real_select([sock], [], [], timeout) 41 if sock in r: 42 return sock.recv(n) 43 else: 44 raise RuntimeError("timed out on %r" % (sock,)) 45 46if HAVE_UNIX_SOCKETS and HAVE_FORKING: 47 class ForkingUnixStreamServer(socketserver.ForkingMixIn, 48 socketserver.UnixStreamServer): 49 pass 50 51 class ForkingUnixDatagramServer(socketserver.ForkingMixIn, 52 socketserver.UnixDatagramServer): 53 pass 54 55 56@contextlib.contextmanager 57def simple_subprocess(testcase): 58 """Tests that a custom child process is not waited on (Issue 1540386)""" 59 pid = os.fork() 60 if pid == 0: 61 # Don't raise an exception; it would be caught by the test harness. 62 os._exit(72) 63 try: 64 yield None 65 except: 66 raise 67 finally: 68 pid2, status = os.waitpid(pid, 0) 69 testcase.assertEqual(pid2, pid) 70 testcase.assertEqual(72 << 8, status) 71 72 73class SocketServerTest(unittest.TestCase): 74 """Test all socket servers.""" 75 76 def setUp(self): 77 signal_alarm(60) # Kill deadlocks after 60 seconds. 78 self.port_seed = 0 79 self.test_files = [] 80 81 def tearDown(self): 82 signal_alarm(0) # Didn't deadlock. 83 reap_children() 84 85 for fn in self.test_files: 86 try: 87 os.remove(fn) 88 except OSError: 89 pass 90 self.test_files[:] = [] 91 92 def pickaddr(self, proto): 93 if proto == socket.AF_INET: 94 return (HOST, 0) 95 else: 96 # XXX: We need a way to tell AF_UNIX to pick its own name 97 # like AF_INET provides port==0. 98 dir = None 99 fn = tempfile.mktemp(prefix='unix_socket.', dir=dir) 100 self.test_files.append(fn) 101 return fn 102 103 def make_server(self, addr, svrcls, hdlrbase): 104 class MyServer(svrcls): 105 def handle_error(self, request, client_address): 106 self.close_request(request) 107 raise 108 109 class MyHandler(hdlrbase): 110 def handle(self): 111 line = self.rfile.readline() 112 self.wfile.write(line) 113 114 if verbose: print("creating server") 115 try: 116 server = MyServer(addr, MyHandler) 117 except PermissionError as e: 118 # Issue 29184: cannot bind() a Unix socket on Android. 119 self.skipTest('Cannot create server (%s, %s): %s' % 120 (svrcls, addr, e)) 121 self.assertEqual(server.server_address, server.socket.getsockname()) 122 return server 123 124 @reap_threads 125 def run_server(self, svrcls, hdlrbase, testfunc): 126 server = self.make_server(self.pickaddr(svrcls.address_family), 127 svrcls, hdlrbase) 128 # We had the OS pick a port, so pull the real address out of 129 # the server. 130 addr = server.server_address 131 if verbose: 132 print("ADDR =", addr) 133 print("CLASS =", svrcls) 134 135 t = threading.Thread( 136 name='%s serving' % svrcls, 137 target=server.serve_forever, 138 # Short poll interval to make the test finish quickly. 139 # Time between requests is short enough that we won't wake 140 # up spuriously too many times. 141 kwargs={'poll_interval':0.01}) 142 t.daemon = True # In case this function raises. 143 t.start() 144 if verbose: print("server running") 145 for i in range(3): 146 if verbose: print("test client", i) 147 testfunc(svrcls.address_family, addr) 148 if verbose: print("waiting for server") 149 server.shutdown() 150 t.join() 151 server.server_close() 152 self.assertEqual(-1, server.socket.fileno()) 153 if HAVE_FORKING and isinstance(server, socketserver.ForkingMixIn): 154 # bpo-31151: Check that ForkingMixIn.server_close() waits until 155 # all children completed 156 self.assertFalse(server.active_children) 157 if verbose: print("done") 158 159 def stream_examine(self, proto, addr): 160 with socket.socket(proto, socket.SOCK_STREAM) as s: 161 s.connect(addr) 162 s.sendall(TEST_STR) 163 buf = data = receive(s, 100) 164 while data and b'\n' not in buf: 165 data = receive(s, 100) 166 buf += data 167 self.assertEqual(buf, TEST_STR) 168 169 def dgram_examine(self, proto, addr): 170 with socket.socket(proto, socket.SOCK_DGRAM) as s: 171 if HAVE_UNIX_SOCKETS and proto == socket.AF_UNIX: 172 s.bind(self.pickaddr(proto)) 173 s.sendto(TEST_STR, addr) 174 buf = data = receive(s, 100) 175 while data and b'\n' not in buf: 176 data = receive(s, 100) 177 buf += data 178 self.assertEqual(buf, TEST_STR) 179 180 def test_TCPServer(self): 181 self.run_server(socketserver.TCPServer, 182 socketserver.StreamRequestHandler, 183 self.stream_examine) 184 185 def test_ThreadingTCPServer(self): 186 self.run_server(socketserver.ThreadingTCPServer, 187 socketserver.StreamRequestHandler, 188 self.stream_examine) 189 190 @requires_forking 191 def test_ForkingTCPServer(self): 192 with simple_subprocess(self): 193 self.run_server(socketserver.ForkingTCPServer, 194 socketserver.StreamRequestHandler, 195 self.stream_examine) 196 197 @requires_unix_sockets 198 def test_UnixStreamServer(self): 199 self.run_server(socketserver.UnixStreamServer, 200 socketserver.StreamRequestHandler, 201 self.stream_examine) 202 203 @requires_unix_sockets 204 def test_ThreadingUnixStreamServer(self): 205 self.run_server(socketserver.ThreadingUnixStreamServer, 206 socketserver.StreamRequestHandler, 207 self.stream_examine) 208 209 @requires_unix_sockets 210 @requires_forking 211 def test_ForkingUnixStreamServer(self): 212 with simple_subprocess(self): 213 self.run_server(ForkingUnixStreamServer, 214 socketserver.StreamRequestHandler, 215 self.stream_examine) 216 217 def test_UDPServer(self): 218 self.run_server(socketserver.UDPServer, 219 socketserver.DatagramRequestHandler, 220 self.dgram_examine) 221 222 def test_ThreadingUDPServer(self): 223 self.run_server(socketserver.ThreadingUDPServer, 224 socketserver.DatagramRequestHandler, 225 self.dgram_examine) 226 227 @requires_forking 228 def test_ForkingUDPServer(self): 229 with simple_subprocess(self): 230 self.run_server(socketserver.ForkingUDPServer, 231 socketserver.DatagramRequestHandler, 232 self.dgram_examine) 233 234 @requires_unix_sockets 235 def test_UnixDatagramServer(self): 236 self.run_server(socketserver.UnixDatagramServer, 237 socketserver.DatagramRequestHandler, 238 self.dgram_examine) 239 240 @requires_unix_sockets 241 def test_ThreadingUnixDatagramServer(self): 242 self.run_server(socketserver.ThreadingUnixDatagramServer, 243 socketserver.DatagramRequestHandler, 244 self.dgram_examine) 245 246 @requires_unix_sockets 247 @requires_forking 248 def test_ForkingUnixDatagramServer(self): 249 self.run_server(ForkingUnixDatagramServer, 250 socketserver.DatagramRequestHandler, 251 self.dgram_examine) 252 253 @reap_threads 254 def test_shutdown(self): 255 # Issue #2302: shutdown() should always succeed in making an 256 # other thread leave serve_forever(). 257 class MyServer(socketserver.TCPServer): 258 pass 259 260 class MyHandler(socketserver.StreamRequestHandler): 261 pass 262 263 threads = [] 264 for i in range(20): 265 s = MyServer((HOST, 0), MyHandler) 266 t = threading.Thread( 267 name='MyServer serving', 268 target=s.serve_forever, 269 kwargs={'poll_interval':0.01}) 270 t.daemon = True # In case this function raises. 271 threads.append((t, s)) 272 for t, s in threads: 273 t.start() 274 s.shutdown() 275 for t, s in threads: 276 t.join() 277 s.server_close() 278 279 def test_close_immediately(self): 280 class MyServer(socketserver.ThreadingMixIn, socketserver.TCPServer): 281 pass 282 283 server = MyServer((HOST, 0), lambda: None) 284 server.server_close() 285 286 def test_tcpserver_bind_leak(self): 287 # Issue #22435: the server socket wouldn't be closed if bind()/listen() 288 # failed. 289 # Create many servers for which bind() will fail, to see if this result 290 # in FD exhaustion. 291 for i in range(1024): 292 with self.assertRaises(OverflowError): 293 socketserver.TCPServer((HOST, -1), 294 socketserver.StreamRequestHandler) 295 296 def test_context_manager(self): 297 with socketserver.TCPServer((HOST, 0), 298 socketserver.StreamRequestHandler) as server: 299 pass 300 self.assertEqual(-1, server.socket.fileno()) 301 302 303class ErrorHandlerTest(unittest.TestCase): 304 """Test that the servers pass normal exceptions from the handler to 305 handle_error(), and that exiting exceptions like SystemExit and 306 KeyboardInterrupt are not passed.""" 307 308 def tearDown(self): 309 test.support.unlink(test.support.TESTFN) 310 311 def test_sync_handled(self): 312 BaseErrorTestServer(ValueError) 313 self.check_result(handled=True) 314 315 def test_sync_not_handled(self): 316 with self.assertRaises(SystemExit): 317 BaseErrorTestServer(SystemExit) 318 self.check_result(handled=False) 319 320 def test_threading_handled(self): 321 ThreadingErrorTestServer(ValueError) 322 self.check_result(handled=True) 323 324 def test_threading_not_handled(self): 325 ThreadingErrorTestServer(SystemExit) 326 self.check_result(handled=False) 327 328 @requires_forking 329 def test_forking_handled(self): 330 ForkingErrorTestServer(ValueError) 331 self.check_result(handled=True) 332 333 @requires_forking 334 def test_forking_not_handled(self): 335 ForkingErrorTestServer(SystemExit) 336 self.check_result(handled=False) 337 338 def check_result(self, handled): 339 with open(test.support.TESTFN) as log: 340 expected = 'Handler called\n' + 'Error handled\n' * handled 341 self.assertEqual(log.read(), expected) 342 343 344class BaseErrorTestServer(socketserver.TCPServer): 345 def __init__(self, exception): 346 self.exception = exception 347 super().__init__((HOST, 0), BadHandler) 348 with socket.create_connection(self.server_address): 349 pass 350 try: 351 self.handle_request() 352 finally: 353 self.server_close() 354 self.wait_done() 355 356 def handle_error(self, request, client_address): 357 with open(test.support.TESTFN, 'a') as log: 358 log.write('Error handled\n') 359 360 def wait_done(self): 361 pass 362 363 364class BadHandler(socketserver.BaseRequestHandler): 365 def handle(self): 366 with open(test.support.TESTFN, 'a') as log: 367 log.write('Handler called\n') 368 raise self.server.exception('Test error') 369 370 371class ThreadingErrorTestServer(socketserver.ThreadingMixIn, 372 BaseErrorTestServer): 373 def __init__(self, *pos, **kw): 374 self.done = threading.Event() 375 super().__init__(*pos, **kw) 376 377 def shutdown_request(self, *pos, **kw): 378 super().shutdown_request(*pos, **kw) 379 self.done.set() 380 381 def wait_done(self): 382 self.done.wait() 383 384 385if HAVE_FORKING: 386 class ForkingErrorTestServer(socketserver.ForkingMixIn, BaseErrorTestServer): 387 pass 388 389 390class SocketWriterTest(unittest.TestCase): 391 def test_basics(self): 392 class Handler(socketserver.StreamRequestHandler): 393 def handle(self): 394 self.server.wfile = self.wfile 395 self.server.wfile_fileno = self.wfile.fileno() 396 self.server.request_fileno = self.request.fileno() 397 398 server = socketserver.TCPServer((HOST, 0), Handler) 399 self.addCleanup(server.server_close) 400 s = socket.socket( 401 server.address_family, socket.SOCK_STREAM, socket.IPPROTO_TCP) 402 with s: 403 s.connect(server.server_address) 404 server.handle_request() 405 self.assertIsInstance(server.wfile, io.BufferedIOBase) 406 self.assertEqual(server.wfile_fileno, server.request_fileno) 407 408 def test_write(self): 409 # Test that wfile.write() sends data immediately, and that it does 410 # not truncate sends when interrupted by a Unix signal 411 pthread_kill = test.support.get_attribute(signal, 'pthread_kill') 412 413 class Handler(socketserver.StreamRequestHandler): 414 def handle(self): 415 self.server.sent1 = self.wfile.write(b'write data\n') 416 # Should be sent immediately, without requiring flush() 417 self.server.received = self.rfile.readline() 418 big_chunk = b'\0' * test.support.SOCK_MAX_SIZE 419 self.server.sent2 = self.wfile.write(big_chunk) 420 421 server = socketserver.TCPServer((HOST, 0), Handler) 422 self.addCleanup(server.server_close) 423 interrupted = threading.Event() 424 425 def signal_handler(signum, frame): 426 interrupted.set() 427 428 original = signal.signal(signal.SIGUSR1, signal_handler) 429 self.addCleanup(signal.signal, signal.SIGUSR1, original) 430 response1 = None 431 received2 = None 432 main_thread = threading.get_ident() 433 434 def run_client(): 435 s = socket.socket(server.address_family, socket.SOCK_STREAM, 436 socket.IPPROTO_TCP) 437 with s, s.makefile('rb') as reader: 438 s.connect(server.server_address) 439 nonlocal response1 440 response1 = reader.readline() 441 s.sendall(b'client response\n') 442 443 reader.read(100) 444 # The main thread should now be blocking in a send() syscall. 445 # But in theory, it could get interrupted by other signals, 446 # and then retried. So keep sending the signal in a loop, in 447 # case an earlier signal happens to be delivered at an 448 # inconvenient moment. 449 while True: 450 pthread_kill(main_thread, signal.SIGUSR1) 451 if interrupted.wait(timeout=float(1)): 452 break 453 nonlocal received2 454 received2 = len(reader.read()) 455 456 background = threading.Thread(target=run_client) 457 background.start() 458 server.handle_request() 459 background.join() 460 self.assertEqual(server.sent1, len(response1)) 461 self.assertEqual(response1, b'write data\n') 462 self.assertEqual(server.received, b'client response\n') 463 self.assertEqual(server.sent2, test.support.SOCK_MAX_SIZE) 464 self.assertEqual(received2, test.support.SOCK_MAX_SIZE - 100) 465 466 467class MiscTestCase(unittest.TestCase): 468 469 def test_all(self): 470 # objects defined in the module should be in __all__ 471 expected = [] 472 for name in dir(socketserver): 473 if not name.startswith('_'): 474 mod_object = getattr(socketserver, name) 475 if getattr(mod_object, '__module__', None) == 'socketserver': 476 expected.append(name) 477 self.assertCountEqual(socketserver.__all__, expected) 478 479 def test_shutdown_request_called_if_verify_request_false(self): 480 # Issue #26309: BaseServer should call shutdown_request even if 481 # verify_request is False 482 483 class MyServer(socketserver.TCPServer): 484 def verify_request(self, request, client_address): 485 return False 486 487 shutdown_called = 0 488 def shutdown_request(self, request): 489 self.shutdown_called += 1 490 socketserver.TCPServer.shutdown_request(self, request) 491 492 server = MyServer((HOST, 0), socketserver.StreamRequestHandler) 493 s = socket.socket(server.address_family, socket.SOCK_STREAM) 494 s.connect(server.server_address) 495 s.close() 496 server.handle_request() 497 self.assertEqual(server.shutdown_called, 1) 498 server.server_close() 499 500 def test_threads_reaped(self): 501 """ 502 In #37193, users reported a memory leak 503 due to the saving of every request thread. Ensure that 504 not all threads are kept forever. 505 """ 506 class MyServer(socketserver.ThreadingMixIn, socketserver.TCPServer): 507 pass 508 509 server = MyServer((HOST, 0), socketserver.StreamRequestHandler) 510 for n in range(10): 511 with socket.create_connection(server.server_address): 512 server.handle_request() 513 self.assertLess(len(server._threads), 10) 514 server.server_close() 515 516 517if __name__ == "__main__": 518 unittest.main() 519