1import socket 2import ssl 3from contextlib import ExitStack 4from threading import Thread 5from typing import NoReturn 6 7import pytest 8from trustme import CA 9 10from anyio import ( 11 BrokenResourceError, EndOfStream, Event, connect_tcp, create_task_group, create_tcp_listener) 12from anyio.abc import AnyByteStream, SocketAttribute, SocketStream 13from anyio.streams.tls import TLSAttribute, TLSListener, TLSStream 14 15pytestmark = pytest.mark.anyio 16 17 18class TestTLSStream: 19 async def test_send_receive(self, server_context: ssl.SSLContext, 20 client_context: ssl.SSLContext) -> None: 21 def serve_sync() -> None: 22 conn, addr = server_sock.accept() 23 conn.settimeout(1) 24 data = conn.recv(10) 25 conn.send(data[::-1]) 26 conn.close() 27 28 server_sock = server_context.wrap_socket(socket.socket(), server_side=True, 29 suppress_ragged_eofs=False) 30 server_sock.settimeout(1) 31 server_sock.bind(('127.0.0.1', 0)) 32 server_sock.listen() 33 server_thread = Thread(target=serve_sync) 34 server_thread.start() 35 36 async with await connect_tcp(*server_sock.getsockname()) as stream: 37 wrapper = await TLSStream.wrap(stream, hostname='localhost', 38 ssl_context=client_context) 39 await wrapper.send(b'hello') 40 response = await wrapper.receive() 41 42 server_thread.join() 43 server_sock.close() 44 assert response == b'olleh' 45 46 async def test_extra_attributes(self, server_context: ssl.SSLContext, 47 client_context: ssl.SSLContext) -> None: 48 def serve_sync() -> None: 49 conn, addr = server_sock.accept() 50 with conn: 51 conn.settimeout(1) 52 conn.recv(1) 53 54 server_context.set_alpn_protocols(['h2']) 55 client_context.set_alpn_protocols(['h2']) 56 57 server_sock = server_context.wrap_socket(socket.socket(), server_side=True, 58 suppress_ragged_eofs=True) 59 server_sock.settimeout(1) 60 server_sock.bind(('127.0.0.1', 0)) 61 server_sock.listen() 62 server_thread = Thread(target=serve_sync) 63 server_thread.start() 64 65 async with await connect_tcp(*server_sock.getsockname()) as stream: 66 wrapper = await TLSStream.wrap(stream, hostname='localhost', 67 ssl_context=client_context, standard_compatible=False) 68 async with wrapper: 69 for name, attribute in SocketAttribute.__dict__.items(): 70 if not name.startswith('_'): 71 assert wrapper.extra(attribute) == stream.extra(attribute) 72 73 assert wrapper.extra(TLSAttribute.alpn_protocol) == 'h2' 74 assert isinstance(wrapper.extra(TLSAttribute.channel_binding_tls_unique), bytes) 75 assert isinstance(wrapper.extra(TLSAttribute.cipher), tuple) 76 assert isinstance(wrapper.extra(TLSAttribute.peer_certificate), dict) 77 assert isinstance(wrapper.extra(TLSAttribute.peer_certificate_binary), bytes) 78 assert wrapper.extra(TLSAttribute.server_side) is False 79 assert isinstance(wrapper.extra(TLSAttribute.shared_ciphers), list) 80 assert isinstance(wrapper.extra(TLSAttribute.ssl_object), ssl.SSLObject) 81 assert wrapper.extra(TLSAttribute.standard_compatible) is False 82 assert wrapper.extra(TLSAttribute.tls_version).startswith('TLSv') 83 await wrapper.send(b'\x00') 84 85 server_thread.join() 86 server_sock.close() 87 88 async def test_unwrap(self, server_context: ssl.SSLContext, 89 client_context: ssl.SSLContext) -> None: 90 def serve_sync() -> None: 91 conn, addr = server_sock.accept() 92 conn.settimeout(1) 93 conn.send(b'encrypted') 94 unencrypted = conn.unwrap() 95 unencrypted.send(b'unencrypted') 96 unencrypted.close() 97 98 server_sock = server_context.wrap_socket(socket.socket(), server_side=True, 99 suppress_ragged_eofs=False) 100 server_sock.settimeout(1) 101 server_sock.bind(('127.0.0.1', 0)) 102 server_sock.listen() 103 server_thread = Thread(target=serve_sync) 104 server_thread.start() 105 106 async with await connect_tcp(*server_sock.getsockname()) as stream: 107 wrapper = await TLSStream.wrap(stream, hostname='localhost', 108 ssl_context=client_context) 109 msg1 = await wrapper.receive() 110 unwrapped_stream, msg2 = await wrapper.unwrap() 111 if msg2 != b'unencrypted': 112 msg2 += await unwrapped_stream.receive() 113 114 server_thread.join() 115 server_sock.close() 116 assert msg1 == b'encrypted' 117 assert msg2 == b'unencrypted' 118 119 @pytest.mark.skipif(not ssl.HAS_ALPN, reason='ALPN support not available') 120 async def test_alpn_negotiation(self, server_context: ssl.SSLContext, 121 client_context: ssl.SSLContext) -> None: 122 def serve_sync() -> None: 123 conn, addr = server_sock.accept() 124 conn.settimeout(1) 125 selected_alpn_protocol = conn.selected_alpn_protocol() 126 assert selected_alpn_protocol is not None 127 conn.send(selected_alpn_protocol.encode()) 128 conn.close() 129 130 server_context.set_alpn_protocols(['dummy1', 'dummy2']) 131 client_context.set_alpn_protocols(['dummy2', 'dummy3']) 132 133 server_sock = server_context.wrap_socket(socket.socket(), server_side=True, 134 suppress_ragged_eofs=False) 135 server_sock.settimeout(1) 136 server_sock.bind(('127.0.0.1', 0)) 137 server_sock.listen() 138 server_thread = Thread(target=serve_sync) 139 server_thread.start() 140 141 async with await connect_tcp(*server_sock.getsockname()) as stream: 142 wrapper = await TLSStream.wrap(stream, hostname='localhost', 143 ssl_context=client_context) 144 assert wrapper.extra(TLSAttribute.alpn_protocol) == 'dummy2' 145 server_alpn_protocol = await wrapper.receive() 146 147 server_thread.join() 148 server_sock.close() 149 assert server_alpn_protocol == b'dummy2' 150 151 @pytest.mark.parametrize('server_compatible, client_compatible', [ 152 pytest.param(True, True, id='both_standard'), 153 pytest.param(True, False, id='server_standard'), 154 pytest.param(False, True, id='client_standard'), 155 pytest.param(False, False, id='neither_standard') 156 ]) 157 async def test_ragged_eofs(self, server_context: ssl.SSLContext, 158 client_context: ssl.SSLContext, server_compatible: bool, 159 client_compatible: bool) -> None: 160 server_exc = None 161 162 def serve_sync() -> None: 163 nonlocal server_exc 164 conn, addr = server_sock.accept() 165 try: 166 conn.settimeout(1) 167 conn.sendall(b'hello') 168 if server_compatible: 169 conn.unwrap() 170 except BaseException as exc: 171 server_exc = exc 172 finally: 173 conn.close() 174 175 client_cm = ExitStack() 176 if client_compatible and not server_compatible: 177 client_cm = pytest.raises(BrokenResourceError) 178 179 server_sock = server_context.wrap_socket(socket.socket(), server_side=True, 180 suppress_ragged_eofs=not server_compatible) 181 server_sock.settimeout(1) 182 server_sock.bind(('127.0.0.1', 0)) 183 server_sock.listen() 184 server_thread = Thread(target=serve_sync, daemon=True) 185 server_thread.start() 186 187 stream = await connect_tcp(*server_sock.getsockname()) 188 wrapper = await TLSStream.wrap(stream, hostname='localhost', ssl_context=client_context, 189 standard_compatible=client_compatible) 190 with client_cm: 191 assert await wrapper.receive() == b'hello' 192 await wrapper.aclose() 193 194 server_thread.join() 195 server_sock.close() 196 if not client_compatible and server_compatible: 197 assert isinstance(server_exc, OSError) 198 assert not isinstance(server_exc, socket.timeout) 199 else: 200 assert server_exc is None 201 202 async def test_ragged_eof_on_receive(self, server_context: ssl.SSLContext, 203 client_context: ssl.SSLContext) -> None: 204 server_exc = None 205 206 def serve_sync() -> None: 207 nonlocal server_exc 208 conn, addr = server_sock.accept() 209 try: 210 conn.settimeout(1) 211 conn.sendall(b'hello') 212 except BaseException as exc: 213 server_exc = exc 214 finally: 215 conn.close() 216 217 server_sock = server_context.wrap_socket(socket.socket(), server_side=True, 218 suppress_ragged_eofs=True) 219 server_sock.settimeout(1) 220 server_sock.bind(('127.0.0.1', 0)) 221 server_sock.listen() 222 server_thread = Thread(target=serve_sync, daemon=True) 223 server_thread.start() 224 try: 225 async with await connect_tcp(*server_sock.getsockname()) as stream: 226 wrapper = await TLSStream.wrap(stream, hostname='localhost', 227 ssl_context=client_context, 228 standard_compatible=False) 229 assert await wrapper.receive() == b'hello' 230 with pytest.raises(EndOfStream): 231 await wrapper.receive() 232 finally: 233 server_thread.join() 234 server_sock.close() 235 236 assert server_exc is None 237 238 async def test_receive_send_after_eof(self, server_context: ssl.SSLContext, 239 client_context: ssl.SSLContext) -> None: 240 def serve_sync() -> None: 241 conn, addr = server_sock.accept() 242 conn.sendall(b'hello') 243 conn.unwrap() 244 conn.close() 245 246 server_sock = server_context.wrap_socket(socket.socket(), server_side=True, 247 suppress_ragged_eofs=False) 248 server_sock.settimeout(1) 249 server_sock.bind(('127.0.0.1', 0)) 250 server_sock.listen() 251 server_thread = Thread(target=serve_sync, daemon=True) 252 server_thread.start() 253 254 stream = await connect_tcp(*server_sock.getsockname()) 255 async with await TLSStream.wrap(stream, hostname='localhost', 256 ssl_context=client_context) as wrapper: 257 assert await wrapper.receive() == b'hello' 258 with pytest.raises(EndOfStream): 259 await wrapper.receive() 260 261 server_thread.join() 262 server_sock.close() 263 264 @pytest.mark.parametrize('force_tlsv12', [ 265 pytest.param(False, marks=[pytest.mark.skipif(not getattr(ssl, 'HAS_TLSv1_3', False), 266 reason='No TLS 1.3 support')]), 267 pytest.param(True) 268 ], ids=['tlsv13', 'tlsv12']) 269 async def test_send_eof_not_implemented(self, server_context: ssl.SSLContext, 270 ca: CA, force_tlsv12: bool) -> None: 271 def serve_sync() -> None: 272 conn, addr = server_sock.accept() 273 conn.sendall(b'hello') 274 conn.unwrap() 275 conn.close() 276 277 client_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) 278 ca.configure_trust(client_context) 279 if force_tlsv12: 280 expected_pattern = r'send_eof\(\) requires at least TLSv1.3' 281 if hasattr(ssl, 'TLSVersion'): 282 client_context.maximum_version = ssl.TLSVersion.TLSv1_2 283 else: # Python 3.6 284 client_context.options |= ssl.OP_NO_TLSv1_3 285 else: 286 expected_pattern = r'send_eof\(\) has not yet been implemented for TLS streams' 287 288 server_sock = server_context.wrap_socket(socket.socket(), server_side=True, 289 suppress_ragged_eofs=False) 290 server_sock.settimeout(1) 291 server_sock.bind(('127.0.0.1', 0)) 292 server_sock.listen() 293 server_thread = Thread(target=serve_sync, daemon=True) 294 server_thread.start() 295 296 stream = await connect_tcp(*server_sock.getsockname()) 297 async with await TLSStream.wrap(stream, hostname='localhost', 298 ssl_context=client_context) as wrapper: 299 assert await wrapper.receive() == b'hello' 300 with pytest.raises(NotImplementedError) as exc: 301 await wrapper.send_eof() 302 303 exc.match(expected_pattern) 304 305 server_thread.join() 306 server_sock.close() 307 308 309class TestTLSListener: 310 async def test_handshake_fail(self, server_context: ssl.SSLContext) -> None: 311 def handler(stream: object) -> NoReturn: # type: ignore[misc] 312 pytest.fail('This function should never be called in this scenario') 313 314 exception = None 315 316 class CustomTLSListener(TLSListener): 317 @staticmethod 318 async def handle_handshake_error(exc: BaseException, 319 stream: AnyByteStream) -> None: 320 nonlocal exception 321 await TLSListener.handle_handshake_error(exc, stream) 322 assert isinstance(stream, SocketStream) 323 exception = exc 324 event.set() 325 326 event = Event() 327 listener = await create_tcp_listener(local_host='127.0.0.1') 328 tls_listener = CustomTLSListener(listener, server_context) 329 async with tls_listener, create_task_group() as tg: 330 tg.start_soon(tls_listener.serve, handler) 331 sock = socket.socket() 332 sock.connect(listener.extra(SocketAttribute.local_address)) 333 sock.close() 334 await event.wait() 335 tg.cancel_scope.cancel() 336 337 assert isinstance(exception, BrokenResourceError) 338