1# Tests for aiohttp/http_writer.py
2import array
3from unittest import mock
4
5import pytest
6
7from aiohttp import http
8from aiohttp.test_utils import make_mocked_coro
9
10
11@pytest.fixture
12def buf():
13    return bytearray()
14
15
16@pytest.fixture
17def transport(buf):
18    transport = mock.Mock()
19
20    def write(chunk):
21        buf.extend(chunk)
22
23    transport.write.side_effect = write
24    transport.is_closing.return_value = False
25    return transport
26
27
28@pytest.fixture
29def protocol(loop, transport):
30    protocol = mock.Mock(transport=transport)
31    protocol._drain_helper = make_mocked_coro()
32    return protocol
33
34
35def test_payloadwriter_properties(transport, protocol, loop) -> None:
36    writer = http.StreamWriter(protocol, loop)
37    assert writer.protocol == protocol
38    assert writer.transport == transport
39
40
41async def test_write_payload_eof(transport, protocol, loop) -> None:
42    write = transport.write = mock.Mock()
43    msg = http.StreamWriter(protocol, loop)
44
45    await msg.write(b"data1")
46    await msg.write(b"data2")
47    await msg.write_eof()
48
49    content = b"".join([c[1][0] for c in list(write.mock_calls)])
50    assert b"data1data2" == content.split(b"\r\n\r\n", 1)[-1]
51
52
53async def test_write_payload_chunked(buf, protocol, transport, loop) -> None:
54    msg = http.StreamWriter(protocol, loop)
55    msg.enable_chunking()
56    await msg.write(b"data")
57    await msg.write_eof()
58
59    assert b"4\r\ndata\r\n0\r\n\r\n" == buf
60
61
62async def test_write_payload_chunked_multiple(buf, protocol, transport, loop) -> None:
63    msg = http.StreamWriter(protocol, loop)
64    msg.enable_chunking()
65    await msg.write(b"data1")
66    await msg.write(b"data2")
67    await msg.write_eof()
68
69    assert b"5\r\ndata1\r\n5\r\ndata2\r\n0\r\n\r\n" == buf
70
71
72async def test_write_payload_length(protocol, transport, loop) -> None:
73    write = transport.write = mock.Mock()
74
75    msg = http.StreamWriter(protocol, loop)
76    msg.length = 2
77    await msg.write(b"d")
78    await msg.write(b"ata")
79    await msg.write_eof()
80
81    content = b"".join([c[1][0] for c in list(write.mock_calls)])
82    assert b"da" == content.split(b"\r\n\r\n", 1)[-1]
83
84
85async def test_write_payload_chunked_filter(protocol, transport, loop) -> None:
86    write = transport.write = mock.Mock()
87
88    msg = http.StreamWriter(protocol, loop)
89    msg.enable_chunking()
90    await msg.write(b"da")
91    await msg.write(b"ta")
92    await msg.write_eof()
93
94    content = b"".join([c[1][0] for c in list(write.mock_calls)])
95    assert content.endswith(b"2\r\nda\r\n2\r\nta\r\n0\r\n\r\n")
96
97
98async def test_write_payload_chunked_filter_mutiple_chunks(protocol, transport, loop):
99    write = transport.write = mock.Mock()
100    msg = http.StreamWriter(protocol, loop)
101    msg.enable_chunking()
102    await msg.write(b"da")
103    await msg.write(b"ta")
104    await msg.write(b"1d")
105    await msg.write(b"at")
106    await msg.write(b"a2")
107    await msg.write_eof()
108    content = b"".join([c[1][0] for c in list(write.mock_calls)])
109    assert content.endswith(
110        b"2\r\nda\r\n2\r\nta\r\n2\r\n1d\r\n2\r\nat\r\n" b"2\r\na2\r\n0\r\n\r\n"
111    )
112
113
114async def test_write_payload_deflate_compression(protocol, transport, loop) -> None:
115
116    COMPRESSED = b"x\x9cKI,I\x04\x00\x04\x00\x01\x9b"
117    write = transport.write = mock.Mock()
118    msg = http.StreamWriter(protocol, loop)
119    msg.enable_compression("deflate")
120    await msg.write(b"data")
121    await msg.write_eof()
122
123    chunks = [c[1][0] for c in list(write.mock_calls)]
124    assert all(chunks)
125    content = b"".join(chunks)
126    assert COMPRESSED == content.split(b"\r\n\r\n", 1)[-1]
127
128
129async def test_write_payload_deflate_and_chunked(buf, protocol, transport, loop):
130    msg = http.StreamWriter(protocol, loop)
131    msg.enable_compression("deflate")
132    msg.enable_chunking()
133
134    await msg.write(b"da")
135    await msg.write(b"ta")
136    await msg.write_eof()
137
138    thing = b"2\r\nx\x9c\r\n" b"a\r\nKI,I\x04\x00\x04\x00\x01\x9b\r\n" b"0\r\n\r\n"
139    assert thing == buf
140
141
142async def test_write_payload_bytes_memoryview(buf, protocol, transport, loop):
143
144    msg = http.StreamWriter(protocol, loop)
145
146    mv = memoryview(b"abcd")
147
148    await msg.write(mv)
149    await msg.write_eof()
150
151    thing = b"abcd"
152    assert thing == buf
153
154
155async def test_write_payload_short_ints_memoryview(buf, protocol, transport, loop):
156    msg = http.StreamWriter(protocol, loop)
157    msg.enable_chunking()
158
159    payload = memoryview(array.array("H", [65, 66, 67]))
160
161    await msg.write(payload)
162    await msg.write_eof()
163
164    endians = (
165        (b"6\r\n" b"\x00A\x00B\x00C\r\n" b"0\r\n\r\n"),
166        (b"6\r\n" b"A\x00B\x00C\x00\r\n" b"0\r\n\r\n"),
167    )
168    assert buf in endians
169
170
171async def test_write_payload_2d_shape_memoryview(buf, protocol, transport, loop):
172    msg = http.StreamWriter(protocol, loop)
173    msg.enable_chunking()
174
175    mv = memoryview(b"ABCDEF")
176    payload = mv.cast("c", [3, 2])
177
178    await msg.write(payload)
179    await msg.write_eof()
180
181    thing = b"6\r\n" b"ABCDEF\r\n" b"0\r\n\r\n"
182    assert thing == buf
183
184
185async def test_write_payload_slicing_long_memoryview(buf, protocol, transport, loop):
186    msg = http.StreamWriter(protocol, loop)
187    msg.length = 4
188
189    mv = memoryview(b"ABCDEF")
190    payload = mv.cast("c", [3, 2])
191
192    await msg.write(payload)
193    await msg.write_eof()
194
195    thing = b"ABCD"
196    assert thing == buf
197
198
199async def test_write_drain(protocol, transport, loop) -> None:
200    msg = http.StreamWriter(protocol, loop)
201    msg.drain = make_mocked_coro()
202    await msg.write(b"1" * (64 * 1024 * 2), drain=False)
203    assert not msg.drain.called
204
205    await msg.write(b"1", drain=True)
206    assert msg.drain.called
207    assert msg.buffer_size == 0
208
209
210async def test_write_calls_callback(protocol, transport, loop) -> None:
211    on_chunk_sent = make_mocked_coro()
212    msg = http.StreamWriter(protocol, loop, on_chunk_sent=on_chunk_sent)
213    chunk = b"1"
214    await msg.write(chunk)
215    assert on_chunk_sent.called
216    assert on_chunk_sent.call_args == mock.call(chunk)
217
218
219async def test_write_eof_calls_callback(protocol, transport, loop) -> None:
220    on_chunk_sent = make_mocked_coro()
221    msg = http.StreamWriter(protocol, loop, on_chunk_sent=on_chunk_sent)
222    chunk = b"1"
223    await msg.write_eof(chunk=chunk)
224    assert on_chunk_sent.called
225    assert on_chunk_sent.call_args == mock.call(chunk)
226
227
228async def test_write_to_closing_transport(protocol, transport, loop) -> None:
229    msg = http.StreamWriter(protocol, loop)
230
231    await msg.write(b"Before closing")
232    transport.is_closing.return_value = True
233
234    with pytest.raises(ConnectionResetError):
235        await msg.write(b"After closing")
236
237
238async def test_drain(protocol, transport, loop) -> None:
239    msg = http.StreamWriter(protocol, loop)
240    await msg.drain()
241    assert protocol._drain_helper.called
242
243
244async def test_drain_no_transport(protocol, transport, loop) -> None:
245    msg = http.StreamWriter(protocol, loop)
246    msg._protocol.transport = None
247    await msg.drain()
248    assert not protocol._drain_helper.called
249