1"""
2Test suite for socketserver.
3"""
4
5import contextlib
6import io
7import os
8import select
9import signal
10import socket
11import tempfile
12import threading
13import unittest
14import socketserver
15
16import test.support
17from test.support import reap_children, reap_threads, verbose
18
19
20test.support.requires("network")
21
22TEST_STR = b"hello world\n"
23HOST = test.support.HOST
24
25HAVE_UNIX_SOCKETS = hasattr(socket, "AF_UNIX")
26requires_unix_sockets = unittest.skipUnless(HAVE_UNIX_SOCKETS,
27                                            'requires Unix sockets')
28HAVE_FORKING = hasattr(os, "fork")
29requires_forking = unittest.skipUnless(HAVE_FORKING, 'requires forking')
30
31def signal_alarm(n):
32    """Call signal.alarm when it exists (i.e. not on Windows)."""
33    if hasattr(signal, 'alarm'):
34        signal.alarm(n)
35
36# Remember real select() to avoid interferences with mocking
37_real_select = select.select
38
39def receive(sock, n, timeout=20):
40    r, w, x = _real_select([sock], [], [], timeout)
41    if sock in r:
42        return sock.recv(n)
43    else:
44        raise RuntimeError("timed out on %r" % (sock,))
45
46if HAVE_UNIX_SOCKETS and HAVE_FORKING:
47    class ForkingUnixStreamServer(socketserver.ForkingMixIn,
48                                  socketserver.UnixStreamServer):
49        pass
50
51    class ForkingUnixDatagramServer(socketserver.ForkingMixIn,
52                                    socketserver.UnixDatagramServer):
53        pass
54
55
56@contextlib.contextmanager
57def simple_subprocess(testcase):
58    """Tests that a custom child process is not waited on (Issue 1540386)"""
59    pid = os.fork()
60    if pid == 0:
61        # Don't raise an exception; it would be caught by the test harness.
62        os._exit(72)
63    try:
64        yield None
65    except:
66        raise
67    finally:
68        pid2, status = os.waitpid(pid, 0)
69        testcase.assertEqual(pid2, pid)
70        testcase.assertEqual(72 << 8, status)
71
72
73class SocketServerTest(unittest.TestCase):
74    """Test all socket servers."""
75
76    def setUp(self):
77        signal_alarm(60)  # Kill deadlocks after 60 seconds.
78        self.port_seed = 0
79        self.test_files = []
80
81    def tearDown(self):
82        signal_alarm(0)  # Didn't deadlock.
83        reap_children()
84
85        for fn in self.test_files:
86            try:
87                os.remove(fn)
88            except OSError:
89                pass
90        self.test_files[:] = []
91
92    def pickaddr(self, proto):
93        if proto == socket.AF_INET:
94            return (HOST, 0)
95        else:
96            # XXX: We need a way to tell AF_UNIX to pick its own name
97            # like AF_INET provides port==0.
98            dir = None
99            fn = tempfile.mktemp(prefix='unix_socket.', dir=dir)
100            self.test_files.append(fn)
101            return fn
102
103    def make_server(self, addr, svrcls, hdlrbase):
104        class MyServer(svrcls):
105            def handle_error(self, request, client_address):
106                self.close_request(request)
107                raise
108
109        class MyHandler(hdlrbase):
110            def handle(self):
111                line = self.rfile.readline()
112                self.wfile.write(line)
113
114        if verbose: print("creating server")
115        try:
116            server = MyServer(addr, MyHandler)
117        except PermissionError as e:
118            # Issue 29184: cannot bind() a Unix socket on Android.
119            self.skipTest('Cannot create server (%s, %s): %s' %
120                          (svrcls, addr, e))
121        self.assertEqual(server.server_address, server.socket.getsockname())
122        return server
123
124    @reap_threads
125    def run_server(self, svrcls, hdlrbase, testfunc):
126        server = self.make_server(self.pickaddr(svrcls.address_family),
127                                  svrcls, hdlrbase)
128        # We had the OS pick a port, so pull the real address out of
129        # the server.
130        addr = server.server_address
131        if verbose:
132            print("ADDR =", addr)
133            print("CLASS =", svrcls)
134
135        t = threading.Thread(
136            name='%s serving' % svrcls,
137            target=server.serve_forever,
138            # Short poll interval to make the test finish quickly.
139            # Time between requests is short enough that we won't wake
140            # up spuriously too many times.
141            kwargs={'poll_interval':0.01})
142        t.daemon = True  # In case this function raises.
143        t.start()
144        if verbose: print("server running")
145        for i in range(3):
146            if verbose: print("test client", i)
147            testfunc(svrcls.address_family, addr)
148        if verbose: print("waiting for server")
149        server.shutdown()
150        t.join()
151        server.server_close()
152        self.assertEqual(-1, server.socket.fileno())
153        if HAVE_FORKING and isinstance(server, socketserver.ForkingMixIn):
154            # bpo-31151: Check that ForkingMixIn.server_close() waits until
155            # all children completed
156            self.assertFalse(server.active_children)
157        if verbose: print("done")
158
159    def stream_examine(self, proto, addr):
160        with socket.socket(proto, socket.SOCK_STREAM) as s:
161            s.connect(addr)
162            s.sendall(TEST_STR)
163            buf = data = receive(s, 100)
164            while data and b'\n' not in buf:
165                data = receive(s, 100)
166                buf += data
167            self.assertEqual(buf, TEST_STR)
168
169    def dgram_examine(self, proto, addr):
170        with socket.socket(proto, socket.SOCK_DGRAM) as s:
171            if HAVE_UNIX_SOCKETS and proto == socket.AF_UNIX:
172                s.bind(self.pickaddr(proto))
173            s.sendto(TEST_STR, addr)
174            buf = data = receive(s, 100)
175            while data and b'\n' not in buf:
176                data = receive(s, 100)
177                buf += data
178            self.assertEqual(buf, TEST_STR)
179
180    def test_TCPServer(self):
181        self.run_server(socketserver.TCPServer,
182                        socketserver.StreamRequestHandler,
183                        self.stream_examine)
184
185    def test_ThreadingTCPServer(self):
186        self.run_server(socketserver.ThreadingTCPServer,
187                        socketserver.StreamRequestHandler,
188                        self.stream_examine)
189
190    @requires_forking
191    def test_ForkingTCPServer(self):
192        with simple_subprocess(self):
193            self.run_server(socketserver.ForkingTCPServer,
194                            socketserver.StreamRequestHandler,
195                            self.stream_examine)
196
197    @requires_unix_sockets
198    def test_UnixStreamServer(self):
199        self.run_server(socketserver.UnixStreamServer,
200                        socketserver.StreamRequestHandler,
201                        self.stream_examine)
202
203    @requires_unix_sockets
204    def test_ThreadingUnixStreamServer(self):
205        self.run_server(socketserver.ThreadingUnixStreamServer,
206                        socketserver.StreamRequestHandler,
207                        self.stream_examine)
208
209    @requires_unix_sockets
210    @requires_forking
211    def test_ForkingUnixStreamServer(self):
212        with simple_subprocess(self):
213            self.run_server(ForkingUnixStreamServer,
214                            socketserver.StreamRequestHandler,
215                            self.stream_examine)
216
217    def test_UDPServer(self):
218        self.run_server(socketserver.UDPServer,
219                        socketserver.DatagramRequestHandler,
220                        self.dgram_examine)
221
222    def test_ThreadingUDPServer(self):
223        self.run_server(socketserver.ThreadingUDPServer,
224                        socketserver.DatagramRequestHandler,
225                        self.dgram_examine)
226
227    @requires_forking
228    def test_ForkingUDPServer(self):
229        with simple_subprocess(self):
230            self.run_server(socketserver.ForkingUDPServer,
231                            socketserver.DatagramRequestHandler,
232                            self.dgram_examine)
233
234    @requires_unix_sockets
235    def test_UnixDatagramServer(self):
236        self.run_server(socketserver.UnixDatagramServer,
237                        socketserver.DatagramRequestHandler,
238                        self.dgram_examine)
239
240    @requires_unix_sockets
241    def test_ThreadingUnixDatagramServer(self):
242        self.run_server(socketserver.ThreadingUnixDatagramServer,
243                        socketserver.DatagramRequestHandler,
244                        self.dgram_examine)
245
246    @requires_unix_sockets
247    @requires_forking
248    def test_ForkingUnixDatagramServer(self):
249        self.run_server(ForkingUnixDatagramServer,
250                        socketserver.DatagramRequestHandler,
251                        self.dgram_examine)
252
253    @reap_threads
254    def test_shutdown(self):
255        # Issue #2302: shutdown() should always succeed in making an
256        # other thread leave serve_forever().
257        class MyServer(socketserver.TCPServer):
258            pass
259
260        class MyHandler(socketserver.StreamRequestHandler):
261            pass
262
263        threads = []
264        for i in range(20):
265            s = MyServer((HOST, 0), MyHandler)
266            t = threading.Thread(
267                name='MyServer serving',
268                target=s.serve_forever,
269                kwargs={'poll_interval':0.01})
270            t.daemon = True  # In case this function raises.
271            threads.append((t, s))
272        for t, s in threads:
273            t.start()
274            s.shutdown()
275        for t, s in threads:
276            t.join()
277            s.server_close()
278
279    def test_close_immediately(self):
280        class MyServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
281            pass
282
283        server = MyServer((HOST, 0), lambda: None)
284        server.server_close()
285
286    def test_tcpserver_bind_leak(self):
287        # Issue #22435: the server socket wouldn't be closed if bind()/listen()
288        # failed.
289        # Create many servers for which bind() will fail, to see if this result
290        # in FD exhaustion.
291        for i in range(1024):
292            with self.assertRaises(OverflowError):
293                socketserver.TCPServer((HOST, -1),
294                                       socketserver.StreamRequestHandler)
295
296    def test_context_manager(self):
297        with socketserver.TCPServer((HOST, 0),
298                                    socketserver.StreamRequestHandler) as server:
299            pass
300        self.assertEqual(-1, server.socket.fileno())
301
302
303class ErrorHandlerTest(unittest.TestCase):
304    """Test that the servers pass normal exceptions from the handler to
305    handle_error(), and that exiting exceptions like SystemExit and
306    KeyboardInterrupt are not passed."""
307
308    def tearDown(self):
309        test.support.unlink(test.support.TESTFN)
310
311    def test_sync_handled(self):
312        BaseErrorTestServer(ValueError)
313        self.check_result(handled=True)
314
315    def test_sync_not_handled(self):
316        with self.assertRaises(SystemExit):
317            BaseErrorTestServer(SystemExit)
318        self.check_result(handled=False)
319
320    def test_threading_handled(self):
321        ThreadingErrorTestServer(ValueError)
322        self.check_result(handled=True)
323
324    def test_threading_not_handled(self):
325        ThreadingErrorTestServer(SystemExit)
326        self.check_result(handled=False)
327
328    @requires_forking
329    def test_forking_handled(self):
330        ForkingErrorTestServer(ValueError)
331        self.check_result(handled=True)
332
333    @requires_forking
334    def test_forking_not_handled(self):
335        ForkingErrorTestServer(SystemExit)
336        self.check_result(handled=False)
337
338    def check_result(self, handled):
339        with open(test.support.TESTFN) as log:
340            expected = 'Handler called\n' + 'Error handled\n' * handled
341            self.assertEqual(log.read(), expected)
342
343
344class BaseErrorTestServer(socketserver.TCPServer):
345    def __init__(self, exception):
346        self.exception = exception
347        super().__init__((HOST, 0), BadHandler)
348        with socket.create_connection(self.server_address):
349            pass
350        try:
351            self.handle_request()
352        finally:
353            self.server_close()
354        self.wait_done()
355
356    def handle_error(self, request, client_address):
357        with open(test.support.TESTFN, 'a') as log:
358            log.write('Error handled\n')
359
360    def wait_done(self):
361        pass
362
363
364class BadHandler(socketserver.BaseRequestHandler):
365    def handle(self):
366        with open(test.support.TESTFN, 'a') as log:
367            log.write('Handler called\n')
368        raise self.server.exception('Test error')
369
370
371class ThreadingErrorTestServer(socketserver.ThreadingMixIn,
372        BaseErrorTestServer):
373    def __init__(self, *pos, **kw):
374        self.done = threading.Event()
375        super().__init__(*pos, **kw)
376
377    def shutdown_request(self, *pos, **kw):
378        super().shutdown_request(*pos, **kw)
379        self.done.set()
380
381    def wait_done(self):
382        self.done.wait()
383
384
385if HAVE_FORKING:
386    class ForkingErrorTestServer(socketserver.ForkingMixIn, BaseErrorTestServer):
387        pass
388
389
390class SocketWriterTest(unittest.TestCase):
391    def test_basics(self):
392        class Handler(socketserver.StreamRequestHandler):
393            def handle(self):
394                self.server.wfile = self.wfile
395                self.server.wfile_fileno = self.wfile.fileno()
396                self.server.request_fileno = self.request.fileno()
397
398        server = socketserver.TCPServer((HOST, 0), Handler)
399        self.addCleanup(server.server_close)
400        s = socket.socket(
401            server.address_family, socket.SOCK_STREAM, socket.IPPROTO_TCP)
402        with s:
403            s.connect(server.server_address)
404        server.handle_request()
405        self.assertIsInstance(server.wfile, io.BufferedIOBase)
406        self.assertEqual(server.wfile_fileno, server.request_fileno)
407
408    def test_write(self):
409        # Test that wfile.write() sends data immediately, and that it does
410        # not truncate sends when interrupted by a Unix signal
411        pthread_kill = test.support.get_attribute(signal, 'pthread_kill')
412
413        class Handler(socketserver.StreamRequestHandler):
414            def handle(self):
415                self.server.sent1 = self.wfile.write(b'write data\n')
416                # Should be sent immediately, without requiring flush()
417                self.server.received = self.rfile.readline()
418                big_chunk = b'\0' * test.support.SOCK_MAX_SIZE
419                self.server.sent2 = self.wfile.write(big_chunk)
420
421        server = socketserver.TCPServer((HOST, 0), Handler)
422        self.addCleanup(server.server_close)
423        interrupted = threading.Event()
424
425        def signal_handler(signum, frame):
426            interrupted.set()
427
428        original = signal.signal(signal.SIGUSR1, signal_handler)
429        self.addCleanup(signal.signal, signal.SIGUSR1, original)
430        response1 = None
431        received2 = None
432        main_thread = threading.get_ident()
433
434        def run_client():
435            s = socket.socket(server.address_family, socket.SOCK_STREAM,
436                socket.IPPROTO_TCP)
437            with s, s.makefile('rb') as reader:
438                s.connect(server.server_address)
439                nonlocal response1
440                response1 = reader.readline()
441                s.sendall(b'client response\n')
442
443                reader.read(100)
444                # The main thread should now be blocking in a send() syscall.
445                # But in theory, it could get interrupted by other signals,
446                # and then retried. So keep sending the signal in a loop, in
447                # case an earlier signal happens to be delivered at an
448                # inconvenient moment.
449                while True:
450                    pthread_kill(main_thread, signal.SIGUSR1)
451                    if interrupted.wait(timeout=float(1)):
452                        break
453                nonlocal received2
454                received2 = len(reader.read())
455
456        background = threading.Thread(target=run_client)
457        background.start()
458        server.handle_request()
459        background.join()
460        self.assertEqual(server.sent1, len(response1))
461        self.assertEqual(response1, b'write data\n')
462        self.assertEqual(server.received, b'client response\n')
463        self.assertEqual(server.sent2, test.support.SOCK_MAX_SIZE)
464        self.assertEqual(received2, test.support.SOCK_MAX_SIZE - 100)
465
466
467class MiscTestCase(unittest.TestCase):
468
469    def test_all(self):
470        # objects defined in the module should be in __all__
471        expected = []
472        for name in dir(socketserver):
473            if not name.startswith('_'):
474                mod_object = getattr(socketserver, name)
475                if getattr(mod_object, '__module__', None) == 'socketserver':
476                    expected.append(name)
477        self.assertCountEqual(socketserver.__all__, expected)
478
479    def test_shutdown_request_called_if_verify_request_false(self):
480        # Issue #26309: BaseServer should call shutdown_request even if
481        # verify_request is False
482
483        class MyServer(socketserver.TCPServer):
484            def verify_request(self, request, client_address):
485                return False
486
487            shutdown_called = 0
488            def shutdown_request(self, request):
489                self.shutdown_called += 1
490                socketserver.TCPServer.shutdown_request(self, request)
491
492        server = MyServer((HOST, 0), socketserver.StreamRequestHandler)
493        s = socket.socket(server.address_family, socket.SOCK_STREAM)
494        s.connect(server.server_address)
495        s.close()
496        server.handle_request()
497        self.assertEqual(server.shutdown_called, 1)
498        server.server_close()
499
500    def test_threads_reaped(self):
501        """
502        In #37193, users reported a memory leak
503        due to the saving of every request thread. Ensure that
504        not all threads are kept forever.
505        """
506        class MyServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
507            pass
508
509        server = MyServer((HOST, 0), socketserver.StreamRequestHandler)
510        for n in range(10):
511            with socket.create_connection(server.server_address):
512                server.handle_request()
513        self.assertLess(len(server._threads), 10)
514        server.server_close()
515
516
517if __name__ == "__main__":
518    unittest.main()
519