1import asyncio
2import unittest
3
4from test.test_asyncio import functional as func_tests
5
6
7class ReceiveStuffProto(asyncio.BufferedProtocol):
8    def __init__(self, cb, con_lost_fut):
9        self.cb = cb
10        self.con_lost_fut = con_lost_fut
11
12    def get_buffer(self, sizehint):
13        self.buffer = bytearray(100)
14        return self.buffer
15
16    def buffer_updated(self, nbytes):
17        self.cb(self.buffer[:nbytes])
18
19    def connection_lost(self, exc):
20        if exc is None:
21            self.con_lost_fut.set_result(None)
22        else:
23            self.con_lost_fut.set_exception(exc)
24
25
26class BaseTestBufferedProtocol(func_tests.FunctionalTestCaseMixin):
27
28    def new_loop(self):
29        raise NotImplementedError
30
31    def test_buffered_proto_create_connection(self):
32
33        NOISE = b'12345678+' * 1024
34
35        async def client(addr):
36            data = b''
37
38            def on_buf(buf):
39                nonlocal data
40                data += buf
41                if data == NOISE:
42                    tr.write(b'1')
43
44            conn_lost_fut = self.loop.create_future()
45
46            tr, pr = await self.loop.create_connection(
47                lambda: ReceiveStuffProto(on_buf, conn_lost_fut), *addr)
48
49            await conn_lost_fut
50
51        async def on_server_client(reader, writer):
52            writer.write(NOISE)
53            await reader.readexactly(1)
54            writer.close()
55            await writer.wait_closed()
56
57        srv = self.loop.run_until_complete(
58            asyncio.start_server(
59                on_server_client, '127.0.0.1', 0))
60
61        addr = srv.sockets[0].getsockname()
62        self.loop.run_until_complete(
63            asyncio.wait_for(client(addr), 5, loop=self.loop))
64
65        srv.close()
66        self.loop.run_until_complete(srv.wait_closed())
67
68
69class BufferedProtocolSelectorTests(BaseTestBufferedProtocol,
70                                    unittest.TestCase):
71
72    def new_loop(self):
73        return asyncio.SelectorEventLoop()
74
75
76@unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only')
77class BufferedProtocolProactorTests(BaseTestBufferedProtocol,
78                                    unittest.TestCase):
79
80    def new_loop(self):
81        return asyncio.ProactorEventLoop()
82
83
84if __name__ == '__main__':
85    unittest.main()
86