1import asyncio
2import unittest
3
4from test.test_asyncio import functional as func_tests
5
6
7def tearDownModule():
8    asyncio.set_event_loop_policy(None)
9
10
11class ReceiveStuffProto(asyncio.BufferedProtocol):
12    def __init__(self, cb, con_lost_fut):
13        self.cb = cb
14        self.con_lost_fut = con_lost_fut
15
16    def get_buffer(self, sizehint):
17        self.buffer = bytearray(100)
18        return self.buffer
19
20    def buffer_updated(self, nbytes):
21        self.cb(self.buffer[:nbytes])
22
23    def connection_lost(self, exc):
24        if exc is None:
25            self.con_lost_fut.set_result(None)
26        else:
27            self.con_lost_fut.set_exception(exc)
28
29
30class BaseTestBufferedProtocol(func_tests.FunctionalTestCaseMixin):
31
32    def new_loop(self):
33        raise NotImplementedError
34
35    def test_buffered_proto_create_connection(self):
36
37        NOISE = b'12345678+' * 1024
38
39        async def client(addr):
40            data = b''
41
42            def on_buf(buf):
43                nonlocal data
44                data += buf
45                if data == NOISE:
46                    tr.write(b'1')
47
48            conn_lost_fut = self.loop.create_future()
49
50            tr, pr = await self.loop.create_connection(
51                lambda: ReceiveStuffProto(on_buf, conn_lost_fut), *addr)
52
53            await conn_lost_fut
54
55        async def on_server_client(reader, writer):
56            writer.write(NOISE)
57            await reader.readexactly(1)
58            writer.close()
59            await writer.wait_closed()
60
61        srv = self.loop.run_until_complete(
62            asyncio.start_server(
63                on_server_client, '127.0.0.1', 0))
64
65        addr = srv.sockets[0].getsockname()
66        self.loop.run_until_complete(
67            asyncio.wait_for(client(addr), 5))
68
69        srv.close()
70        self.loop.run_until_complete(srv.wait_closed())
71
72
73class BufferedProtocolSelectorTests(BaseTestBufferedProtocol,
74                                    unittest.TestCase):
75
76    def new_loop(self):
77        return asyncio.SelectorEventLoop()
78
79
80@unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only')
81class BufferedProtocolProactorTests(BaseTestBufferedProtocol,
82                                    unittest.TestCase):
83
84    def new_loop(self):
85        return asyncio.ProactorEventLoop()
86
87
88if __name__ == '__main__':
89    unittest.main()
90