1import ssl
2import time
3from pathlib import Path
4from typing import Union
5
6import pytest
7
8from OpenSSL import SSL
9from mitmproxy import certs, connection
10from mitmproxy.addons import tlsconfig
11from mitmproxy.proxy import context
12from mitmproxy.proxy.layers import modes, tls
13from mitmproxy.test import taddons
14from test.mitmproxy.proxy.layers import test_tls
15
16
17def test_alpn_select_callback():
18    ctx = SSL.Context(SSL.SSLv23_METHOD)
19    conn = SSL.Connection(ctx)
20
21    # Test that we respect addons setting `client.alpn`.
22    conn.set_app_data(tlsconfig.AppData(server_alpn=b"h2", http2=True, client_alpn=b"qux"))
23    assert tlsconfig.alpn_select_callback(conn, [b"http/1.1", b"qux", b"h2"]) == b"qux"
24    conn.set_app_data(tlsconfig.AppData(server_alpn=b"h2", http2=True, client_alpn=b""))
25    assert tlsconfig.alpn_select_callback(conn, [b"http/1.1", b"qux", b"h2"]) == SSL.NO_OVERLAPPING_PROTOCOLS
26
27    # Test that we try to mirror the server connection's ALPN
28    conn.set_app_data(tlsconfig.AppData(server_alpn=b"h2", http2=True, client_alpn=None))
29    assert tlsconfig.alpn_select_callback(conn, [b"http/1.1", b"qux", b"h2"]) == b"h2"
30
31    # Test that we respect the client's preferred HTTP ALPN.
32    conn.set_app_data(tlsconfig.AppData(server_alpn=None, http2=True, client_alpn=None))
33    assert tlsconfig.alpn_select_callback(conn, [b"qux", b"http/1.1", b"h2"]) == b"http/1.1"
34    assert tlsconfig.alpn_select_callback(conn, [b"qux", b"h2", b"http/1.1"]) == b"h2"
35
36    # Test no overlap
37    assert tlsconfig.alpn_select_callback(conn, [b"qux", b"quux"]) == SSL.NO_OVERLAPPING_PROTOCOLS
38
39    # Test that we don't select an ALPN if the server refused to select one.
40    conn.set_app_data(tlsconfig.AppData(server_alpn=b"", http2=True, client_alpn=None))
41    assert tlsconfig.alpn_select_callback(conn, [b"http/1.1"]) == SSL.NO_OVERLAPPING_PROTOCOLS
42
43
44here = Path(__file__).parent
45
46
47class TestTlsConfig:
48    def test_configure(self, tdata):
49        ta = tlsconfig.TlsConfig()
50        with taddons.context(ta) as tctx:
51            with pytest.raises(Exception, match="file does not exist"):
52                tctx.configure(ta, certs=["*=nonexistent"])
53
54            with pytest.raises(Exception, match="Invalid certificate format"):
55                tctx.configure(ta, certs=[tdata.path("mitmproxy/net/data/verificationcerts/trusted-leaf.key")])
56
57            assert not ta.certstore.certs
58            tctx.configure(ta, certs=[tdata.path("mitmproxy/net/data/verificationcerts/trusted-leaf.pem")])
59            assert ta.certstore.certs
60
61    def test_get_cert(self, tdata):
62        """Test that we generate a certificate matching the connection's context."""
63        ta = tlsconfig.TlsConfig()
64        with taddons.context(ta) as tctx:
65            ta.configure(["confdir"])
66
67            ctx = context.Context(connection.Client(("client", 1234), ("127.0.0.1", 8080), 1605699329), tctx.options)
68
69            # Edge case first: We don't have _any_ idea about the server nor is there a SNI,
70            # so we just return our local IP as subject.
71            entry = ta.get_cert(ctx)
72            assert entry.cert.cn == "127.0.0.1"
73
74            # Here we have an existing server connection...
75            ctx.server.address = ("server-address.example", 443)
76            with open(tdata.path("mitmproxy/net/data/verificationcerts/trusted-leaf.crt"), "rb") as f:
77                ctx.server.certificate_list = [certs.Cert.from_pem(f.read())]
78            entry = ta.get_cert(ctx)
79            assert entry.cert.cn == "example.mitmproxy.org"
80            assert entry.cert.altnames == ["example.mitmproxy.org", "server-address.example"]
81
82            # And now we also incorporate SNI.
83            ctx.client.sni = "sni.example"
84            entry = ta.get_cert(ctx)
85            assert entry.cert.altnames == ["example.mitmproxy.org", "sni.example"]
86
87            with open(tdata.path("mitmproxy/data/invalid-subject.pem"), "rb") as f:
88                ctx.server.certificate_list = [certs.Cert.from_pem(f.read())]
89            assert ta.get_cert(ctx)  # does not raise
90
91    def test_tls_clienthello(self):
92        # only really testing for coverage here, there's no point in mirroring the individual conditions
93        ta = tlsconfig.TlsConfig()
94        with taddons.context(ta) as tctx:
95            ctx = context.Context(connection.Client(("client", 1234), ("127.0.0.1", 8080), 1605699329), tctx.options)
96            ch = tls.ClientHelloData(ctx)
97            ta.tls_clienthello(ch)
98            assert not ch.establish_server_tls_first
99
100    def do_handshake(
101            self,
102            tssl_client: Union[test_tls.SSLTest, SSL.Connection],
103            tssl_server: Union[test_tls.SSLTest, SSL.Connection]
104    ) -> bool:
105        # ClientHello
106        with pytest.raises((ssl.SSLWantReadError, SSL.WantReadError)):
107            tssl_client.do_handshake()
108        tssl_server.bio_write(tssl_client.bio_read(65536))
109
110        # ServerHello
111        with pytest.raises((ssl.SSLWantReadError, SSL.WantReadError)):
112            tssl_server.do_handshake()
113        tssl_client.bio_write(tssl_server.bio_read(65536))
114
115        # done
116        tssl_client.do_handshake()
117        tssl_server.bio_write(tssl_client.bio_read(65536))
118        tssl_server.do_handshake()
119
120        return True
121
122    def test_tls_start_client(self, tdata):
123        ta = tlsconfig.TlsConfig()
124        with taddons.context(ta) as tctx:
125            ta.configure(["confdir"])
126            tctx.configure(
127                ta,
128                certs=[tdata.path("mitmproxy/net/data/verificationcerts/trusted-leaf.pem")],
129                ciphers_client="ECDHE-ECDSA-AES128-GCM-SHA256",
130            )
131            ctx = context.Context(connection.Client(("client", 1234), ("127.0.0.1", 8080), 1605699329), tctx.options)
132
133            tls_start = tls.TlsStartData(ctx.client, context=ctx)
134            ta.tls_start_client(tls_start)
135            tssl_server = tls_start.ssl_conn
136            tssl_client = test_tls.SSLTest()
137            assert self.do_handshake(tssl_client, tssl_server)
138            assert tssl_client.obj.getpeercert()["subjectAltName"] == (("DNS", "example.mitmproxy.org"),)
139
140    def test_tls_start_server_verify_failed(self):
141        ta = tlsconfig.TlsConfig()
142        with taddons.context(ta) as tctx:
143            ctx = context.Context(connection.Client(("client", 1234), ("127.0.0.1", 8080), 1605699329), tctx.options)
144            ctx.client.alpn_offers = [b"h2"]
145            ctx.client.cipher_list = ["TLS_AES_256_GCM_SHA384", "ECDHE-RSA-AES128-SHA"]
146            ctx.server.address = ("example.mitmproxy.org", 443)
147
148            tls_start = tls.TlsStartData(ctx.server, context=ctx)
149            ta.tls_start_server(tls_start)
150            tssl_client = tls_start.ssl_conn
151            tssl_server = test_tls.SSLTest(server_side=True)
152            with pytest.raises(SSL.Error, match="certificate verify failed"):
153                assert self.do_handshake(tssl_client, tssl_server)
154
155    def test_tls_start_server_verify_ok(self, tdata):
156        ta = tlsconfig.TlsConfig()
157        with taddons.context(ta) as tctx:
158            ctx = context.Context(connection.Client(("client", 1234), ("127.0.0.1", 8080), 1605699329), tctx.options)
159            ctx.server.address = ("example.mitmproxy.org", 443)
160            tctx.configure(ta, ssl_verify_upstream_trusted_ca=tdata.path(
161                "mitmproxy/net/data/verificationcerts/trusted-root.crt"))
162
163            tls_start = tls.TlsStartData(ctx.server, context=ctx)
164            ta.tls_start_server(tls_start)
165            tssl_client = tls_start.ssl_conn
166            tssl_server = test_tls.SSLTest(server_side=True)
167            assert self.do_handshake(tssl_client, tssl_server)
168
169    def test_tls_start_server_insecure(self):
170        ta = tlsconfig.TlsConfig()
171        with taddons.context(ta) as tctx:
172            ctx = context.Context(connection.Client(("client", 1234), ("127.0.0.1", 8080), 1605699329), tctx.options)
173            ctx.server.address = ("example.mitmproxy.org", 443)
174
175            tctx.configure(
176                ta,
177                ssl_verify_upstream_trusted_ca=None,
178                ssl_insecure=True,
179                http2=False,
180                ciphers_server="ALL"
181            )
182            tls_start = tls.TlsStartData(ctx.server, context=ctx)
183            ta.tls_start_server(tls_start)
184            tssl_client = tls_start.ssl_conn
185            tssl_server = test_tls.SSLTest(server_side=True)
186            assert self.do_handshake(tssl_client, tssl_server)
187
188    def test_alpn_selection(self):
189        ta = tlsconfig.TlsConfig()
190        with taddons.context(ta) as tctx:
191            ctx = context.Context(connection.Client(("client", 1234), ("127.0.0.1", 8080), 1605699329), tctx.options)
192            ctx.server.address = ("example.mitmproxy.org", 443)
193            tls_start = tls.TlsStartData(ctx.server, context=ctx)
194
195            def assert_alpn(http2, client_offers, expected):
196                tctx.configure(ta, http2=http2)
197                ctx.client.alpn_offers = client_offers
198                ctx.server.alpn_offers = None
199                ta.tls_start_server(tls_start)
200                assert ctx.server.alpn_offers == expected
201
202            assert_alpn(True, tls.HTTP_ALPNS + (b"foo",), tls.HTTP_ALPNS + (b"foo",))
203            assert_alpn(False, tls.HTTP_ALPNS + (b"foo",), tls.HTTP1_ALPNS + (b"foo",))
204            assert_alpn(True, [], [])
205            assert_alpn(False, [], [])
206            ctx.client.timestamp_tls_setup = time.time()
207            # make sure that we don't upgrade h1 to h2,
208            # see comment in tlsconfig.py
209            assert_alpn(True, [], [])
210
211    def test_no_h2_proxy(self, tdata):
212        """Do not negotiate h2 on the client<->proxy connection in secure web proxy mode,
213        https://github.com/mitmproxy/mitmproxy/issues/4689"""
214
215        ta = tlsconfig.TlsConfig()
216        with taddons.context(ta) as tctx:
217            tctx.configure(ta, certs=[tdata.path("mitmproxy/net/data/verificationcerts/trusted-leaf.pem")])
218
219            ctx = context.Context(connection.Client(("client", 1234), ("127.0.0.1", 8080), 1605699329), tctx.options)
220            # mock up something that looks like a secure web proxy.
221            ctx.layers = [
222                modes.HttpProxy(ctx),
223                123
224            ]
225            tls_start = tls.TlsStartData(ctx.client, context=ctx)
226            ta.tls_start_client(tls_start)
227            assert tls_start.ssl_conn.get_app_data()["client_alpn"] == b"http/1.1"
228
229    @pytest.mark.parametrize(
230        "client_certs",
231        [
232            "mitmproxy/net/data/verificationcerts/trusted-leaf.pem",
233            "mitmproxy/net/data/verificationcerts/",
234        ],
235    )
236    def test_client_cert_file(self, tdata, client_certs):
237        ta = tlsconfig.TlsConfig()
238        with taddons.context(ta) as tctx:
239            ctx = context.Context(connection.Client(("client", 1234), ("127.0.0.1", 8080), 1605699329), tctx.options)
240            ctx.server.address = ("example.mitmproxy.org", 443)
241            tctx.configure(
242                ta,
243                client_certs=tdata.path(client_certs),
244                ssl_verify_upstream_trusted_ca=tdata.path("mitmproxy/net/data/verificationcerts/trusted-root.crt"),
245            )
246
247            tls_start = tls.TlsStartData(ctx.server, context=ctx)
248            ta.tls_start_server(tls_start)
249            tssl_client = tls_start.ssl_conn
250            tssl_server = test_tls.SSLTest(server_side=True)
251
252            assert self.do_handshake(tssl_client, tssl_server)
253            assert tssl_server.obj.getpeercert()
254
255    @pytest.mark.asyncio
256    async def test_ca_expired(self, monkeypatch):
257        monkeypatch.setattr(certs.Cert, "has_expired", lambda self: True)
258        ta = tlsconfig.TlsConfig()
259        with taddons.context(ta) as tctx:
260            ta.configure(["confdir"])
261            await tctx.master.await_log("The mitmproxy certificate authority has expired", "warn")
262