1"""Tests for transports.py."""
2
3import unittest
4from unittest import mock
5
6import asyncio
7from asyncio import transports
8
9
10class TransportTests(unittest.TestCase):
11
12    def test_ctor_extra_is_none(self):
13        transport = asyncio.Transport()
14        self.assertEqual(transport._extra, {})
15
16    def test_get_extra_info(self):
17        transport = asyncio.Transport({'extra': 'info'})
18        self.assertEqual('info', transport.get_extra_info('extra'))
19        self.assertIsNone(transport.get_extra_info('unknown'))
20
21        default = object()
22        self.assertIs(default, transport.get_extra_info('unknown', default))
23
24    def test_writelines(self):
25        writer = mock.Mock()
26
27        class MyTransport(asyncio.Transport):
28            def write(self, data):
29                writer(data)
30
31        transport = MyTransport()
32
33        transport.writelines([b'line1',
34                              bytearray(b'line2'),
35                              memoryview(b'line3')])
36        self.assertEqual(1, writer.call_count)
37        writer.assert_called_with(b'line1line2line3')
38
39    def test_not_implemented(self):
40        transport = asyncio.Transport()
41
42        self.assertRaises(NotImplementedError,
43                          transport.set_write_buffer_limits)
44        self.assertRaises(NotImplementedError, transport.get_write_buffer_size)
45        self.assertRaises(NotImplementedError, transport.write, 'data')
46        self.assertRaises(NotImplementedError, transport.write_eof)
47        self.assertRaises(NotImplementedError, transport.can_write_eof)
48        self.assertRaises(NotImplementedError, transport.pause_reading)
49        self.assertRaises(NotImplementedError, transport.resume_reading)
50        self.assertRaises(NotImplementedError, transport.is_reading)
51        self.assertRaises(NotImplementedError, transport.close)
52        self.assertRaises(NotImplementedError, transport.abort)
53
54    def test_dgram_not_implemented(self):
55        transport = asyncio.DatagramTransport()
56
57        self.assertRaises(NotImplementedError, transport.sendto, 'data')
58        self.assertRaises(NotImplementedError, transport.abort)
59
60    def test_subprocess_transport_not_implemented(self):
61        transport = asyncio.SubprocessTransport()
62
63        self.assertRaises(NotImplementedError, transport.get_pid)
64        self.assertRaises(NotImplementedError, transport.get_returncode)
65        self.assertRaises(NotImplementedError, transport.get_pipe_transport, 1)
66        self.assertRaises(NotImplementedError, transport.send_signal, 1)
67        self.assertRaises(NotImplementedError, transport.terminate)
68        self.assertRaises(NotImplementedError, transport.kill)
69
70    def test_flowcontrol_mixin_set_write_limits(self):
71
72        class MyTransport(transports._FlowControlMixin,
73                          transports.Transport):
74
75            def get_write_buffer_size(self):
76                return 512
77
78        loop = mock.Mock()
79        transport = MyTransport(loop=loop)
80        transport._protocol = mock.Mock()
81
82        self.assertFalse(transport._protocol_paused)
83
84        with self.assertRaisesRegex(ValueError, 'high.*must be >= low'):
85            transport.set_write_buffer_limits(high=0, low=1)
86
87        transport.set_write_buffer_limits(high=1024, low=128)
88        self.assertFalse(transport._protocol_paused)
89        self.assertEqual(transport.get_write_buffer_limits(), (128, 1024))
90
91        transport.set_write_buffer_limits(high=256, low=128)
92        self.assertTrue(transport._protocol_paused)
93        self.assertEqual(transport.get_write_buffer_limits(), (128, 256))
94
95
96if __name__ == '__main__':
97    unittest.main()
98