1import socket
2import ssl
3from contextlib import ExitStack
4from threading import Thread
5from typing import NoReturn
6
7import pytest
8from trustme import CA
9
10from anyio import (
11    BrokenResourceError, EndOfStream, Event, connect_tcp, create_task_group, create_tcp_listener)
12from anyio.abc import AnyByteStream, SocketAttribute, SocketStream
13from anyio.streams.tls import TLSAttribute, TLSListener, TLSStream
14
15pytestmark = pytest.mark.anyio
16
17
18class TestTLSStream:
19    async def test_send_receive(self, server_context: ssl.SSLContext,
20                                client_context: ssl.SSLContext) -> None:
21        def serve_sync() -> None:
22            conn, addr = server_sock.accept()
23            conn.settimeout(1)
24            data = conn.recv(10)
25            conn.send(data[::-1])
26            conn.close()
27
28        server_sock = server_context.wrap_socket(socket.socket(), server_side=True,
29                                                 suppress_ragged_eofs=False)
30        server_sock.settimeout(1)
31        server_sock.bind(('127.0.0.1', 0))
32        server_sock.listen()
33        server_thread = Thread(target=serve_sync)
34        server_thread.start()
35
36        async with await connect_tcp(*server_sock.getsockname()) as stream:
37            wrapper = await TLSStream.wrap(stream, hostname='localhost',
38                                           ssl_context=client_context)
39            await wrapper.send(b'hello')
40            response = await wrapper.receive()
41
42        server_thread.join()
43        server_sock.close()
44        assert response == b'olleh'
45
46    async def test_extra_attributes(self, server_context: ssl.SSLContext,
47                                    client_context: ssl.SSLContext) -> None:
48        def serve_sync() -> None:
49            conn, addr = server_sock.accept()
50            with conn:
51                conn.settimeout(1)
52                conn.recv(1)
53
54        server_context.set_alpn_protocols(['h2'])
55        client_context.set_alpn_protocols(['h2'])
56
57        server_sock = server_context.wrap_socket(socket.socket(), server_side=True,
58                                                 suppress_ragged_eofs=True)
59        server_sock.settimeout(1)
60        server_sock.bind(('127.0.0.1', 0))
61        server_sock.listen()
62        server_thread = Thread(target=serve_sync)
63        server_thread.start()
64
65        async with await connect_tcp(*server_sock.getsockname()) as stream:
66            wrapper = await TLSStream.wrap(stream, hostname='localhost',
67                                           ssl_context=client_context, standard_compatible=False)
68            async with wrapper:
69                for name, attribute in SocketAttribute.__dict__.items():
70                    if not name.startswith('_'):
71                        assert wrapper.extra(attribute) == stream.extra(attribute)
72
73                assert wrapper.extra(TLSAttribute.alpn_protocol) == 'h2'
74                assert isinstance(wrapper.extra(TLSAttribute.channel_binding_tls_unique), bytes)
75                assert isinstance(wrapper.extra(TLSAttribute.cipher), tuple)
76                assert isinstance(wrapper.extra(TLSAttribute.peer_certificate), dict)
77                assert isinstance(wrapper.extra(TLSAttribute.peer_certificate_binary), bytes)
78                assert wrapper.extra(TLSAttribute.server_side) is False
79                assert isinstance(wrapper.extra(TLSAttribute.shared_ciphers), list)
80                assert isinstance(wrapper.extra(TLSAttribute.ssl_object), ssl.SSLObject)
81                assert wrapper.extra(TLSAttribute.standard_compatible) is False
82                assert wrapper.extra(TLSAttribute.tls_version).startswith('TLSv')
83                await wrapper.send(b'\x00')
84
85        server_thread.join()
86        server_sock.close()
87
88    async def test_unwrap(self, server_context: ssl.SSLContext,
89                          client_context: ssl.SSLContext) -> None:
90        def serve_sync() -> None:
91            conn, addr = server_sock.accept()
92            conn.settimeout(1)
93            conn.send(b'encrypted')
94            unencrypted = conn.unwrap()
95            unencrypted.send(b'unencrypted')
96            unencrypted.close()
97
98        server_sock = server_context.wrap_socket(socket.socket(), server_side=True,
99                                                 suppress_ragged_eofs=False)
100        server_sock.settimeout(1)
101        server_sock.bind(('127.0.0.1', 0))
102        server_sock.listen()
103        server_thread = Thread(target=serve_sync)
104        server_thread.start()
105
106        async with await connect_tcp(*server_sock.getsockname()) as stream:
107            wrapper = await TLSStream.wrap(stream, hostname='localhost',
108                                           ssl_context=client_context)
109            msg1 = await wrapper.receive()
110            unwrapped_stream, msg2 = await wrapper.unwrap()
111            if msg2 != b'unencrypted':
112                msg2 += await unwrapped_stream.receive()
113
114        server_thread.join()
115        server_sock.close()
116        assert msg1 == b'encrypted'
117        assert msg2 == b'unencrypted'
118
119    @pytest.mark.skipif(not ssl.HAS_ALPN, reason='ALPN support not available')
120    async def test_alpn_negotiation(self, server_context: ssl.SSLContext,
121                                    client_context: ssl.SSLContext) -> None:
122        def serve_sync() -> None:
123            conn, addr = server_sock.accept()
124            conn.settimeout(1)
125            selected_alpn_protocol = conn.selected_alpn_protocol()
126            assert selected_alpn_protocol is not None
127            conn.send(selected_alpn_protocol.encode())
128            conn.close()
129
130        server_context.set_alpn_protocols(['dummy1', 'dummy2'])
131        client_context.set_alpn_protocols(['dummy2', 'dummy3'])
132
133        server_sock = server_context.wrap_socket(socket.socket(), server_side=True,
134                                                 suppress_ragged_eofs=False)
135        server_sock.settimeout(1)
136        server_sock.bind(('127.0.0.1', 0))
137        server_sock.listen()
138        server_thread = Thread(target=serve_sync)
139        server_thread.start()
140
141        async with await connect_tcp(*server_sock.getsockname()) as stream:
142            wrapper = await TLSStream.wrap(stream, hostname='localhost',
143                                           ssl_context=client_context)
144            assert wrapper.extra(TLSAttribute.alpn_protocol) == 'dummy2'
145            server_alpn_protocol = await wrapper.receive()
146
147        server_thread.join()
148        server_sock.close()
149        assert server_alpn_protocol == b'dummy2'
150
151    @pytest.mark.parametrize('server_compatible, client_compatible', [
152        pytest.param(True, True, id='both_standard'),
153        pytest.param(True, False, id='server_standard'),
154        pytest.param(False, True, id='client_standard'),
155        pytest.param(False, False, id='neither_standard')
156    ])
157    async def test_ragged_eofs(self, server_context: ssl.SSLContext,
158                               client_context: ssl.SSLContext, server_compatible: bool,
159                               client_compatible: bool) -> None:
160        server_exc = None
161
162        def serve_sync() -> None:
163            nonlocal server_exc
164            conn, addr = server_sock.accept()
165            try:
166                conn.settimeout(1)
167                conn.sendall(b'hello')
168                if server_compatible:
169                    conn.unwrap()
170            except BaseException as exc:
171                server_exc = exc
172            finally:
173                conn.close()
174
175        client_cm = ExitStack()
176        if client_compatible and not server_compatible:
177            client_cm = pytest.raises(BrokenResourceError)
178
179        server_sock = server_context.wrap_socket(socket.socket(), server_side=True,
180                                                 suppress_ragged_eofs=not server_compatible)
181        server_sock.settimeout(1)
182        server_sock.bind(('127.0.0.1', 0))
183        server_sock.listen()
184        server_thread = Thread(target=serve_sync, daemon=True)
185        server_thread.start()
186
187        stream = await connect_tcp(*server_sock.getsockname())
188        wrapper = await TLSStream.wrap(stream, hostname='localhost', ssl_context=client_context,
189                                       standard_compatible=client_compatible)
190        with client_cm:
191            assert await wrapper.receive() == b'hello'
192            await wrapper.aclose()
193
194        server_thread.join()
195        server_sock.close()
196        if not client_compatible and server_compatible:
197            assert isinstance(server_exc, OSError)
198            assert not isinstance(server_exc, socket.timeout)
199        else:
200            assert server_exc is None
201
202    async def test_ragged_eof_on_receive(self, server_context: ssl.SSLContext,
203                                         client_context: ssl.SSLContext) -> None:
204        server_exc = None
205
206        def serve_sync() -> None:
207            nonlocal server_exc
208            conn, addr = server_sock.accept()
209            try:
210                conn.settimeout(1)
211                conn.sendall(b'hello')
212            except BaseException as exc:
213                server_exc = exc
214            finally:
215                conn.close()
216
217        server_sock = server_context.wrap_socket(socket.socket(), server_side=True,
218                                                 suppress_ragged_eofs=True)
219        server_sock.settimeout(1)
220        server_sock.bind(('127.0.0.1', 0))
221        server_sock.listen()
222        server_thread = Thread(target=serve_sync, daemon=True)
223        server_thread.start()
224        try:
225            async with await connect_tcp(*server_sock.getsockname()) as stream:
226                wrapper = await TLSStream.wrap(stream, hostname='localhost',
227                                               ssl_context=client_context,
228                                               standard_compatible=False)
229                assert await wrapper.receive() == b'hello'
230                with pytest.raises(EndOfStream):
231                    await wrapper.receive()
232        finally:
233            server_thread.join()
234            server_sock.close()
235
236        assert server_exc is None
237
238    async def test_receive_send_after_eof(self, server_context: ssl.SSLContext,
239                                          client_context: ssl.SSLContext) -> None:
240        def serve_sync() -> None:
241            conn, addr = server_sock.accept()
242            conn.sendall(b'hello')
243            conn.unwrap()
244            conn.close()
245
246        server_sock = server_context.wrap_socket(socket.socket(), server_side=True,
247                                                 suppress_ragged_eofs=False)
248        server_sock.settimeout(1)
249        server_sock.bind(('127.0.0.1', 0))
250        server_sock.listen()
251        server_thread = Thread(target=serve_sync, daemon=True)
252        server_thread.start()
253
254        stream = await connect_tcp(*server_sock.getsockname())
255        async with await TLSStream.wrap(stream, hostname='localhost',
256                                        ssl_context=client_context) as wrapper:
257            assert await wrapper.receive() == b'hello'
258            with pytest.raises(EndOfStream):
259                await wrapper.receive()
260
261        server_thread.join()
262        server_sock.close()
263
264    @pytest.mark.parametrize('force_tlsv12', [
265        pytest.param(False, marks=[pytest.mark.skipif(not getattr(ssl, 'HAS_TLSv1_3', False),
266                                                      reason='No TLS 1.3 support')]),
267        pytest.param(True)
268    ], ids=['tlsv13', 'tlsv12'])
269    async def test_send_eof_not_implemented(self, server_context: ssl.SSLContext,
270                                            ca: CA, force_tlsv12: bool) -> None:
271        def serve_sync() -> None:
272            conn, addr = server_sock.accept()
273            conn.sendall(b'hello')
274            conn.unwrap()
275            conn.close()
276
277        client_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
278        ca.configure_trust(client_context)
279        if force_tlsv12:
280            expected_pattern = r'send_eof\(\) requires at least TLSv1.3'
281            if hasattr(ssl, 'TLSVersion'):
282                client_context.maximum_version = ssl.TLSVersion.TLSv1_2
283            else:  # Python 3.6
284                client_context.options |= ssl.OP_NO_TLSv1_3
285        else:
286            expected_pattern = r'send_eof\(\) has not yet been implemented for TLS streams'
287
288        server_sock = server_context.wrap_socket(socket.socket(), server_side=True,
289                                                 suppress_ragged_eofs=False)
290        server_sock.settimeout(1)
291        server_sock.bind(('127.0.0.1', 0))
292        server_sock.listen()
293        server_thread = Thread(target=serve_sync, daemon=True)
294        server_thread.start()
295
296        stream = await connect_tcp(*server_sock.getsockname())
297        async with await TLSStream.wrap(stream, hostname='localhost',
298                                        ssl_context=client_context) as wrapper:
299            assert await wrapper.receive() == b'hello'
300            with pytest.raises(NotImplementedError) as exc:
301                await wrapper.send_eof()
302
303            exc.match(expected_pattern)
304
305        server_thread.join()
306        server_sock.close()
307
308
309class TestTLSListener:
310    async def test_handshake_fail(self, server_context: ssl.SSLContext) -> None:
311        def handler(stream: object) -> NoReturn:  # type: ignore[misc]
312            pytest.fail('This function should never be called in this scenario')
313
314        exception = None
315
316        class CustomTLSListener(TLSListener):
317            @staticmethod
318            async def handle_handshake_error(exc: BaseException,
319                                             stream: AnyByteStream) -> None:
320                nonlocal exception
321                await TLSListener.handle_handshake_error(exc, stream)
322                assert isinstance(stream, SocketStream)
323                exception = exc
324                event.set()
325
326        event = Event()
327        listener = await create_tcp_listener(local_host='127.0.0.1')
328        tls_listener = CustomTLSListener(listener, server_context)
329        async with tls_listener, create_task_group() as tg:
330            tg.start_soon(tls_listener.serve, handler)
331            sock = socket.socket()
332            sock.connect(listener.extra(SocketAttribute.local_address))
333            sock.close()
334            await event.wait()
335            tg.cancel_scope.cancel()
336
337        assert isinstance(exception, BrokenResourceError)
338