1"""
2Test suite for socketserver.
3"""
4
5import contextlib
6import io
7import os
8import select
9import signal
10import socket
11import tempfile
12import threading
13import unittest
14import socketserver
15
16import test.support
17from test.support import reap_children, verbose
18from test.support import os_helper
19from test.support import socket_helper
20from test.support import threading_helper
21
22
23test.support.requires("network")
24
25TEST_STR = b"hello world\n"
26HOST = socket_helper.HOST
27
28HAVE_UNIX_SOCKETS = hasattr(socket, "AF_UNIX")
29requires_unix_sockets = unittest.skipUnless(HAVE_UNIX_SOCKETS,
30                                            'requires Unix sockets')
31HAVE_FORKING = hasattr(os, "fork")
32requires_forking = unittest.skipUnless(HAVE_FORKING, 'requires forking')
33
34def signal_alarm(n):
35    """Call signal.alarm when it exists (i.e. not on Windows)."""
36    if hasattr(signal, 'alarm'):
37        signal.alarm(n)
38
39# Remember real select() to avoid interferences with mocking
40_real_select = select.select
41
42def receive(sock, n, timeout=test.support.SHORT_TIMEOUT):
43    r, w, x = _real_select([sock], [], [], timeout)
44    if sock in r:
45        return sock.recv(n)
46    else:
47        raise RuntimeError("timed out on %r" % (sock,))
48
49if HAVE_UNIX_SOCKETS and HAVE_FORKING:
50    class ForkingUnixStreamServer(socketserver.ForkingMixIn,
51                                  socketserver.UnixStreamServer):
52        pass
53
54    class ForkingUnixDatagramServer(socketserver.ForkingMixIn,
55                                    socketserver.UnixDatagramServer):
56        pass
57
58
59@contextlib.contextmanager
60def simple_subprocess(testcase):
61    """Tests that a custom child process is not waited on (Issue 1540386)"""
62    pid = os.fork()
63    if pid == 0:
64        # Don't raise an exception; it would be caught by the test harness.
65        os._exit(72)
66    try:
67        yield None
68    except:
69        raise
70    finally:
71        test.support.wait_process(pid, exitcode=72)
72
73
74class SocketServerTest(unittest.TestCase):
75    """Test all socket servers."""
76
77    def setUp(self):
78        signal_alarm(60)  # Kill deadlocks after 60 seconds.
79        self.port_seed = 0
80        self.test_files = []
81
82    def tearDown(self):
83        signal_alarm(0)  # Didn't deadlock.
84        reap_children()
85
86        for fn in self.test_files:
87            try:
88                os.remove(fn)
89            except OSError:
90                pass
91        self.test_files[:] = []
92
93    def pickaddr(self, proto):
94        if proto == socket.AF_INET:
95            return (HOST, 0)
96        else:
97            # XXX: We need a way to tell AF_UNIX to pick its own name
98            # like AF_INET provides port==0.
99            dir = None
100            fn = tempfile.mktemp(prefix='unix_socket.', dir=dir)
101            self.test_files.append(fn)
102            return fn
103
104    def make_server(self, addr, svrcls, hdlrbase):
105        class MyServer(svrcls):
106            def handle_error(self, request, client_address):
107                self.close_request(request)
108                raise
109
110        class MyHandler(hdlrbase):
111            def handle(self):
112                line = self.rfile.readline()
113                self.wfile.write(line)
114
115        if verbose: print("creating server")
116        try:
117            server = MyServer(addr, MyHandler)
118        except PermissionError as e:
119            # Issue 29184: cannot bind() a Unix socket on Android.
120            self.skipTest('Cannot create server (%s, %s): %s' %
121                          (svrcls, addr, e))
122        self.assertEqual(server.server_address, server.socket.getsockname())
123        return server
124
125    @threading_helper.reap_threads
126    def run_server(self, svrcls, hdlrbase, testfunc):
127        server = self.make_server(self.pickaddr(svrcls.address_family),
128                                  svrcls, hdlrbase)
129        # We had the OS pick a port, so pull the real address out of
130        # the server.
131        addr = server.server_address
132        if verbose:
133            print("ADDR =", addr)
134            print("CLASS =", svrcls)
135
136        t = threading.Thread(
137            name='%s serving' % svrcls,
138            target=server.serve_forever,
139            # Short poll interval to make the test finish quickly.
140            # Time between requests is short enough that we won't wake
141            # up spuriously too many times.
142            kwargs={'poll_interval':0.01})
143        t.daemon = True  # In case this function raises.
144        t.start()
145        if verbose: print("server running")
146        for i in range(3):
147            if verbose: print("test client", i)
148            testfunc(svrcls.address_family, addr)
149        if verbose: print("waiting for server")
150        server.shutdown()
151        t.join()
152        server.server_close()
153        self.assertEqual(-1, server.socket.fileno())
154        if HAVE_FORKING and isinstance(server, socketserver.ForkingMixIn):
155            # bpo-31151: Check that ForkingMixIn.server_close() waits until
156            # all children completed
157            self.assertFalse(server.active_children)
158        if verbose: print("done")
159
160    def stream_examine(self, proto, addr):
161        with socket.socket(proto, socket.SOCK_STREAM) as s:
162            s.connect(addr)
163            s.sendall(TEST_STR)
164            buf = data = receive(s, 100)
165            while data and b'\n' not in buf:
166                data = receive(s, 100)
167                buf += data
168            self.assertEqual(buf, TEST_STR)
169
170    def dgram_examine(self, proto, addr):
171        with socket.socket(proto, socket.SOCK_DGRAM) as s:
172            if HAVE_UNIX_SOCKETS and proto == socket.AF_UNIX:
173                s.bind(self.pickaddr(proto))
174            s.sendto(TEST_STR, addr)
175            buf = data = receive(s, 100)
176            while data and b'\n' not in buf:
177                data = receive(s, 100)
178                buf += data
179            self.assertEqual(buf, TEST_STR)
180
181    def test_TCPServer(self):
182        self.run_server(socketserver.TCPServer,
183                        socketserver.StreamRequestHandler,
184                        self.stream_examine)
185
186    def test_ThreadingTCPServer(self):
187        self.run_server(socketserver.ThreadingTCPServer,
188                        socketserver.StreamRequestHandler,
189                        self.stream_examine)
190
191    @requires_forking
192    def test_ForkingTCPServer(self):
193        with simple_subprocess(self):
194            self.run_server(socketserver.ForkingTCPServer,
195                            socketserver.StreamRequestHandler,
196                            self.stream_examine)
197
198    @requires_unix_sockets
199    def test_UnixStreamServer(self):
200        self.run_server(socketserver.UnixStreamServer,
201                        socketserver.StreamRequestHandler,
202                        self.stream_examine)
203
204    @requires_unix_sockets
205    def test_ThreadingUnixStreamServer(self):
206        self.run_server(socketserver.ThreadingUnixStreamServer,
207                        socketserver.StreamRequestHandler,
208                        self.stream_examine)
209
210    @requires_unix_sockets
211    @requires_forking
212    def test_ForkingUnixStreamServer(self):
213        with simple_subprocess(self):
214            self.run_server(ForkingUnixStreamServer,
215                            socketserver.StreamRequestHandler,
216                            self.stream_examine)
217
218    def test_UDPServer(self):
219        self.run_server(socketserver.UDPServer,
220                        socketserver.DatagramRequestHandler,
221                        self.dgram_examine)
222
223    def test_ThreadingUDPServer(self):
224        self.run_server(socketserver.ThreadingUDPServer,
225                        socketserver.DatagramRequestHandler,
226                        self.dgram_examine)
227
228    @requires_forking
229    def test_ForkingUDPServer(self):
230        with simple_subprocess(self):
231            self.run_server(socketserver.ForkingUDPServer,
232                            socketserver.DatagramRequestHandler,
233                            self.dgram_examine)
234
235    @requires_unix_sockets
236    def test_UnixDatagramServer(self):
237        self.run_server(socketserver.UnixDatagramServer,
238                        socketserver.DatagramRequestHandler,
239                        self.dgram_examine)
240
241    @requires_unix_sockets
242    def test_ThreadingUnixDatagramServer(self):
243        self.run_server(socketserver.ThreadingUnixDatagramServer,
244                        socketserver.DatagramRequestHandler,
245                        self.dgram_examine)
246
247    @requires_unix_sockets
248    @requires_forking
249    def test_ForkingUnixDatagramServer(self):
250        self.run_server(ForkingUnixDatagramServer,
251                        socketserver.DatagramRequestHandler,
252                        self.dgram_examine)
253
254    @threading_helper.reap_threads
255    def test_shutdown(self):
256        # Issue #2302: shutdown() should always succeed in making an
257        # other thread leave serve_forever().
258        class MyServer(socketserver.TCPServer):
259            pass
260
261        class MyHandler(socketserver.StreamRequestHandler):
262            pass
263
264        threads = []
265        for i in range(20):
266            s = MyServer((HOST, 0), MyHandler)
267            t = threading.Thread(
268                name='MyServer serving',
269                target=s.serve_forever,
270                kwargs={'poll_interval':0.01})
271            t.daemon = True  # In case this function raises.
272            threads.append((t, s))
273        for t, s in threads:
274            t.start()
275            s.shutdown()
276        for t, s in threads:
277            t.join()
278            s.server_close()
279
280    def test_close_immediately(self):
281        class MyServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
282            pass
283
284        server = MyServer((HOST, 0), lambda: None)
285        server.server_close()
286
287    def test_tcpserver_bind_leak(self):
288        # Issue #22435: the server socket wouldn't be closed if bind()/listen()
289        # failed.
290        # Create many servers for which bind() will fail, to see if this result
291        # in FD exhaustion.
292        for i in range(1024):
293            with self.assertRaises(OverflowError):
294                socketserver.TCPServer((HOST, -1),
295                                       socketserver.StreamRequestHandler)
296
297    def test_context_manager(self):
298        with socketserver.TCPServer((HOST, 0),
299                                    socketserver.StreamRequestHandler) as server:
300            pass
301        self.assertEqual(-1, server.socket.fileno())
302
303
304class ErrorHandlerTest(unittest.TestCase):
305    """Test that the servers pass normal exceptions from the handler to
306    handle_error(), and that exiting exceptions like SystemExit and
307    KeyboardInterrupt are not passed."""
308
309    def tearDown(self):
310        os_helper.unlink(os_helper.TESTFN)
311
312    def test_sync_handled(self):
313        BaseErrorTestServer(ValueError)
314        self.check_result(handled=True)
315
316    def test_sync_not_handled(self):
317        with self.assertRaises(SystemExit):
318            BaseErrorTestServer(SystemExit)
319        self.check_result(handled=False)
320
321    def test_threading_handled(self):
322        ThreadingErrorTestServer(ValueError)
323        self.check_result(handled=True)
324
325    def test_threading_not_handled(self):
326        with threading_helper.catch_threading_exception() as cm:
327            ThreadingErrorTestServer(SystemExit)
328            self.check_result(handled=False)
329
330            self.assertIs(cm.exc_type, SystemExit)
331
332    @requires_forking
333    def test_forking_handled(self):
334        ForkingErrorTestServer(ValueError)
335        self.check_result(handled=True)
336
337    @requires_forking
338    def test_forking_not_handled(self):
339        ForkingErrorTestServer(SystemExit)
340        self.check_result(handled=False)
341
342    def check_result(self, handled):
343        with open(os_helper.TESTFN) as log:
344            expected = 'Handler called\n' + 'Error handled\n' * handled
345            self.assertEqual(log.read(), expected)
346
347
348class BaseErrorTestServer(socketserver.TCPServer):
349    def __init__(self, exception):
350        self.exception = exception
351        super().__init__((HOST, 0), BadHandler)
352        with socket.create_connection(self.server_address):
353            pass
354        try:
355            self.handle_request()
356        finally:
357            self.server_close()
358        self.wait_done()
359
360    def handle_error(self, request, client_address):
361        with open(os_helper.TESTFN, 'a') as log:
362            log.write('Error handled\n')
363
364    def wait_done(self):
365        pass
366
367
368class BadHandler(socketserver.BaseRequestHandler):
369    def handle(self):
370        with open(os_helper.TESTFN, 'a') as log:
371            log.write('Handler called\n')
372        raise self.server.exception('Test error')
373
374
375class ThreadingErrorTestServer(socketserver.ThreadingMixIn,
376        BaseErrorTestServer):
377    def __init__(self, *pos, **kw):
378        self.done = threading.Event()
379        super().__init__(*pos, **kw)
380
381    def shutdown_request(self, *pos, **kw):
382        super().shutdown_request(*pos, **kw)
383        self.done.set()
384
385    def wait_done(self):
386        self.done.wait()
387
388
389if HAVE_FORKING:
390    class ForkingErrorTestServer(socketserver.ForkingMixIn, BaseErrorTestServer):
391        pass
392
393
394class SocketWriterTest(unittest.TestCase):
395    def test_basics(self):
396        class Handler(socketserver.StreamRequestHandler):
397            def handle(self):
398                self.server.wfile = self.wfile
399                self.server.wfile_fileno = self.wfile.fileno()
400                self.server.request_fileno = self.request.fileno()
401
402        server = socketserver.TCPServer((HOST, 0), Handler)
403        self.addCleanup(server.server_close)
404        s = socket.socket(
405            server.address_family, socket.SOCK_STREAM, socket.IPPROTO_TCP)
406        with s:
407            s.connect(server.server_address)
408        server.handle_request()
409        self.assertIsInstance(server.wfile, io.BufferedIOBase)
410        self.assertEqual(server.wfile_fileno, server.request_fileno)
411
412    def test_write(self):
413        # Test that wfile.write() sends data immediately, and that it does
414        # not truncate sends when interrupted by a Unix signal
415        pthread_kill = test.support.get_attribute(signal, 'pthread_kill')
416
417        class Handler(socketserver.StreamRequestHandler):
418            def handle(self):
419                self.server.sent1 = self.wfile.write(b'write data\n')
420                # Should be sent immediately, without requiring flush()
421                self.server.received = self.rfile.readline()
422                big_chunk = b'\0' * test.support.SOCK_MAX_SIZE
423                self.server.sent2 = self.wfile.write(big_chunk)
424
425        server = socketserver.TCPServer((HOST, 0), Handler)
426        self.addCleanup(server.server_close)
427        interrupted = threading.Event()
428
429        def signal_handler(signum, frame):
430            interrupted.set()
431
432        original = signal.signal(signal.SIGUSR1, signal_handler)
433        self.addCleanup(signal.signal, signal.SIGUSR1, original)
434        response1 = None
435        received2 = None
436        main_thread = threading.get_ident()
437
438        def run_client():
439            s = socket.socket(server.address_family, socket.SOCK_STREAM,
440                socket.IPPROTO_TCP)
441            with s, s.makefile('rb') as reader:
442                s.connect(server.server_address)
443                nonlocal response1
444                response1 = reader.readline()
445                s.sendall(b'client response\n')
446
447                reader.read(100)
448                # The main thread should now be blocking in a send() syscall.
449                # But in theory, it could get interrupted by other signals,
450                # and then retried. So keep sending the signal in a loop, in
451                # case an earlier signal happens to be delivered at an
452                # inconvenient moment.
453                while True:
454                    pthread_kill(main_thread, signal.SIGUSR1)
455                    if interrupted.wait(timeout=float(1)):
456                        break
457                nonlocal received2
458                received2 = len(reader.read())
459
460        background = threading.Thread(target=run_client)
461        background.start()
462        server.handle_request()
463        background.join()
464        self.assertEqual(server.sent1, len(response1))
465        self.assertEqual(response1, b'write data\n')
466        self.assertEqual(server.received, b'client response\n')
467        self.assertEqual(server.sent2, test.support.SOCK_MAX_SIZE)
468        self.assertEqual(received2, test.support.SOCK_MAX_SIZE - 100)
469
470
471class MiscTestCase(unittest.TestCase):
472
473    def test_all(self):
474        # objects defined in the module should be in __all__
475        expected = []
476        for name in dir(socketserver):
477            if not name.startswith('_'):
478                mod_object = getattr(socketserver, name)
479                if getattr(mod_object, '__module__', None) == 'socketserver':
480                    expected.append(name)
481        self.assertCountEqual(socketserver.__all__, expected)
482
483    def test_shutdown_request_called_if_verify_request_false(self):
484        # Issue #26309: BaseServer should call shutdown_request even if
485        # verify_request is False
486
487        class MyServer(socketserver.TCPServer):
488            def verify_request(self, request, client_address):
489                return False
490
491            shutdown_called = 0
492            def shutdown_request(self, request):
493                self.shutdown_called += 1
494                socketserver.TCPServer.shutdown_request(self, request)
495
496        server = MyServer((HOST, 0), socketserver.StreamRequestHandler)
497        s = socket.socket(server.address_family, socket.SOCK_STREAM)
498        s.connect(server.server_address)
499        s.close()
500        server.handle_request()
501        self.assertEqual(server.shutdown_called, 1)
502        server.server_close()
503
504    def test_threads_reaped(self):
505        """
506        In #37193, users reported a memory leak
507        due to the saving of every request thread. Ensure that
508        not all threads are kept forever.
509        """
510        class MyServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
511            pass
512
513        server = MyServer((HOST, 0), socketserver.StreamRequestHandler)
514        for n in range(10):
515            with socket.create_connection(server.server_address):
516                server.handle_request()
517        self.assertLess(len(server._threads), 10)
518        server.server_close()
519
520
521if __name__ == "__main__":
522    unittest.main()
523