1"""
2TLS with SNI_-support for Python 2. Follow these instructions if you would
3like to verify TLS certificates in Python 2. Note, the default libraries do
4*not* do certificate checking; you need to do additional work to validate
5certificates yourself.
6
7This needs the following packages installed:
8
9* `pyOpenSSL`_ (tested with 16.0.0)
10* `cryptography`_ (minimum 1.3.4, from pyopenssl)
11* `idna`_ (minimum 2.0, from cryptography)
12
13However, pyopenssl depends on cryptography, which depends on idna, so while we
14use all three directly here we end up having relatively few packages required.
15
16You can install them with the following command:
17
18.. code-block:: bash
19
20    $ python -m pip install pyopenssl cryptography idna
21
22To activate certificate checking, call
23:func:`~urllib3.contrib.pyopenssl.inject_into_urllib3` from your Python code
24before you begin making HTTP requests. This can be done in a ``sitecustomize``
25module, or at any other time before your application begins using ``urllib3``,
26like this:
27
28.. code-block:: python
29
30    try:
31        import urllib3.contrib.pyopenssl
32        urllib3.contrib.pyopenssl.inject_into_urllib3()
33    except ImportError:
34        pass
35
36Now you can use :mod:`urllib3` as you normally would, and it will support SNI
37when the required modules are installed.
38
39Activating this module also has the positive side effect of disabling SSL/TLS
40compression in Python 2 (see `CRIME attack`_).
41
42.. _sni: https://en.wikipedia.org/wiki/Server_Name_Indication
43.. _crime attack: https://en.wikipedia.org/wiki/CRIME_(security_exploit)
44.. _pyopenssl: https://www.pyopenssl.org
45.. _cryptography: https://cryptography.io
46.. _idna: https://github.com/kjd/idna
47"""
48from __future__ import absolute_import
49
50import OpenSSL.SSL
51from cryptography import x509
52from cryptography.hazmat.backends.openssl import backend as openssl_backend
53from cryptography.hazmat.backends.openssl.x509 import _Certificate
54
55try:
56    from cryptography.x509 import UnsupportedExtension
57except ImportError:
58    # UnsupportedExtension is gone in cryptography >= 2.1.0
59    class UnsupportedExtension(Exception):
60        pass
61
62
63from io import BytesIO
64from socket import error as SocketError
65from socket import timeout
66
67try:  # Platform-specific: Python 2
68    from socket import _fileobject
69except ImportError:  # Platform-specific: Python 3
70    _fileobject = None
71    from ..packages.backports.makefile import backport_makefile
72
73import logging
74import ssl
75import sys
76
77from .. import util
78from ..packages import six
79from ..util.ssl_ import PROTOCOL_TLS_CLIENT
80
81__all__ = ["inject_into_urllib3", "extract_from_urllib3"]
82
83# SNI always works.
84HAS_SNI = True
85
86# Map from urllib3 to PyOpenSSL compatible parameter-values.
87_openssl_versions = {
88    util.PROTOCOL_TLS: OpenSSL.SSL.SSLv23_METHOD,
89    PROTOCOL_TLS_CLIENT: OpenSSL.SSL.SSLv23_METHOD,
90    ssl.PROTOCOL_TLSv1: OpenSSL.SSL.TLSv1_METHOD,
91}
92
93if hasattr(ssl, "PROTOCOL_SSLv3") and hasattr(OpenSSL.SSL, "SSLv3_METHOD"):
94    _openssl_versions[ssl.PROTOCOL_SSLv3] = OpenSSL.SSL.SSLv3_METHOD
95
96if hasattr(ssl, "PROTOCOL_TLSv1_1") and hasattr(OpenSSL.SSL, "TLSv1_1_METHOD"):
97    _openssl_versions[ssl.PROTOCOL_TLSv1_1] = OpenSSL.SSL.TLSv1_1_METHOD
98
99if hasattr(ssl, "PROTOCOL_TLSv1_2") and hasattr(OpenSSL.SSL, "TLSv1_2_METHOD"):
100    _openssl_versions[ssl.PROTOCOL_TLSv1_2] = OpenSSL.SSL.TLSv1_2_METHOD
101
102
103_stdlib_to_openssl_verify = {
104    ssl.CERT_NONE: OpenSSL.SSL.VERIFY_NONE,
105    ssl.CERT_OPTIONAL: OpenSSL.SSL.VERIFY_PEER,
106    ssl.CERT_REQUIRED: OpenSSL.SSL.VERIFY_PEER
107    + OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
108}
109_openssl_to_stdlib_verify = dict((v, k) for k, v in _stdlib_to_openssl_verify.items())
110
111# OpenSSL will only write 16K at a time
112SSL_WRITE_BLOCKSIZE = 16384
113
114orig_util_HAS_SNI = util.HAS_SNI
115orig_util_SSLContext = util.ssl_.SSLContext
116
117
118log = logging.getLogger(__name__)
119
120
121def inject_into_urllib3():
122    "Monkey-patch urllib3 with PyOpenSSL-backed SSL-support."
123
124    _validate_dependencies_met()
125
126    util.SSLContext = PyOpenSSLContext
127    util.ssl_.SSLContext = PyOpenSSLContext
128    util.HAS_SNI = HAS_SNI
129    util.ssl_.HAS_SNI = HAS_SNI
130    util.IS_PYOPENSSL = True
131    util.ssl_.IS_PYOPENSSL = True
132
133
134def extract_from_urllib3():
135    "Undo monkey-patching by :func:`inject_into_urllib3`."
136
137    util.SSLContext = orig_util_SSLContext
138    util.ssl_.SSLContext = orig_util_SSLContext
139    util.HAS_SNI = orig_util_HAS_SNI
140    util.ssl_.HAS_SNI = orig_util_HAS_SNI
141    util.IS_PYOPENSSL = False
142    util.ssl_.IS_PYOPENSSL = False
143
144
145def _validate_dependencies_met():
146    """
147    Verifies that PyOpenSSL's package-level dependencies have been met.
148    Throws `ImportError` if they are not met.
149    """
150    # Method added in `cryptography==1.1`; not available in older versions
151    from cryptography.x509.extensions import Extensions
152
153    if getattr(Extensions, "get_extension_for_class", None) is None:
154        raise ImportError(
155            "'cryptography' module missing required functionality.  "
156            "Try upgrading to v1.3.4 or newer."
157        )
158
159    # pyOpenSSL 0.14 and above use cryptography for OpenSSL bindings. The _x509
160    # attribute is only present on those versions.
161    from OpenSSL.crypto import X509
162
163    x509 = X509()
164    if getattr(x509, "_x509", None) is None:
165        raise ImportError(
166            "'pyOpenSSL' module missing required functionality. "
167            "Try upgrading to v0.14 or newer."
168        )
169
170
171def _dnsname_to_stdlib(name):
172    """
173    Converts a dNSName SubjectAlternativeName field to the form used by the
174    standard library on the given Python version.
175
176    Cryptography produces a dNSName as a unicode string that was idna-decoded
177    from ASCII bytes. We need to idna-encode that string to get it back, and
178    then on Python 3 we also need to convert to unicode via UTF-8 (the stdlib
179    uses PyUnicode_FromStringAndSize on it, which decodes via UTF-8).
180
181    If the name cannot be idna-encoded then we return None signalling that
182    the name given should be skipped.
183    """
184
185    def idna_encode(name):
186        """
187        Borrowed wholesale from the Python Cryptography Project. It turns out
188        that we can't just safely call `idna.encode`: it can explode for
189        wildcard names. This avoids that problem.
190        """
191        import idna
192
193        try:
194            for prefix in [u"*.", u"."]:
195                if name.startswith(prefix):
196                    name = name[len(prefix) :]
197                    return prefix.encode("ascii") + idna.encode(name)
198            return idna.encode(name)
199        except idna.core.IDNAError:
200            return None
201
202    # Don't send IPv6 addresses through the IDNA encoder.
203    if ":" in name:
204        return name
205
206    name = idna_encode(name)
207    if name is None:
208        return None
209    elif sys.version_info >= (3, 0):
210        name = name.decode("utf-8")
211    return name
212
213
214def get_subj_alt_name(peer_cert):
215    """
216    Given an PyOpenSSL certificate, provides all the subject alternative names.
217    """
218    # Pass the cert to cryptography, which has much better APIs for this.
219    if hasattr(peer_cert, "to_cryptography"):
220        cert = peer_cert.to_cryptography()
221    else:
222        # This is technically using private APIs, but should work across all
223        # relevant versions before PyOpenSSL got a proper API for this.
224        cert = _Certificate(openssl_backend, peer_cert._x509)
225
226    # We want to find the SAN extension. Ask Cryptography to locate it (it's
227    # faster than looping in Python)
228    try:
229        ext = cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value
230    except x509.ExtensionNotFound:
231        # No such extension, return the empty list.
232        return []
233    except (
234        x509.DuplicateExtension,
235        UnsupportedExtension,
236        x509.UnsupportedGeneralNameType,
237        UnicodeError,
238    ) as e:
239        # A problem has been found with the quality of the certificate. Assume
240        # no SAN field is present.
241        log.warning(
242            "A problem was encountered with the certificate that prevented "
243            "urllib3 from finding the SubjectAlternativeName field. This can "
244            "affect certificate validation. The error was %s",
245            e,
246        )
247        return []
248
249    # We want to return dNSName and iPAddress fields. We need to cast the IPs
250    # back to strings because the match_hostname function wants them as
251    # strings.
252    # Sadly the DNS names need to be idna encoded and then, on Python 3, UTF-8
253    # decoded. This is pretty frustrating, but that's what the standard library
254    # does with certificates, and so we need to attempt to do the same.
255    # We also want to skip over names which cannot be idna encoded.
256    names = [
257        ("DNS", name)
258        for name in map(_dnsname_to_stdlib, ext.get_values_for_type(x509.DNSName))
259        if name is not None
260    ]
261    names.extend(
262        ("IP Address", str(name)) for name in ext.get_values_for_type(x509.IPAddress)
263    )
264
265    return names
266
267
268class WrappedSocket(object):
269    """API-compatibility wrapper for Python OpenSSL's Connection-class.
270
271    Note: _makefile_refs, _drop() and _reuse() are needed for the garbage
272    collector of pypy.
273    """
274
275    def __init__(self, connection, socket, suppress_ragged_eofs=True):
276        self.connection = connection
277        self.socket = socket
278        self.suppress_ragged_eofs = suppress_ragged_eofs
279        self._makefile_refs = 0
280        self._closed = False
281
282    def fileno(self):
283        return self.socket.fileno()
284
285    # Copy-pasted from Python 3.5 source code
286    def _decref_socketios(self):
287        if self._makefile_refs > 0:
288            self._makefile_refs -= 1
289        if self._closed:
290            self.close()
291
292    def recv(self, *args, **kwargs):
293        try:
294            data = self.connection.recv(*args, **kwargs)
295        except OpenSSL.SSL.SysCallError as e:
296            if self.suppress_ragged_eofs and e.args == (-1, "Unexpected EOF"):
297                return b""
298            else:
299                raise SocketError(str(e))
300        except OpenSSL.SSL.ZeroReturnError:
301            if self.connection.get_shutdown() == OpenSSL.SSL.RECEIVED_SHUTDOWN:
302                return b""
303            else:
304                raise
305        except OpenSSL.SSL.WantReadError:
306            if not util.wait_for_read(self.socket, self.socket.gettimeout()):
307                raise timeout("The read operation timed out")
308            else:
309                return self.recv(*args, **kwargs)
310
311        # TLS 1.3 post-handshake authentication
312        except OpenSSL.SSL.Error as e:
313            raise ssl.SSLError("read error: %r" % e)
314        else:
315            return data
316
317    def recv_into(self, *args, **kwargs):
318        try:
319            return self.connection.recv_into(*args, **kwargs)
320        except OpenSSL.SSL.SysCallError as e:
321            if self.suppress_ragged_eofs and e.args == (-1, "Unexpected EOF"):
322                return 0
323            else:
324                raise SocketError(str(e))
325        except OpenSSL.SSL.ZeroReturnError:
326            if self.connection.get_shutdown() == OpenSSL.SSL.RECEIVED_SHUTDOWN:
327                return 0
328            else:
329                raise
330        except OpenSSL.SSL.WantReadError:
331            if not util.wait_for_read(self.socket, self.socket.gettimeout()):
332                raise timeout("The read operation timed out")
333            else:
334                return self.recv_into(*args, **kwargs)
335
336        # TLS 1.3 post-handshake authentication
337        except OpenSSL.SSL.Error as e:
338            raise ssl.SSLError("read error: %r" % e)
339
340    def settimeout(self, timeout):
341        return self.socket.settimeout(timeout)
342
343    def _send_until_done(self, data):
344        while True:
345            try:
346                return self.connection.send(data)
347            except OpenSSL.SSL.WantWriteError:
348                if not util.wait_for_write(self.socket, self.socket.gettimeout()):
349                    raise timeout()
350                continue
351            except OpenSSL.SSL.SysCallError as e:
352                raise SocketError(str(e))
353
354    def sendall(self, data):
355        total_sent = 0
356        while total_sent < len(data):
357            sent = self._send_until_done(
358                data[total_sent : total_sent + SSL_WRITE_BLOCKSIZE]
359            )
360            total_sent += sent
361
362    def shutdown(self):
363        # FIXME rethrow compatible exceptions should we ever use this
364        self.connection.shutdown()
365
366    def close(self):
367        if self._makefile_refs < 1:
368            try:
369                self._closed = True
370                return self.connection.close()
371            except OpenSSL.SSL.Error:
372                return
373        else:
374            self._makefile_refs -= 1
375
376    def getpeercert(self, binary_form=False):
377        x509 = self.connection.get_peer_certificate()
378
379        if not x509:
380            return x509
381
382        if binary_form:
383            return OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_ASN1, x509)
384
385        return {
386            "subject": ((("commonName", x509.get_subject().CN),),),
387            "subjectAltName": get_subj_alt_name(x509),
388        }
389
390    def version(self):
391        return self.connection.get_protocol_version_name()
392
393    def _reuse(self):
394        self._makefile_refs += 1
395
396    def _drop(self):
397        if self._makefile_refs < 1:
398            self.close()
399        else:
400            self._makefile_refs -= 1
401
402
403if _fileobject:  # Platform-specific: Python 2
404
405    def makefile(self, mode, bufsize=-1):
406        self._makefile_refs += 1
407        return _fileobject(self, mode, bufsize, close=True)
408
409
410else:  # Platform-specific: Python 3
411    makefile = backport_makefile
412
413WrappedSocket.makefile = makefile
414
415
416class PyOpenSSLContext(object):
417    """
418    I am a wrapper class for the PyOpenSSL ``Context`` object. I am responsible
419    for translating the interface of the standard library ``SSLContext`` object
420    to calls into PyOpenSSL.
421    """
422
423    def __init__(self, protocol):
424        self.protocol = _openssl_versions[protocol]
425        self._ctx = OpenSSL.SSL.Context(self.protocol)
426        self._options = 0
427        self.check_hostname = False
428
429    @property
430    def options(self):
431        return self._options
432
433    @options.setter
434    def options(self, value):
435        self._options = value
436        self._ctx.set_options(value)
437
438    @property
439    def verify_mode(self):
440        return _openssl_to_stdlib_verify[self._ctx.get_verify_mode()]
441
442    @verify_mode.setter
443    def verify_mode(self, value):
444        self._ctx.set_verify(_stdlib_to_openssl_verify[value], _verify_callback)
445
446    def set_default_verify_paths(self):
447        self._ctx.set_default_verify_paths()
448
449    def set_ciphers(self, ciphers):
450        if isinstance(ciphers, six.text_type):
451            ciphers = ciphers.encode("utf-8")
452        self._ctx.set_cipher_list(ciphers)
453
454    def load_verify_locations(self, cafile=None, capath=None, cadata=None):
455        if cafile is not None:
456            cafile = cafile.encode("utf-8")
457        if capath is not None:
458            capath = capath.encode("utf-8")
459        try:
460            self._ctx.load_verify_locations(cafile, capath)
461            if cadata is not None:
462                self._ctx.load_verify_locations(BytesIO(cadata))
463        except OpenSSL.SSL.Error as e:
464            raise ssl.SSLError("unable to load trusted certificates: %r" % e)
465
466    def load_cert_chain(self, certfile, keyfile=None, password=None):
467        self._ctx.use_certificate_chain_file(certfile)
468        if password is not None:
469            if not isinstance(password, six.binary_type):
470                password = password.encode("utf-8")
471            self._ctx.set_passwd_cb(lambda *_: password)
472        self._ctx.use_privatekey_file(keyfile or certfile)
473
474    def set_alpn_protocols(self, protocols):
475        protocols = [six.ensure_binary(p) for p in protocols]
476        return self._ctx.set_alpn_protos(protocols)
477
478    def wrap_socket(
479        self,
480        sock,
481        server_side=False,
482        do_handshake_on_connect=True,
483        suppress_ragged_eofs=True,
484        server_hostname=None,
485    ):
486        cnx = OpenSSL.SSL.Connection(self._ctx, sock)
487
488        if isinstance(server_hostname, six.text_type):  # Platform-specific: Python 3
489            server_hostname = server_hostname.encode("utf-8")
490
491        if server_hostname is not None:
492            cnx.set_tlsext_host_name(server_hostname)
493
494        cnx.set_connect_state()
495
496        while True:
497            try:
498                cnx.do_handshake()
499            except OpenSSL.SSL.WantReadError:
500                if not util.wait_for_read(sock, sock.gettimeout()):
501                    raise timeout("select timed out")
502                continue
503            except OpenSSL.SSL.Error as e:
504                raise ssl.SSLError("bad handshake: %r" % e)
505            break
506
507        return WrappedSocket(cnx, sock)
508
509
510def _verify_callback(cnx, x509, err_no, err_depth, return_code):
511    return err_no == 0
512