1"""Crypto utilities."""
2import binascii
3import contextlib
4import ipaddress
5import logging
6import os
7import re
8import socket
9from typing import Any
10from typing import Callable
11from typing import List
12from typing import Mapping
13from typing import Optional
14from typing import Set
15from typing import Tuple
16from typing import Union
17
18import josepy as jose
19from OpenSSL import crypto
20from OpenSSL import SSL
21
22from acme import errors
23
24logger = logging.getLogger(__name__)
25
26# Default SSL method selected here is the most compatible, while secure
27# SSL method: TLSv1_METHOD is only compatible with
28# TLSv1_METHOD, while SSLv23_METHOD is compatible with all other
29# methods, including TLSv2_METHOD (read more at
30# https://www.openssl.org/docs/ssl/SSLv23_method.html). _serve_sni
31# should be changed to use "set_options" to disable SSLv2 and SSLv3,
32# in case it's used for things other than probing/serving!
33_DEFAULT_SSL_METHOD = SSL.SSLv23_METHOD
34
35
36class _DefaultCertSelection:
37    def __init__(self, certs: Mapping[bytes, Tuple[crypto.PKey, crypto.X509]]):
38        self.certs = certs
39
40    def __call__(self, connection: SSL.Connection) -> Optional[Tuple[crypto.PKey, crypto.X509]]:
41        server_name = connection.get_servername()
42        return self.certs.get(server_name, None)
43
44
45class SSLSocket:  # pylint: disable=too-few-public-methods
46    """SSL wrapper for sockets.
47
48    :ivar socket sock: Original wrapped socket.
49    :ivar dict certs: Mapping from domain names (`bytes`) to
50        `OpenSSL.crypto.X509`.
51    :ivar method: See `OpenSSL.SSL.Context` for allowed values.
52    :ivar alpn_selection: Hook to select negotiated ALPN protocol for
53        connection.
54    :ivar cert_selection: Hook to select certificate for connection. If given,
55        `certs` parameter would be ignored, and therefore must be empty.
56
57    """
58    def __init__(self, sock: socket.socket,
59                 certs: Optional[Mapping[bytes, Tuple[crypto.PKey, crypto.X509]]] = None,
60                 method: int = _DEFAULT_SSL_METHOD,
61                 alpn_selection: Optional[Callable[[SSL.Connection, List[bytes]], bytes]] = None,
62                 cert_selection: Optional[Callable[[SSL.Connection],
63                                                   Tuple[crypto.PKey, crypto.X509]]] = None
64                 ) -> None:
65        self.sock = sock
66        self.alpn_selection = alpn_selection
67        self.method = method
68        if not cert_selection and not certs:
69            raise ValueError("Neither cert_selection or certs specified.")
70        if cert_selection and certs:
71            raise ValueError("Both cert_selection and certs specified.")
72        actual_cert_selection: Union[_DefaultCertSelection,
73                                     Optional[Callable[[SSL.Connection],
74                                                       Tuple[crypto.PKey,
75                                                             crypto.X509]]]] = cert_selection
76        if actual_cert_selection is None:
77            actual_cert_selection = _DefaultCertSelection(certs if certs else {})
78        self.cert_selection = actual_cert_selection
79
80    def __getattr__(self, name: str) -> Any:
81        return getattr(self.sock, name)
82
83    def _pick_certificate_cb(self, connection: SSL.Connection) -> None:
84        """SNI certificate callback.
85
86        This method will set a new OpenSSL context object for this
87        connection when an incoming connection provides an SNI name
88        (in order to serve the appropriate certificate, if any).
89
90        :param connection: The TLS connection object on which the SNI
91            extension was received.
92        :type connection: :class:`OpenSSL.Connection`
93
94        """
95        pair = self.cert_selection(connection)
96        if pair is None:
97            logger.debug("Certificate selection for server name %s failed, dropping SSL",
98                         connection.get_servername())
99            return
100        key, cert = pair
101        new_context = SSL.Context(self.method)
102        new_context.set_options(SSL.OP_NO_SSLv2)
103        new_context.set_options(SSL.OP_NO_SSLv3)
104        new_context.use_privatekey(key)
105        new_context.use_certificate(cert)
106        if self.alpn_selection is not None:
107            new_context.set_alpn_select_callback(self.alpn_selection)
108        connection.set_context(new_context)
109
110    class FakeConnection:
111        """Fake OpenSSL.SSL.Connection."""
112
113        # pylint: disable=missing-function-docstring
114
115        def __init__(self, connection: SSL.Connection) -> None:
116            self._wrapped = connection
117
118        def __getattr__(self, name: str) -> Any:
119            return getattr(self._wrapped, name)
120
121        def shutdown(self, *unused_args: Any) -> bool:
122            # OpenSSL.SSL.Connection.shutdown doesn't accept any args
123            return self._wrapped.shutdown()
124
125    def accept(self) -> Tuple[FakeConnection, Any]:  # pylint: disable=missing-function-docstring
126        sock, addr = self.sock.accept()
127
128        context = SSL.Context(self.method)
129        context.set_options(SSL.OP_NO_SSLv2)
130        context.set_options(SSL.OP_NO_SSLv3)
131        context.set_tlsext_servername_callback(self._pick_certificate_cb)
132        if self.alpn_selection is not None:
133            context.set_alpn_select_callback(self.alpn_selection)
134
135        ssl_sock = self.FakeConnection(SSL.Connection(context, sock))
136        ssl_sock.set_accept_state()
137
138        logger.debug("Performing handshake with %s", addr)
139        try:
140            ssl_sock.do_handshake()
141        except SSL.Error as error:
142            # _pick_certificate_cb might have returned without
143            # creating SSL context (wrong server name)
144            raise socket.error(error)
145
146        return ssl_sock, addr
147
148
149def probe_sni(name: bytes, host: bytes, port: int = 443, timeout: int = 300,  # pylint: disable=too-many-arguments
150              method: int = _DEFAULT_SSL_METHOD, source_address: Tuple[str, int] = ('', 0),
151              alpn_protocols: Optional[List[str]] = None) -> crypto.X509:
152    """Probe SNI server for SSL certificate.
153
154    :param bytes name: Byte string to send as the server name in the
155        client hello message.
156    :param bytes host: Host to connect to.
157    :param int port: Port to connect to.
158    :param int timeout: Timeout in seconds.
159    :param method: See `OpenSSL.SSL.Context` for allowed values.
160    :param tuple source_address: Enables multi-path probing (selection
161        of source interface). See `socket.creation_connection` for more
162        info. Available only in Python 2.7+.
163    :param alpn_protocols: Protocols to request using ALPN.
164    :type alpn_protocols: `list` of `str`
165
166    :raises acme.errors.Error: In case of any problems.
167
168    :returns: SSL certificate presented by the server.
169    :rtype: OpenSSL.crypto.X509
170
171    """
172    context = SSL.Context(method)
173    context.set_timeout(timeout)
174
175    socket_kwargs = {'source_address': source_address}
176
177    try:
178        logger.debug(
179            "Attempting to connect to %s:%d%s.", host, port,
180            " from {0}:{1}".format(
181                source_address[0],
182                source_address[1]
183            ) if any(source_address) else ""
184        )
185        socket_tuple: Tuple[bytes, int] = (host, port)
186        sock = socket.create_connection(socket_tuple, **socket_kwargs)  # type: ignore[arg-type]
187    except socket.error as error:
188        raise errors.Error(error)
189
190    with contextlib.closing(sock) as client:
191        client_ssl = SSL.Connection(context, client)
192        client_ssl.set_connect_state()
193        client_ssl.set_tlsext_host_name(name)  # pyOpenSSL>=0.13
194        if alpn_protocols is not None:
195            client_ssl.set_alpn_protos(alpn_protocols)
196        try:
197            client_ssl.do_handshake()
198            client_ssl.shutdown()
199        except SSL.Error as error:
200            raise errors.Error(error)
201    return client_ssl.get_peer_certificate()
202
203
204def make_csr(private_key_pem: bytes, domains: Optional[Union[Set[str], List[str]]] = None,
205             must_staple: bool = False,
206             ipaddrs: Optional[List[Union[ipaddress.IPv4Address, ipaddress.IPv6Address]]] = None
207             ) -> bytes:
208    """Generate a CSR containing domains or IPs as subjectAltNames.
209
210    :param buffer private_key_pem: Private key, in PEM PKCS#8 format.
211    :param list domains: List of DNS names to include in subjectAltNames of CSR.
212    :param bool must_staple: Whether to include the TLS Feature extension (aka
213        OCSP Must Staple: https://tools.ietf.org/html/rfc7633).
214    :param list ipaddrs: List of IPaddress(type ipaddress.IPv4Address or ipaddress.IPv6Address)
215    names to include in subbjectAltNames of CSR.
216    params ordered this way for backward competablity when called by positional argument.
217    :returns: buffer PEM-encoded Certificate Signing Request.
218    """
219    private_key = crypto.load_privatekey(
220        crypto.FILETYPE_PEM, private_key_pem)
221    csr = crypto.X509Req()
222    sanlist = []
223    # if domain or ip list not supplied make it empty list so it's easier to iterate
224    if domains is None:
225        domains = []
226    if ipaddrs is None:
227        ipaddrs = []
228    if len(domains)+len(ipaddrs) == 0:
229        raise ValueError("At least one of domains or ipaddrs parameter need to be not empty")
230    for address in domains:
231        sanlist.append('DNS:' + address)
232    for ips in ipaddrs:
233        sanlist.append('IP:' + ips.exploded)
234    # make sure its ascii encoded
235    san_string = ', '.join(sanlist).encode('ascii')
236    # for IP san it's actually need to be octet-string,
237    # but somewhere downsteam thankfully handle it for us
238    extensions = [
239        crypto.X509Extension(
240            b'subjectAltName',
241            critical=False,
242            value=san_string
243        ),
244    ]
245    if must_staple:
246        extensions.append(crypto.X509Extension(
247            b"1.3.6.1.5.5.7.1.24",
248            critical=False,
249            value=b"DER:30:03:02:01:05"))
250    csr.add_extensions(extensions)
251    csr.set_pubkey(private_key)
252    csr.set_version(2)
253    csr.sign(private_key, 'sha256')
254    return crypto.dump_certificate_request(
255        crypto.FILETYPE_PEM, csr)
256
257
258def _pyopenssl_cert_or_req_all_names(loaded_cert_or_req: Union[crypto.X509, crypto.X509Req]
259                                     ) -> List[str]:
260    # unlike its name this only outputs DNS names, other type of idents will ignored
261    common_name = loaded_cert_or_req.get_subject().CN
262    sans = _pyopenssl_cert_or_req_san(loaded_cert_or_req)
263
264    if common_name is None:
265        return sans
266    return [common_name] + [d for d in sans if d != common_name]
267
268
269def _pyopenssl_cert_or_req_san(cert_or_req: Union[crypto.X509, crypto.X509Req]) -> List[str]:
270    """Get Subject Alternative Names from certificate or CSR using pyOpenSSL.
271
272    .. todo:: Implement directly in PyOpenSSL!
273
274    .. note:: Although this is `acme` internal API, it is used by
275        `letsencrypt`.
276
277    :param cert_or_req: Certificate or CSR.
278    :type cert_or_req: `OpenSSL.crypto.X509` or `OpenSSL.crypto.X509Req`.
279
280    :returns: A list of Subject Alternative Names that is DNS.
281    :rtype: `list` of `unicode`
282
283    """
284    # This function finds SANs with dns name
285
286    # constants based on PyOpenSSL certificate/CSR text dump
287    part_separator = ":"
288    prefix = "DNS" + part_separator
289
290    sans_parts = _pyopenssl_extract_san_list_raw(cert_or_req)
291
292    return [part.split(part_separator)[1]
293            for part in sans_parts if part.startswith(prefix)]
294
295
296def _pyopenssl_cert_or_req_san_ip(cert_or_req: Union[crypto.X509, crypto.X509Req]) -> List[str]:
297    """Get Subject Alternative Names IPs from certificate or CSR using pyOpenSSL.
298
299    :param cert_or_req: Certificate or CSR.
300    :type cert_or_req: `OpenSSL.crypto.X509` or `OpenSSL.crypto.X509Req`.
301
302    :returns: A list of Subject Alternative Names that are IP Addresses.
303    :rtype: `list` of `unicode`. note that this returns as string, not IPaddress object
304
305    """
306
307    # constants based on PyOpenSSL certificate/CSR text dump
308    part_separator = ":"
309    prefix = "IP Address" + part_separator
310
311    sans_parts = _pyopenssl_extract_san_list_raw(cert_or_req)
312
313    return [part[len(prefix):] for part in sans_parts if part.startswith(prefix)]
314
315
316def _pyopenssl_extract_san_list_raw(cert_or_req: Union[crypto.X509, crypto.X509Req]) -> List[str]:
317    """Get raw SAN string from cert or csr, parse it as UTF-8 and return.
318
319    :param cert_or_req: Certificate or CSR.
320    :type cert_or_req: `OpenSSL.crypto.X509` or `OpenSSL.crypto.X509Req`.
321
322    :returns: raw san strings, parsed byte as utf-8
323    :rtype: `list` of `unicode`
324
325    """
326    # This function finds SANs by dumping the certificate/CSR to text and
327    # searching for "X509v3 Subject Alternative Name" in the text. This method
328    # is used to because in PyOpenSSL version <0.17 `_subjectAltNameString` methods are
329    # not able to Parse IP Addresses in subjectAltName string.
330
331    if isinstance(cert_or_req, crypto.X509):
332        # pylint: disable=line-too-long
333        text = crypto.dump_certificate(crypto.FILETYPE_TEXT, cert_or_req).decode('utf-8')
334    else:
335        text = crypto.dump_certificate_request(crypto.FILETYPE_TEXT, cert_or_req).decode('utf-8')
336    # WARNING: this function does not support multiple SANs extensions.
337    # Multiple X509v3 extensions of the same type is disallowed by RFC 5280.
338    raw_san = re.search(r"X509v3 Subject Alternative Name:(?: critical)?\s*(.*)", text)
339
340    parts_separator = ", "
341    # WARNING: this function assumes that no SAN can include
342    # parts_separator, hence the split!
343    sans_parts = [] if raw_san is None else raw_san.group(1).split(parts_separator)
344    return sans_parts
345
346
347def gen_ss_cert(key: crypto.PKey, domains: Optional[List[str]] = None,
348                not_before: Optional[int] = None,
349                validity: int = (7 * 24 * 60 * 60), force_san: bool = True,
350                extensions: Optional[List[crypto.X509Extension]] = None,
351                ips: Optional[List[Union[ipaddress.IPv4Address, ipaddress.IPv6Address]]] = None
352                ) -> crypto.X509:
353    """Generate new self-signed certificate.
354
355    :type domains: `list` of `unicode`
356    :param OpenSSL.crypto.PKey key:
357    :param bool force_san:
358    :param extensions: List of additional extensions to include in the cert.
359    :type extensions: `list` of `OpenSSL.crypto.X509Extension`
360    :type ips: `list` of (`ipaddress.IPv4Address` or `ipaddress.IPv6Address`)
361
362    If more than one domain is provided, all of the domains are put into
363    ``subjectAltName`` X.509 extension and first domain is set as the
364    subject CN. If only one domain is provided no ``subjectAltName``
365    extension is used, unless `force_san` is ``True``.
366
367    """
368    assert domains or ips, "Must provide one or more hostnames or IPs for the cert."
369
370    cert = crypto.X509()
371    cert.set_serial_number(int(binascii.hexlify(os.urandom(16)), 16))
372    cert.set_version(2)
373
374    if extensions is None:
375        extensions = []
376    if domains is None:
377        domains = []
378    if ips is None:
379        ips = []
380    extensions.append(
381        crypto.X509Extension(
382            b"basicConstraints", True, b"CA:TRUE, pathlen:0"),
383    )
384
385    if len(domains) > 0:
386        cert.get_subject().CN = domains[0]
387    # TODO: what to put into cert.get_subject()?
388    cert.set_issuer(cert.get_subject())
389
390    sanlist = []
391    for address in domains:
392        sanlist.append('DNS:' + address)
393    for ip in ips:
394        sanlist.append('IP:' + ip.exploded)
395    san_string = ', '.join(sanlist).encode('ascii')
396    if force_san or len(domains) > 1 or len(ips) > 0:
397        extensions.append(crypto.X509Extension(
398            b"subjectAltName",
399            critical=False,
400            value=san_string
401        ))
402
403    cert.add_extensions(extensions)
404
405    cert.gmtime_adj_notBefore(0 if not_before is None else not_before)
406    cert.gmtime_adj_notAfter(validity)
407
408    cert.set_pubkey(key)
409    cert.sign(key, "sha256")
410    return cert
411
412
413def dump_pyopenssl_chain(chain: List[crypto.X509], filetype: int = crypto.FILETYPE_PEM) -> bytes:
414    """Dump certificate chain into a bundle.
415
416    :param list chain: List of `OpenSSL.crypto.X509` (or wrapped in
417        :class:`josepy.util.ComparableX509`).
418
419    :returns: certificate chain bundle
420    :rtype: bytes
421
422    """
423    # XXX: returns empty string when no chain is available, which
424    # shuts up RenewableCert, but might not be the best solution...
425
426    def _dump_cert(cert: Union[jose.ComparableX509, crypto.X509]) -> bytes:
427        if isinstance(cert, jose.ComparableX509):
428            cert = cert.wrapped
429        return crypto.dump_certificate(filetype, cert)
430
431    # assumes that OpenSSL.crypto.dump_certificate includes ending
432    # newline character
433    return b"".join(_dump_cert(cert) for cert in chain)
434