1import asyncio
2import time
3import threading
4import unittest
5
6from test.support import socket_helper
7from test.test_asyncio import utils as test_utils
8from test.test_asyncio import functional as func_tests
9
10
11def tearDownModule():
12    asyncio.set_event_loop_policy(None)
13
14
15class BaseStartServer(func_tests.FunctionalTestCaseMixin):
16
17    def new_loop(self):
18        raise NotImplementedError
19
20    def test_start_server_1(self):
21        HELLO_MSG = b'1' * 1024 * 5 + b'\n'
22
23        def client(sock, addr):
24            for i in range(10):
25                time.sleep(0.2)
26                if srv.is_serving():
27                    break
28            else:
29                raise RuntimeError
30
31            sock.settimeout(2)
32            sock.connect(addr)
33            sock.send(HELLO_MSG)
34            sock.recv_all(1)
35            sock.close()
36
37        async def serve(reader, writer):
38            await reader.readline()
39            main_task.cancel()
40            writer.write(b'1')
41            writer.close()
42            await writer.wait_closed()
43
44        async def main(srv):
45            async with srv:
46                await srv.serve_forever()
47
48        srv = self.loop.run_until_complete(asyncio.start_server(
49            serve, socket_helper.HOSTv4, 0, start_serving=False))
50
51        self.assertFalse(srv.is_serving())
52
53        main_task = self.loop.create_task(main(srv))
54
55        addr = srv.sockets[0].getsockname()
56        with self.assertRaises(asyncio.CancelledError):
57            with self.tcp_client(lambda sock: client(sock, addr)):
58                self.loop.run_until_complete(main_task)
59
60        self.assertEqual(srv.sockets, ())
61
62        self.assertIsNone(srv._sockets)
63        self.assertIsNone(srv._waiters)
64        self.assertFalse(srv.is_serving())
65
66        with self.assertRaisesRegex(RuntimeError, r'is closed'):
67            self.loop.run_until_complete(srv.serve_forever())
68
69
70class SelectorStartServerTests(BaseStartServer, unittest.TestCase):
71
72    def new_loop(self):
73        return asyncio.SelectorEventLoop()
74
75    @socket_helper.skip_unless_bind_unix_socket
76    def test_start_unix_server_1(self):
77        HELLO_MSG = b'1' * 1024 * 5 + b'\n'
78        started = threading.Event()
79
80        def client(sock, addr):
81            sock.settimeout(2)
82            started.wait(5)
83            sock.connect(addr)
84            sock.send(HELLO_MSG)
85            sock.recv_all(1)
86            sock.close()
87
88        async def serve(reader, writer):
89            await reader.readline()
90            main_task.cancel()
91            writer.write(b'1')
92            writer.close()
93            await writer.wait_closed()
94
95        async def main(srv):
96            async with srv:
97                self.assertFalse(srv.is_serving())
98                await srv.start_serving()
99                self.assertTrue(srv.is_serving())
100                started.set()
101                await srv.serve_forever()
102
103        with test_utils.unix_socket_path() as addr:
104            srv = self.loop.run_until_complete(asyncio.start_unix_server(
105                serve, addr, start_serving=False))
106
107            main_task = self.loop.create_task(main(srv))
108
109            with self.assertRaises(asyncio.CancelledError):
110                with self.unix_client(lambda sock: client(sock, addr)):
111                    self.loop.run_until_complete(main_task)
112
113            self.assertEqual(srv.sockets, ())
114
115            self.assertIsNone(srv._sockets)
116            self.assertIsNone(srv._waiters)
117            self.assertFalse(srv.is_serving())
118
119            with self.assertRaisesRegex(RuntimeError, r'is closed'):
120                self.loop.run_until_complete(srv.serve_forever())
121
122
123@unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only')
124class ProactorStartServerTests(BaseStartServer, unittest.TestCase):
125
126    def new_loop(self):
127        return asyncio.ProactorEventLoop()
128
129
130if __name__ == '__main__':
131    unittest.main()
132