1# Tests for aiohttp/http_writer.py 2import array 3from unittest import mock 4 5import pytest 6 7from aiohttp import http 8from aiohttp.test_utils import make_mocked_coro 9 10 11@pytest.fixture 12def buf(): 13 return bytearray() 14 15 16@pytest.fixture 17def transport(buf): 18 transport = mock.Mock() 19 20 def write(chunk): 21 buf.extend(chunk) 22 23 transport.write.side_effect = write 24 transport.is_closing.return_value = False 25 return transport 26 27 28@pytest.fixture 29def protocol(loop, transport): 30 protocol = mock.Mock(transport=transport) 31 protocol._drain_helper = make_mocked_coro() 32 return protocol 33 34 35def test_payloadwriter_properties(transport, protocol, loop) -> None: 36 writer = http.StreamWriter(protocol, loop) 37 assert writer.protocol == protocol 38 assert writer.transport == transport 39 40 41async def test_write_payload_eof(transport, protocol, loop) -> None: 42 write = transport.write = mock.Mock() 43 msg = http.StreamWriter(protocol, loop) 44 45 await msg.write(b"data1") 46 await msg.write(b"data2") 47 await msg.write_eof() 48 49 content = b"".join([c[1][0] for c in list(write.mock_calls)]) 50 assert b"data1data2" == content.split(b"\r\n\r\n", 1)[-1] 51 52 53async def test_write_payload_chunked(buf, protocol, transport, loop) -> None: 54 msg = http.StreamWriter(protocol, loop) 55 msg.enable_chunking() 56 await msg.write(b"data") 57 await msg.write_eof() 58 59 assert b"4\r\ndata\r\n0\r\n\r\n" == buf 60 61 62async def test_write_payload_chunked_multiple(buf, protocol, transport, loop) -> None: 63 msg = http.StreamWriter(protocol, loop) 64 msg.enable_chunking() 65 await msg.write(b"data1") 66 await msg.write(b"data2") 67 await msg.write_eof() 68 69 assert b"5\r\ndata1\r\n5\r\ndata2\r\n0\r\n\r\n" == buf 70 71 72async def test_write_payload_length(protocol, transport, loop) -> None: 73 write = transport.write = mock.Mock() 74 75 msg = http.StreamWriter(protocol, loop) 76 msg.length = 2 77 await msg.write(b"d") 78 await msg.write(b"ata") 79 await msg.write_eof() 80 81 content = b"".join([c[1][0] for c in list(write.mock_calls)]) 82 assert b"da" == content.split(b"\r\n\r\n", 1)[-1] 83 84 85async def test_write_payload_chunked_filter(protocol, transport, loop) -> None: 86 write = transport.write = mock.Mock() 87 88 msg = http.StreamWriter(protocol, loop) 89 msg.enable_chunking() 90 await msg.write(b"da") 91 await msg.write(b"ta") 92 await msg.write_eof() 93 94 content = b"".join([c[1][0] for c in list(write.mock_calls)]) 95 assert content.endswith(b"2\r\nda\r\n2\r\nta\r\n0\r\n\r\n") 96 97 98async def test_write_payload_chunked_filter_mutiple_chunks(protocol, transport, loop): 99 write = transport.write = mock.Mock() 100 msg = http.StreamWriter(protocol, loop) 101 msg.enable_chunking() 102 await msg.write(b"da") 103 await msg.write(b"ta") 104 await msg.write(b"1d") 105 await msg.write(b"at") 106 await msg.write(b"a2") 107 await msg.write_eof() 108 content = b"".join([c[1][0] for c in list(write.mock_calls)]) 109 assert content.endswith( 110 b"2\r\nda\r\n2\r\nta\r\n2\r\n1d\r\n2\r\nat\r\n" b"2\r\na2\r\n0\r\n\r\n" 111 ) 112 113 114async def test_write_payload_deflate_compression(protocol, transport, loop) -> None: 115 116 COMPRESSED = b"x\x9cKI,I\x04\x00\x04\x00\x01\x9b" 117 write = transport.write = mock.Mock() 118 msg = http.StreamWriter(protocol, loop) 119 msg.enable_compression("deflate") 120 await msg.write(b"data") 121 await msg.write_eof() 122 123 chunks = [c[1][0] for c in list(write.mock_calls)] 124 assert all(chunks) 125 content = b"".join(chunks) 126 assert COMPRESSED == content.split(b"\r\n\r\n", 1)[-1] 127 128 129async def test_write_payload_deflate_and_chunked(buf, protocol, transport, loop): 130 msg = http.StreamWriter(protocol, loop) 131 msg.enable_compression("deflate") 132 msg.enable_chunking() 133 134 await msg.write(b"da") 135 await msg.write(b"ta") 136 await msg.write_eof() 137 138 thing = b"2\r\nx\x9c\r\n" b"a\r\nKI,I\x04\x00\x04\x00\x01\x9b\r\n" b"0\r\n\r\n" 139 assert thing == buf 140 141 142async def test_write_payload_bytes_memoryview(buf, protocol, transport, loop): 143 144 msg = http.StreamWriter(protocol, loop) 145 146 mv = memoryview(b"abcd") 147 148 await msg.write(mv) 149 await msg.write_eof() 150 151 thing = b"abcd" 152 assert thing == buf 153 154 155async def test_write_payload_short_ints_memoryview(buf, protocol, transport, loop): 156 msg = http.StreamWriter(protocol, loop) 157 msg.enable_chunking() 158 159 payload = memoryview(array.array("H", [65, 66, 67])) 160 161 await msg.write(payload) 162 await msg.write_eof() 163 164 endians = ( 165 (b"6\r\n" b"\x00A\x00B\x00C\r\n" b"0\r\n\r\n"), 166 (b"6\r\n" b"A\x00B\x00C\x00\r\n" b"0\r\n\r\n"), 167 ) 168 assert buf in endians 169 170 171async def test_write_payload_2d_shape_memoryview(buf, protocol, transport, loop): 172 msg = http.StreamWriter(protocol, loop) 173 msg.enable_chunking() 174 175 mv = memoryview(b"ABCDEF") 176 payload = mv.cast("c", [3, 2]) 177 178 await msg.write(payload) 179 await msg.write_eof() 180 181 thing = b"6\r\n" b"ABCDEF\r\n" b"0\r\n\r\n" 182 assert thing == buf 183 184 185async def test_write_payload_slicing_long_memoryview(buf, protocol, transport, loop): 186 msg = http.StreamWriter(protocol, loop) 187 msg.length = 4 188 189 mv = memoryview(b"ABCDEF") 190 payload = mv.cast("c", [3, 2]) 191 192 await msg.write(payload) 193 await msg.write_eof() 194 195 thing = b"ABCD" 196 assert thing == buf 197 198 199async def test_write_drain(protocol, transport, loop) -> None: 200 msg = http.StreamWriter(protocol, loop) 201 msg.drain = make_mocked_coro() 202 await msg.write(b"1" * (64 * 1024 * 2), drain=False) 203 assert not msg.drain.called 204 205 await msg.write(b"1", drain=True) 206 assert msg.drain.called 207 assert msg.buffer_size == 0 208 209 210async def test_write_calls_callback(protocol, transport, loop) -> None: 211 on_chunk_sent = make_mocked_coro() 212 msg = http.StreamWriter(protocol, loop, on_chunk_sent=on_chunk_sent) 213 chunk = b"1" 214 await msg.write(chunk) 215 assert on_chunk_sent.called 216 assert on_chunk_sent.call_args == mock.call(chunk) 217 218 219async def test_write_eof_calls_callback(protocol, transport, loop) -> None: 220 on_chunk_sent = make_mocked_coro() 221 msg = http.StreamWriter(protocol, loop, on_chunk_sent=on_chunk_sent) 222 chunk = b"1" 223 await msg.write_eof(chunk=chunk) 224 assert on_chunk_sent.called 225 assert on_chunk_sent.call_args == mock.call(chunk) 226 227 228async def test_write_to_closing_transport(protocol, transport, loop) -> None: 229 msg = http.StreamWriter(protocol, loop) 230 231 await msg.write(b"Before closing") 232 transport.is_closing.return_value = True 233 234 with pytest.raises(ConnectionResetError): 235 await msg.write(b"After closing") 236 237 238async def test_drain(protocol, transport, loop) -> None: 239 msg = http.StreamWriter(protocol, loop) 240 await msg.drain() 241 assert protocol._drain_helper.called 242 243 244async def test_drain_no_transport(protocol, transport, loop) -> None: 245 msg = http.StreamWriter(protocol, loop) 246 msg._protocol.transport = None 247 await msg.drain() 248 assert not protocol._drain_helper.called 249