1# -*- coding: utf-8 -*-
2"""
3hyper/tls
4~~~~~~~~~
5
6Contains the TLS/SSL logic for use in hyper.
7"""
8import os.path as path
9from .common.exceptions import MissingCertFile
10from .compat import ignore_missing, ssl
11
12
13NPN_PROTOCOL = 'h2'
14H2_NPN_PROTOCOLS = [NPN_PROTOCOL, 'h2-16', 'h2-15', 'h2-14']
15SUPPORTED_NPN_PROTOCOLS = H2_NPN_PROTOCOLS + ['http/1.1']
16
17H2C_PROTOCOL = 'h2c'
18
19# We have a singleton SSLContext object. There's no reason to be creating one
20# per connection.
21_context = None
22
23# Work out where our certificates are.
24cert_loc = path.join(path.dirname(__file__), 'certs.pem')
25
26
27def wrap_socket(sock, server_hostname, ssl_context=None, force_proto=None):
28    """
29    A vastly simplified SSL wrapping function. We'll probably extend this to
30    do more things later.
31    """
32
33    global _context
34
35    if ssl_context:
36        # if an SSLContext is provided then use it instead of default context
37        _ssl_context = ssl_context
38    else:
39        # create the singleton SSLContext we use
40        if _context is None:  # pragma: no cover
41            _context = init_context()
42        _ssl_context = _context
43
44    # the spec requires SNI support
45    ssl_sock = _ssl_context.wrap_socket(sock, server_hostname=server_hostname)
46    # Setting SSLContext.check_hostname to True only verifies that the
47    # post-handshake servername matches that of the certificate. We also need
48    # to check that it matches the requested one.
49    if _ssl_context.check_hostname:  # pragma: no cover
50        try:
51            ssl.match_hostname(ssl_sock.getpeercert(), server_hostname)
52        except AttributeError:
53            ssl.verify_hostname(ssl_sock, server_hostname)  # pyopenssl
54
55    # Allow for the protocol to be forced externally.
56    proto = force_proto
57
58    # ALPN is newer, so we prefer it over NPN. The odds of us getting
59    # different answers is pretty low, but let's be sure.
60    with ignore_missing():
61        if proto is None:
62            proto = ssl_sock.selected_alpn_protocol()
63
64    with ignore_missing():
65        if proto is None:
66            proto = ssl_sock.selected_npn_protocol()
67
68    return (ssl_sock, proto)
69
70
71def init_context(cert_path=None, cert=None, cert_password=None):
72    """
73    Create a new ``SSLContext`` that is correctly set up for an HTTP/2
74    connection. This SSL context object can be customized and passed as a
75    parameter to the :class:`HTTPConnection <hyper.HTTPConnection>` class.
76    Provide your own certificate file in case you don’t want to use hyper’s
77    default certificate. The path to the certificate can be absolute or
78    relative to your working directory.
79
80    :param cert_path: (optional) The path to the certificate file of
81        “certification authority” (CA) certificates
82    :param cert: (optional) if string, path to ssl client cert file (.pem).
83        If tuple, ('cert', 'key') pair.
84        The certfile string must be the path to a single file in PEM format
85        containing the certificate as well as any number of CA certificates
86        needed to establish the certificate’s authenticity. The keyfile string,
87        if present, must point to a file containing the private key in.
88        Otherwise the private key will be taken from certfile as well.
89    :param cert_password: (optional) The password argument may be a function to
90        call to get the password for decrypting the private key. It will only
91        be called if the private key is encrypted and a password is necessary.
92        It will be called with no arguments, and it should return a string,
93        bytes, or bytearray. If the return value is a string it will be
94        encoded as UTF-8 before using it to decrypt the key. Alternatively a
95        string, bytes, or bytearray value may be supplied directly as the
96        password argument. It will be ignored if the private key is not
97        encrypted and no password is needed.
98    :returns: An ``SSLContext`` correctly set up for HTTP/2.
99    """
100    cafile = cert_path or cert_loc
101    if not cafile or not path.exists(cafile):
102        err_msg = ("No certificate found at " + str(cafile) + ". Either " +
103                   "ensure the default cert.pem file is included in the " +
104                   "distribution or provide a custom certificate when " +
105                   "creating the connection.")
106        raise MissingCertFile(err_msg)
107
108    context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
109    context.set_default_verify_paths()
110    context.load_verify_locations(cafile=cafile)
111    context.verify_mode = ssl.CERT_REQUIRED
112    context.check_hostname = True
113
114    with ignore_missing():
115        context.set_npn_protocols(SUPPORTED_NPN_PROTOCOLS)
116
117    with ignore_missing():
118        context.set_alpn_protocols(SUPPORTED_NPN_PROTOCOLS)
119
120    # required by the spec
121    context.options |= ssl.OP_NO_COMPRESSION
122
123    if cert is not None:
124        try:
125            basestring
126        except NameError:
127            basestring = (str, bytes)
128        if not isinstance(cert, basestring):
129            context.load_cert_chain(cert[0], cert[1], cert_password)
130        else:
131            context.load_cert_chain(cert, password=cert_password)
132
133    return context
134