1# -*- coding: utf-8 -*-
2#
3# Copyright: (c) 2016, Yanis Guenane <yanis+ansible@guenane.org>
4# Copyright: (c) 2020, Felix Fontein <felix@fontein.de>
5# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
6
7from __future__ import absolute_import, division, print_function
8__metaclass__ = type
9
10
11import abc
12import binascii
13import traceback
14
15from distutils.version import LooseVersion
16
17from ansible.module_utils import six
18from ansible.module_utils.basic import missing_required_lib
19from ansible.module_utils.common.text.converters import to_bytes, to_text
20
21from ansible_collections.community.crypto.plugins.module_utils.crypto.basic import (
22    OpenSSLObjectError,
23    OpenSSLBadPassphraseError,
24)
25
26from ansible_collections.community.crypto.plugins.module_utils.crypto.support import (
27    load_privatekey,
28    load_certificate_request,
29    parse_name_field,
30    select_message_digest,
31)
32
33from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptography_support import (
34    cryptography_get_basic_constraints,
35    cryptography_get_name,
36    cryptography_name_to_oid,
37    cryptography_key_needs_digest_for_signing,
38    cryptography_parse_key_usage_params,
39    cryptography_parse_relative_distinguished_name,
40)
41
42from ansible_collections.community.crypto.plugins.module_utils.crypto.cryptography_crl import (
43    REVOCATION_REASON_MAP,
44)
45
46from ansible_collections.community.crypto.plugins.module_utils.crypto.pyopenssl_support import (
47    pyopenssl_normalize_name_attribute,
48    pyopenssl_parse_name_constraints,
49)
50
51from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.csr_info import (
52    get_csr_info,
53)
54
55from ansible_collections.community.crypto.plugins.module_utils.crypto.module_backends.common import ArgumentSpec
56
57
58MINIMAL_PYOPENSSL_VERSION = '0.15'
59MINIMAL_CRYPTOGRAPHY_VERSION = '1.3'
60
61PYOPENSSL_IMP_ERR = None
62try:
63    import OpenSSL
64    from OpenSSL import crypto
65    PYOPENSSL_VERSION = LooseVersion(OpenSSL.__version__)
66except ImportError:
67    PYOPENSSL_IMP_ERR = traceback.format_exc()
68    PYOPENSSL_FOUND = False
69else:
70    PYOPENSSL_FOUND = True
71    if OpenSSL.SSL.OPENSSL_VERSION_NUMBER >= 0x10100000:
72        # OpenSSL 1.1.0 or newer
73        OPENSSL_MUST_STAPLE_NAME = b"tlsfeature"
74        OPENSSL_MUST_STAPLE_VALUE = b"status_request"
75    else:
76        # OpenSSL 1.0.x or older
77        OPENSSL_MUST_STAPLE_NAME = b"1.3.6.1.5.5.7.1.24"
78        OPENSSL_MUST_STAPLE_VALUE = b"DER:30:03:02:01:05"
79
80CRYPTOGRAPHY_IMP_ERR = None
81try:
82    import cryptography
83    import cryptography.x509
84    import cryptography.x509.oid
85    import cryptography.exceptions
86    import cryptography.hazmat.backends
87    import cryptography.hazmat.primitives.serialization
88    import cryptography.hazmat.primitives.hashes
89    CRYPTOGRAPHY_VERSION = LooseVersion(cryptography.__version__)
90except ImportError:
91    CRYPTOGRAPHY_IMP_ERR = traceback.format_exc()
92    CRYPTOGRAPHY_FOUND = False
93else:
94    CRYPTOGRAPHY_FOUND = True
95    CRYPTOGRAPHY_MUST_STAPLE_NAME = cryptography.x509.oid.ObjectIdentifier("1.3.6.1.5.5.7.1.24")
96    CRYPTOGRAPHY_MUST_STAPLE_VALUE = b"\x30\x03\x02\x01\x05"
97
98
99class CertificateSigningRequestError(OpenSSLObjectError):
100    pass
101
102
103# From the object called `module`, only the following properties are used:
104#
105#  - module.params[]
106#  - module.warn(msg: str)
107#  - module.fail_json(msg: str, **kwargs)
108
109
110@six.add_metaclass(abc.ABCMeta)
111class CertificateSigningRequestBackend(object):
112    def __init__(self, module, backend):
113        self.module = module
114        self.backend = backend
115        self.digest = module.params['digest']
116        self.privatekey_path = module.params['privatekey_path']
117        self.privatekey_content = module.params['privatekey_content']
118        if self.privatekey_content is not None:
119            self.privatekey_content = self.privatekey_content.encode('utf-8')
120        self.privatekey_passphrase = module.params['privatekey_passphrase']
121        self.version = module.params['version']
122        self.subjectAltName = module.params['subject_alt_name']
123        self.subjectAltName_critical = module.params['subject_alt_name_critical']
124        self.keyUsage = module.params['key_usage']
125        self.keyUsage_critical = module.params['key_usage_critical']
126        self.extendedKeyUsage = module.params['extended_key_usage']
127        self.extendedKeyUsage_critical = module.params['extended_key_usage_critical']
128        self.basicConstraints = module.params['basic_constraints']
129        self.basicConstraints_critical = module.params['basic_constraints_critical']
130        self.ocspMustStaple = module.params['ocsp_must_staple']
131        self.ocspMustStaple_critical = module.params['ocsp_must_staple_critical']
132        self.name_constraints_permitted = module.params['name_constraints_permitted'] or []
133        self.name_constraints_excluded = module.params['name_constraints_excluded'] or []
134        self.name_constraints_critical = module.params['name_constraints_critical']
135        self.create_subject_key_identifier = module.params['create_subject_key_identifier']
136        self.subject_key_identifier = module.params['subject_key_identifier']
137        self.authority_key_identifier = module.params['authority_key_identifier']
138        self.authority_cert_issuer = module.params['authority_cert_issuer']
139        self.authority_cert_serial_number = module.params['authority_cert_serial_number']
140        self.crl_distribution_points = module.params['crl_distribution_points']
141        self.csr = None
142        self.privatekey = None
143
144        if self.create_subject_key_identifier and self.subject_key_identifier is not None:
145            module.fail_json(msg='subject_key_identifier cannot be specified if create_subject_key_identifier is true')
146
147        self.subject = [
148            ('C', module.params['country_name']),
149            ('ST', module.params['state_or_province_name']),
150            ('L', module.params['locality_name']),
151            ('O', module.params['organization_name']),
152            ('OU', module.params['organizational_unit_name']),
153            ('CN', module.params['common_name']),
154            ('emailAddress', module.params['email_address']),
155        ]
156
157        if module.params['subject']:
158            self.subject = self.subject + parse_name_field(module.params['subject'])
159        self.subject = [(entry[0], entry[1]) for entry in self.subject if entry[1]]
160
161        self.using_common_name_for_san = False
162        if not self.subjectAltName and module.params['use_common_name_for_san']:
163            for sub in self.subject:
164                if sub[0] in ('commonName', 'CN'):
165                    self.subjectAltName = ['DNS:%s' % sub[1]]
166                    self.using_common_name_for_san = True
167                    break
168
169        if self.subject_key_identifier is not None:
170            try:
171                self.subject_key_identifier = binascii.unhexlify(self.subject_key_identifier.replace(':', ''))
172            except Exception as e:
173                raise CertificateSigningRequestError('Cannot parse subject_key_identifier: {0}'.format(e))
174
175        if self.authority_key_identifier is not None:
176            try:
177                self.authority_key_identifier = binascii.unhexlify(self.authority_key_identifier.replace(':', ''))
178            except Exception as e:
179                raise CertificateSigningRequestError('Cannot parse authority_key_identifier: {0}'.format(e))
180
181        self.existing_csr = None
182        self.existing_csr_bytes = None
183
184        self.diff_before = self._get_info(None)
185        self.diff_after = self._get_info(None)
186
187    def _get_info(self, data):
188        if data is None:
189            return dict()
190        try:
191            result = get_csr_info(
192                self.module, self.backend, data, validate_signature=False, prefer_one_fingerprint=True)
193            result['can_parse_csr'] = True
194            return result
195        except Exception as exc:
196            return dict(can_parse_csr=False)
197
198    @abc.abstractmethod
199    def generate_csr(self):
200        """(Re-)Generate CSR."""
201        pass
202
203    @abc.abstractmethod
204    def get_csr_data(self):
205        """Return bytes for self.csr."""
206        pass
207
208    def set_existing(self, csr_bytes):
209        """Set existing CSR bytes. None indicates that the CSR does not exist."""
210        self.existing_csr_bytes = csr_bytes
211        self.diff_after = self.diff_before = self._get_info(self.existing_csr_bytes)
212
213    def has_existing(self):
214        """Query whether an existing CSR is/has been there."""
215        return self.existing_csr_bytes is not None
216
217    def _ensure_private_key_loaded(self):
218        """Load the provided private key into self.privatekey."""
219        if self.privatekey is not None:
220            return
221        try:
222            self.privatekey = load_privatekey(
223                path=self.privatekey_path,
224                content=self.privatekey_content,
225                passphrase=self.privatekey_passphrase,
226                backend=self.backend,
227            )
228        except OpenSSLBadPassphraseError as exc:
229            raise CertificateSigningRequestError(exc)
230
231    @abc.abstractmethod
232    def _check_csr(self):
233        """Check whether provided parameters, assuming self.existing_csr and self.privatekey have been populated."""
234        pass
235
236    def needs_regeneration(self):
237        """Check whether a regeneration is necessary."""
238        if self.existing_csr_bytes is None:
239            return True
240        try:
241            self.existing_csr = load_certificate_request(None, content=self.existing_csr_bytes, backend=self.backend)
242        except Exception as dummy:
243            return True
244        self._ensure_private_key_loaded()
245        return not self._check_csr()
246
247    def dump(self, include_csr):
248        """Serialize the object into a dictionary."""
249        result = {
250            'privatekey': self.privatekey_path,
251            'subject': self.subject,
252            'subjectAltName': self.subjectAltName,
253            'keyUsage': self.keyUsage,
254            'extendedKeyUsage': self.extendedKeyUsage,
255            'basicConstraints': self.basicConstraints,
256            'ocspMustStaple': self.ocspMustStaple,
257            'name_constraints_permitted': self.name_constraints_permitted,
258            'name_constraints_excluded': self.name_constraints_excluded,
259        }
260        # Get hold of CSR bytes
261        csr_bytes = self.existing_csr_bytes
262        if self.csr is not None:
263            csr_bytes = self.get_csr_data()
264        self.diff_after = self._get_info(csr_bytes)
265        if include_csr:
266            # Store result
267            result['csr'] = csr_bytes.decode('utf-8') if csr_bytes else None
268
269        result['diff'] = dict(
270            before=self.diff_before,
271            after=self.diff_after,
272        )
273        return result
274
275
276# Implementation with using pyOpenSSL
277class CertificateSigningRequestPyOpenSSLBackend(CertificateSigningRequestBackend):
278    def __init__(self, module):
279        for o in ('create_subject_key_identifier', ):
280            if module.params[o]:
281                module.fail_json(msg='You cannot use {0} with the pyOpenSSL backend!'.format(o))
282        for o in ('subject_key_identifier', 'authority_key_identifier', 'authority_cert_issuer', 'authority_cert_serial_number', 'crl_distribution_points'):
283            if module.params[o] is not None:
284                module.fail_json(msg='You cannot use {0} with the pyOpenSSL backend!'.format(o))
285        super(CertificateSigningRequestPyOpenSSLBackend, self).__init__(module, 'pyopenssl')
286
287    def generate_csr(self):
288        """(Re-)Generate CSR."""
289        self._ensure_private_key_loaded()
290
291        req = crypto.X509Req()
292        req.set_version(self.version - 1)
293        subject = req.get_subject()
294        for entry in self.subject:
295            if entry[1] is not None:
296                # Workaround for https://github.com/pyca/pyopenssl/issues/165
297                nid = OpenSSL._util.lib.OBJ_txt2nid(to_bytes(entry[0]))
298                if nid == 0:
299                    raise CertificateSigningRequestError('Unknown subject field identifier "{0}"'.format(entry[0]))
300                res = OpenSSL._util.lib.X509_NAME_add_entry_by_NID(subject._name, nid, OpenSSL._util.lib.MBSTRING_UTF8, to_bytes(entry[1]), -1, -1, 0)
301                if res == 0:
302                    raise CertificateSigningRequestError('Invalid value for subject field identifier "{0}": {1}'.format(entry[0], entry[1]))
303
304        extensions = []
305        if self.subjectAltName:
306            altnames = ', '.join(self.subjectAltName)
307            try:
308                extensions.append(crypto.X509Extension(b"subjectAltName", self.subjectAltName_critical, altnames.encode('ascii')))
309            except OpenSSL.crypto.Error as e:
310                raise CertificateSigningRequestError(
311                    'Error while parsing Subject Alternative Names {0} (check for missing type prefix, such as "DNS:"!): {1}'.format(
312                        ', '.join(["{0}".format(san) for san in self.subjectAltName]), str(e)
313                    )
314                )
315
316        if self.keyUsage:
317            usages = ', '.join(self.keyUsage)
318            extensions.append(crypto.X509Extension(b"keyUsage", self.keyUsage_critical, usages.encode('ascii')))
319
320        if self.extendedKeyUsage:
321            usages = ', '.join(self.extendedKeyUsage)
322            extensions.append(crypto.X509Extension(b"extendedKeyUsage", self.extendedKeyUsage_critical, usages.encode('ascii')))
323
324        if self.basicConstraints:
325            usages = ', '.join(self.basicConstraints)
326            extensions.append(crypto.X509Extension(b"basicConstraints", self.basicConstraints_critical, usages.encode('ascii')))
327
328        if self.name_constraints_permitted or self.name_constraints_excluded:
329            usages = ', '.join(
330                ['permitted;{0}'.format(name) for name in self.name_constraints_permitted] +
331                ['excluded;{0}'.format(name) for name in self.name_constraints_excluded]
332            )
333            extensions.append(crypto.X509Extension(b"nameConstraints", self.name_constraints_critical, usages.encode('ascii')))
334
335        if self.ocspMustStaple:
336            extensions.append(crypto.X509Extension(OPENSSL_MUST_STAPLE_NAME, self.ocspMustStaple_critical, OPENSSL_MUST_STAPLE_VALUE))
337
338        if extensions:
339            req.add_extensions(extensions)
340
341        req.set_pubkey(self.privatekey)
342        req.sign(self.privatekey, self.digest)
343        self.csr = req
344
345    def get_csr_data(self):
346        """Return bytes for self.csr."""
347        return crypto.dump_certificate_request(crypto.FILETYPE_PEM, self.csr)
348
349    def _check_csr(self):
350        def _check_subject(csr):
351            subject = [(OpenSSL._util.lib.OBJ_txt2nid(to_bytes(sub[0])), to_bytes(sub[1])) for sub in self.subject]
352            current_subject = [(OpenSSL._util.lib.OBJ_txt2nid(to_bytes(sub[0])), to_bytes(sub[1])) for sub in csr.get_subject().get_components()]
353            if not set(subject) == set(current_subject):
354                return False
355
356            return True
357
358        def _check_subjectAltName(extensions):
359            altnames_ext = next((ext for ext in extensions if ext.get_short_name() == b'subjectAltName'), '')
360            altnames = [pyopenssl_normalize_name_attribute(altname.strip()) for altname in
361                        to_text(altnames_ext, errors='surrogate_or_strict').split(',') if altname.strip()]
362            if self.subjectAltName:
363                if (set(altnames) != set([pyopenssl_normalize_name_attribute(to_text(name)) for name in self.subjectAltName]) or
364                        altnames_ext.get_critical() != self.subjectAltName_critical):
365                    return False
366            else:
367                if altnames:
368                    return False
369
370            return True
371
372        def _check_keyUsage_(extensions, extName, expected, critical):
373            usages_ext = [ext for ext in extensions if ext.get_short_name() == extName]
374            if (not usages_ext and expected) or (usages_ext and not expected):
375                return False
376            elif not usages_ext and not expected:
377                return True
378            else:
379                current = [OpenSSL._util.lib.OBJ_txt2nid(to_bytes(usage.strip())) for usage in str(usages_ext[0]).split(',')]
380                expected = [OpenSSL._util.lib.OBJ_txt2nid(to_bytes(usage)) for usage in expected]
381                return set(current) == set(expected) and usages_ext[0].get_critical() == critical
382
383        def _check_keyUsage(extensions):
384            usages_ext = [ext for ext in extensions if ext.get_short_name() == b'keyUsage']
385            if (not usages_ext and self.keyUsage) or (usages_ext and not self.keyUsage):
386                return False
387            elif not usages_ext and not self.keyUsage:
388                return True
389            else:
390                # OpenSSL._util.lib.OBJ_txt2nid() always returns 0 for all keyUsage values
391                # (since keyUsage has a fixed bitfield for these values and is not extensible).
392                # Therefore, we create an extension for the wanted values, and compare the
393                # data of the extensions (which is the serialized bitfield).
394                expected_ext = crypto.X509Extension(b"keyUsage", False, ', '.join(self.keyUsage).encode('ascii'))
395                return usages_ext[0].get_data() == expected_ext.get_data() and usages_ext[0].get_critical() == self.keyUsage_critical
396
397        def _check_extenededKeyUsage(extensions):
398            return _check_keyUsage_(extensions, b'extendedKeyUsage', self.extendedKeyUsage, self.extendedKeyUsage_critical)
399
400        def _check_basicConstraints(extensions):
401            return _check_keyUsage_(extensions, b'basicConstraints', self.basicConstraints, self.basicConstraints_critical)
402
403        def _check_nameConstraints(extensions):
404            nc_ext = next((ext for ext in extensions if ext.get_short_name() == b'nameConstraints'), '')
405            permitted, excluded = pyopenssl_parse_name_constraints(nc_ext)
406            if self.name_constraints_permitted or self.name_constraints_excluded:
407                if set(permitted) != set([pyopenssl_normalize_name_attribute(to_text(name)) for name in self.name_constraints_permitted]):
408                    return False
409                if set(excluded) != set([pyopenssl_normalize_name_attribute(to_text(name)) for name in self.name_constraints_excluded]):
410                    return False
411                if nc_ext.get_critical() != self.name_constraints_critical:
412                    return False
413            else:
414                if permitted or excluded:
415                    return False
416
417            return True
418
419        def _check_ocspMustStaple(extensions):
420            oms_ext = [ext for ext in extensions if to_bytes(ext.get_short_name()) == OPENSSL_MUST_STAPLE_NAME and to_bytes(ext) == OPENSSL_MUST_STAPLE_VALUE]
421            if OpenSSL.SSL.OPENSSL_VERSION_NUMBER < 0x10100000:
422                # Older versions of libssl don't know about OCSP Must Staple
423                oms_ext.extend([ext for ext in extensions if ext.get_short_name() == b'UNDEF' and ext.get_data() == b'\x30\x03\x02\x01\x05'])
424            if self.ocspMustStaple:
425                return len(oms_ext) > 0 and oms_ext[0].get_critical() == self.ocspMustStaple_critical
426            else:
427                return len(oms_ext) == 0
428
429        def _check_extensions(csr):
430            extensions = csr.get_extensions()
431            return (_check_subjectAltName(extensions) and _check_keyUsage(extensions) and
432                    _check_extenededKeyUsage(extensions) and _check_basicConstraints(extensions) and
433                    _check_ocspMustStaple(extensions) and _check_nameConstraints(extensions))
434
435        def _check_signature(csr):
436            try:
437                return csr.verify(self.privatekey)
438            except crypto.Error:
439                return False
440
441        return _check_subject(self.existing_csr) and _check_extensions(self.existing_csr) and _check_signature(self.existing_csr)
442
443
444def parse_crl_distribution_points(module, crl_distribution_points):
445    result = []
446    for index, parse_crl_distribution_point in enumerate(crl_distribution_points):
447        try:
448            params = dict(
449                full_name=None,
450                relative_name=None,
451                crl_issuer=None,
452                reasons=None,
453            )
454            if parse_crl_distribution_point['full_name'] is not None:
455                params['full_name'] = [cryptography_get_name(name, 'full name') for name in parse_crl_distribution_point['full_name']]
456            if parse_crl_distribution_point['relative_name'] is not None:
457                try:
458                    params['relative_name'] = cryptography_parse_relative_distinguished_name(parse_crl_distribution_point['relative_name'])
459                except Exception:
460                    # If cryptography's version is < 1.6, the error is probably caused by that
461                    if CRYPTOGRAPHY_VERSION < LooseVersion('1.6'):
462                        raise OpenSSLObjectError('Cannot specify relative_name for cryptography < 1.6')
463                    raise
464            if parse_crl_distribution_point['crl_issuer'] is not None:
465                params['crl_issuer'] = [cryptography_get_name(name, 'CRL issuer') for name in parse_crl_distribution_point['crl_issuer']]
466            if parse_crl_distribution_point['reasons'] is not None:
467                reasons = []
468                for reason in parse_crl_distribution_point['reasons']:
469                    reasons.append(REVOCATION_REASON_MAP[reason])
470                params['reasons'] = frozenset(reasons)
471            result.append(cryptography.x509.DistributionPoint(**params))
472        except OpenSSLObjectError as e:
473            raise OpenSSLObjectError('Error while parsing CRL distribution point #{index}: {error}'.format(index=index, error=e))
474    return result
475
476
477# Implementation with using cryptography
478class CertificateSigningRequestCryptographyBackend(CertificateSigningRequestBackend):
479    def __init__(self, module):
480        super(CertificateSigningRequestCryptographyBackend, self).__init__(module, 'cryptography')
481        self.cryptography_backend = cryptography.hazmat.backends.default_backend()
482        if self.version != 1:
483            module.warn('The cryptography backend only supports version 1. (The only valid value according to RFC 2986.)')
484
485        if self.crl_distribution_points:
486            self.crl_distribution_points = parse_crl_distribution_points(module, self.crl_distribution_points)
487
488    def generate_csr(self):
489        """(Re-)Generate CSR."""
490        self._ensure_private_key_loaded()
491
492        csr = cryptography.x509.CertificateSigningRequestBuilder()
493        try:
494            csr = csr.subject_name(cryptography.x509.Name([
495                cryptography.x509.NameAttribute(cryptography_name_to_oid(entry[0]), to_text(entry[1])) for entry in self.subject
496            ]))
497        except ValueError as e:
498            raise CertificateSigningRequestError(e)
499
500        if self.subjectAltName:
501            csr = csr.add_extension(cryptography.x509.SubjectAlternativeName([
502                cryptography_get_name(name) for name in self.subjectAltName
503            ]), critical=self.subjectAltName_critical)
504
505        if self.keyUsage:
506            params = cryptography_parse_key_usage_params(self.keyUsage)
507            csr = csr.add_extension(cryptography.x509.KeyUsage(**params), critical=self.keyUsage_critical)
508
509        if self.extendedKeyUsage:
510            usages = [cryptography_name_to_oid(usage) for usage in self.extendedKeyUsage]
511            csr = csr.add_extension(cryptography.x509.ExtendedKeyUsage(usages), critical=self.extendedKeyUsage_critical)
512
513        if self.basicConstraints:
514            params = {}
515            ca, path_length = cryptography_get_basic_constraints(self.basicConstraints)
516            csr = csr.add_extension(cryptography.x509.BasicConstraints(ca, path_length), critical=self.basicConstraints_critical)
517
518        if self.ocspMustStaple:
519            try:
520                # This only works with cryptography >= 2.1
521                csr = csr.add_extension(cryptography.x509.TLSFeature([cryptography.x509.TLSFeatureType.status_request]), critical=self.ocspMustStaple_critical)
522            except AttributeError as dummy:
523                csr = csr.add_extension(
524                    cryptography.x509.UnrecognizedExtension(CRYPTOGRAPHY_MUST_STAPLE_NAME, CRYPTOGRAPHY_MUST_STAPLE_VALUE),
525                    critical=self.ocspMustStaple_critical
526                )
527
528        if self.name_constraints_permitted or self.name_constraints_excluded:
529            try:
530                csr = csr.add_extension(cryptography.x509.NameConstraints(
531                    [cryptography_get_name(name, 'name constraints permitted') for name in self.name_constraints_permitted],
532                    [cryptography_get_name(name, 'name constraints excluded') for name in self.name_constraints_excluded],
533                ), critical=self.name_constraints_critical)
534            except TypeError as e:
535                raise OpenSSLObjectError('Error while parsing name constraint: {0}'.format(e))
536
537        if self.create_subject_key_identifier:
538            csr = csr.add_extension(
539                cryptography.x509.SubjectKeyIdentifier.from_public_key(self.privatekey.public_key()),
540                critical=False
541            )
542        elif self.subject_key_identifier is not None:
543            csr = csr.add_extension(cryptography.x509.SubjectKeyIdentifier(self.subject_key_identifier), critical=False)
544
545        if self.authority_key_identifier is not None or self.authority_cert_issuer is not None or self.authority_cert_serial_number is not None:
546            issuers = None
547            if self.authority_cert_issuer is not None:
548                issuers = [cryptography_get_name(n, 'authority cert issuer') for n in self.authority_cert_issuer]
549            csr = csr.add_extension(
550                cryptography.x509.AuthorityKeyIdentifier(self.authority_key_identifier, issuers, self.authority_cert_serial_number),
551                critical=False
552            )
553
554        if self.crl_distribution_points:
555            csr = csr.add_extension(
556                cryptography.x509.CRLDistributionPoints(self.crl_distribution_points),
557                critical=False
558            )
559
560        digest = None
561        if cryptography_key_needs_digest_for_signing(self.privatekey):
562            digest = select_message_digest(self.digest)
563            if digest is None:
564                raise CertificateSigningRequestError('Unsupported digest "{0}"'.format(self.digest))
565        try:
566            self.csr = csr.sign(self.privatekey, digest, self.cryptography_backend)
567        except TypeError as e:
568            if str(e) == 'Algorithm must be a registered hash algorithm.' and digest is None:
569                self.module.fail_json(msg='Signing with Ed25519 and Ed448 keys requires cryptography 2.8 or newer.')
570            raise
571        except UnicodeError as e:
572            # This catches IDNAErrors, which happens when a bad name is passed as a SAN
573            # (https://github.com/ansible-collections/community.crypto/issues/105).
574            # For older cryptography versions, this is handled by idna, which raises
575            # an idna.core.IDNAError. Later versions of cryptography deprecated and stopped
576            # requiring idna, whence we cannot easily handle this error. Fortunately, in
577            # most versions of idna, IDNAError extends UnicodeError. There is only version
578            # 2.3 where it extends Exception instead (see
579            # https://github.com/kjd/idna/commit/ebefacd3134d0f5da4745878620a6a1cba86d130
580            # and then
581            # https://github.com/kjd/idna/commit/ea03c7b5db7d2a99af082e0239da2b68aeea702a).
582            msg = 'Error while creating CSR: {0}\n'.format(e)
583            if self.using_common_name_for_san:
584                self.module.fail_json(msg=msg + 'This is probably caused because the Common Name is used as a SAN.'
585                                      ' Specifying use_common_name_for_san=false might fix this.')
586            self.module.fail_json(msg=msg + 'This is probably caused by an invalid Subject Alternative DNS Name.')
587
588    def get_csr_data(self):
589        """Return bytes for self.csr."""
590        return self.csr.public_bytes(cryptography.hazmat.primitives.serialization.Encoding.PEM)
591
592    def _check_csr(self):
593        """Check whether provided parameters, assuming self.existing_csr and self.privatekey have been populated."""
594        def _check_subject(csr):
595            subject = [(cryptography_name_to_oid(entry[0]), to_text(entry[1])) for entry in self.subject]
596            current_subject = [(sub.oid, sub.value) for sub in csr.subject]
597            return set(subject) == set(current_subject)
598
599        def _find_extension(extensions, exttype):
600            return next(
601                (ext for ext in extensions if isinstance(ext.value, exttype)),
602                None
603            )
604
605        def _check_subjectAltName(extensions):
606            current_altnames_ext = _find_extension(extensions, cryptography.x509.SubjectAlternativeName)
607            current_altnames = [to_text(altname) for altname in current_altnames_ext.value] if current_altnames_ext else []
608            altnames = [to_text(cryptography_get_name(altname)) for altname in self.subjectAltName] if self.subjectAltName else []
609            if set(altnames) != set(current_altnames):
610                return False
611            if altnames:
612                if current_altnames_ext.critical != self.subjectAltName_critical:
613                    return False
614            return True
615
616        def _check_keyUsage(extensions):
617            current_keyusage_ext = _find_extension(extensions, cryptography.x509.KeyUsage)
618            if not self.keyUsage:
619                return current_keyusage_ext is None
620            elif current_keyusage_ext is None:
621                return False
622            params = cryptography_parse_key_usage_params(self.keyUsage)
623            for param in params:
624                if getattr(current_keyusage_ext.value, '_' + param) != params[param]:
625                    return False
626            if current_keyusage_ext.critical != self.keyUsage_critical:
627                return False
628            return True
629
630        def _check_extenededKeyUsage(extensions):
631            current_usages_ext = _find_extension(extensions, cryptography.x509.ExtendedKeyUsage)
632            current_usages = [str(usage) for usage in current_usages_ext.value] if current_usages_ext else []
633            usages = [str(cryptography_name_to_oid(usage)) for usage in self.extendedKeyUsage] if self.extendedKeyUsage else []
634            if set(current_usages) != set(usages):
635                return False
636            if usages:
637                if current_usages_ext.critical != self.extendedKeyUsage_critical:
638                    return False
639            return True
640
641        def _check_basicConstraints(extensions):
642            bc_ext = _find_extension(extensions, cryptography.x509.BasicConstraints)
643            current_ca = bc_ext.value.ca if bc_ext else False
644            current_path_length = bc_ext.value.path_length if bc_ext else None
645            ca, path_length = cryptography_get_basic_constraints(self.basicConstraints)
646            # Check CA flag
647            if ca != current_ca:
648                return False
649            # Check path length
650            if path_length != current_path_length:
651                return False
652            # Check criticality
653            if self.basicConstraints:
654                return bc_ext is not None and bc_ext.critical == self.basicConstraints_critical
655            else:
656                return bc_ext is None
657
658        def _check_ocspMustStaple(extensions):
659            try:
660                # This only works with cryptography >= 2.1
661                tlsfeature_ext = _find_extension(extensions, cryptography.x509.TLSFeature)
662                has_tlsfeature = True
663            except AttributeError as dummy:
664                tlsfeature_ext = next(
665                    (ext for ext in extensions if ext.value.oid == CRYPTOGRAPHY_MUST_STAPLE_NAME),
666                    None
667                )
668                has_tlsfeature = False
669            if self.ocspMustStaple:
670                if not tlsfeature_ext or tlsfeature_ext.critical != self.ocspMustStaple_critical:
671                    return False
672                if has_tlsfeature:
673                    return cryptography.x509.TLSFeatureType.status_request in tlsfeature_ext.value
674                else:
675                    return tlsfeature_ext.value.value == CRYPTOGRAPHY_MUST_STAPLE_VALUE
676            else:
677                return tlsfeature_ext is None
678
679        def _check_nameConstraints(extensions):
680            current_nc_ext = _find_extension(extensions, cryptography.x509.NameConstraints)
681            current_nc_perm = [to_text(altname) for altname in current_nc_ext.value.permitted_subtrees] if current_nc_ext else []
682            current_nc_excl = [to_text(altname) for altname in current_nc_ext.value.excluded_subtrees] if current_nc_ext else []
683            nc_perm = [to_text(cryptography_get_name(altname, 'name constraints permitted')) for altname in self.name_constraints_permitted]
684            nc_excl = [to_text(cryptography_get_name(altname, 'name constraints excluded')) for altname in self.name_constraints_excluded]
685            if set(nc_perm) != set(current_nc_perm) or set(nc_excl) != set(current_nc_excl):
686                return False
687            if nc_perm or nc_excl:
688                if current_nc_ext.critical != self.name_constraints_critical:
689                    return False
690            return True
691
692        def _check_subject_key_identifier(extensions):
693            ext = _find_extension(extensions, cryptography.x509.SubjectKeyIdentifier)
694            if self.create_subject_key_identifier or self.subject_key_identifier is not None:
695                if not ext or ext.critical:
696                    return False
697                if self.create_subject_key_identifier:
698                    digest = cryptography.x509.SubjectKeyIdentifier.from_public_key(self.privatekey.public_key()).digest
699                    return ext.value.digest == digest
700                else:
701                    return ext.value.digest == self.subject_key_identifier
702            else:
703                return ext is None
704
705        def _check_authority_key_identifier(extensions):
706            ext = _find_extension(extensions, cryptography.x509.AuthorityKeyIdentifier)
707            if self.authority_key_identifier is not None or self.authority_cert_issuer is not None or self.authority_cert_serial_number is not None:
708                if not ext or ext.critical:
709                    return False
710                aci = None
711                csr_aci = None
712                if self.authority_cert_issuer is not None:
713                    aci = [to_text(cryptography_get_name(n, 'authority cert issuer')) for n in self.authority_cert_issuer]
714                if ext.value.authority_cert_issuer is not None:
715                    csr_aci = [to_text(n) for n in ext.value.authority_cert_issuer]
716                return (ext.value.key_identifier == self.authority_key_identifier
717                        and csr_aci == aci
718                        and ext.value.authority_cert_serial_number == self.authority_cert_serial_number)
719            else:
720                return ext is None
721
722        def _check_crl_distribution_points(extensions):
723            ext = _find_extension(extensions, cryptography.x509.CRLDistributionPoints)
724            if self.crl_distribution_points is None:
725                return ext is None
726            if not ext:
727                return False
728            return list(ext.value) == self.crl_distribution_points
729
730        def _check_extensions(csr):
731            extensions = csr.extensions
732            return (_check_subjectAltName(extensions) and _check_keyUsage(extensions) and
733                    _check_extenededKeyUsage(extensions) and _check_basicConstraints(extensions) and
734                    _check_ocspMustStaple(extensions) and _check_subject_key_identifier(extensions) and
735                    _check_authority_key_identifier(extensions) and _check_nameConstraints(extensions) and
736                    _check_crl_distribution_points(extensions))
737
738        def _check_signature(csr):
739            if not csr.is_signature_valid:
740                return False
741            # To check whether public key of CSR belongs to private key,
742            # encode both public keys and compare PEMs.
743            key_a = csr.public_key().public_bytes(
744                cryptography.hazmat.primitives.serialization.Encoding.PEM,
745                cryptography.hazmat.primitives.serialization.PublicFormat.SubjectPublicKeyInfo
746            )
747            key_b = self.privatekey.public_key().public_bytes(
748                cryptography.hazmat.primitives.serialization.Encoding.PEM,
749                cryptography.hazmat.primitives.serialization.PublicFormat.SubjectPublicKeyInfo
750            )
751            return key_a == key_b
752
753        return _check_subject(self.existing_csr) and _check_extensions(self.existing_csr) and _check_signature(self.existing_csr)
754
755
756def select_backend(module, backend):
757    if module.params['version'] != 1:
758        module.deprecate('The version option will only support allowed values from community.crypto 2.0.0 on. '
759                         'Currently, only the value 1 is allowed by RFC 2986',
760                         version='2.0.0', collection_name='community.crypto')
761
762    if backend == 'auto':
763        # Detection what is possible
764        can_use_cryptography = CRYPTOGRAPHY_FOUND and CRYPTOGRAPHY_VERSION >= LooseVersion(MINIMAL_CRYPTOGRAPHY_VERSION)
765        can_use_pyopenssl = PYOPENSSL_FOUND and PYOPENSSL_VERSION >= LooseVersion(MINIMAL_PYOPENSSL_VERSION)
766
767        # First try cryptography, then pyOpenSSL
768        if can_use_cryptography:
769            backend = 'cryptography'
770        elif can_use_pyopenssl:
771            backend = 'pyopenssl'
772
773        # Success?
774        if backend == 'auto':
775            module.fail_json(msg=("Can't detect any of the required Python libraries "
776                                  "cryptography (>= {0}) or PyOpenSSL (>= {1})").format(
777                                      MINIMAL_CRYPTOGRAPHY_VERSION,
778                                      MINIMAL_PYOPENSSL_VERSION))
779
780    if backend == 'pyopenssl':
781        if not PYOPENSSL_FOUND:
782            module.fail_json(msg=missing_required_lib('pyOpenSSL >= {0}'.format(MINIMAL_PYOPENSSL_VERSION)),
783                             exception=PYOPENSSL_IMP_ERR)
784        try:
785            getattr(crypto.X509Req, 'get_extensions')
786        except AttributeError:
787            module.fail_json(msg='You need to have PyOpenSSL>=0.15 to generate CSRs')
788
789        module.deprecate('The module is using the PyOpenSSL backend. This backend has been deprecated',
790                         version='2.0.0', collection_name='community.crypto')
791        return backend, CertificateSigningRequestPyOpenSSLBackend(module)
792    elif backend == 'cryptography':
793        if not CRYPTOGRAPHY_FOUND:
794            module.fail_json(msg=missing_required_lib('cryptography >= {0}'.format(MINIMAL_CRYPTOGRAPHY_VERSION)),
795                             exception=CRYPTOGRAPHY_IMP_ERR)
796        return backend, CertificateSigningRequestCryptographyBackend(module)
797    else:
798        raise Exception('Unsupported value for backend: {0}'.format(backend))
799
800
801def get_csr_argument_spec():
802    return ArgumentSpec(
803        argument_spec=dict(
804            digest=dict(type='str', default='sha256'),
805            privatekey_path=dict(type='path'),
806            privatekey_content=dict(type='str', no_log=True),
807            privatekey_passphrase=dict(type='str', no_log=True),
808            version=dict(type='int', default=1),
809            subject=dict(type='dict'),
810            country_name=dict(type='str', aliases=['C', 'countryName']),
811            state_or_province_name=dict(type='str', aliases=['ST', 'stateOrProvinceName']),
812            locality_name=dict(type='str', aliases=['L', 'localityName']),
813            organization_name=dict(type='str', aliases=['O', 'organizationName']),
814            organizational_unit_name=dict(type='str', aliases=['OU', 'organizationalUnitName']),
815            common_name=dict(type='str', aliases=['CN', 'commonName']),
816            email_address=dict(type='str', aliases=['E', 'emailAddress']),
817            subject_alt_name=dict(type='list', elements='str', aliases=['subjectAltName']),
818            subject_alt_name_critical=dict(type='bool', default=False, aliases=['subjectAltName_critical']),
819            use_common_name_for_san=dict(type='bool', default=True, aliases=['useCommonNameForSAN']),
820            key_usage=dict(type='list', elements='str', aliases=['keyUsage']),
821            key_usage_critical=dict(type='bool', default=False, aliases=['keyUsage_critical']),
822            extended_key_usage=dict(type='list', elements='str', aliases=['extKeyUsage', 'extendedKeyUsage']),
823            extended_key_usage_critical=dict(type='bool', default=False, aliases=['extKeyUsage_critical', 'extendedKeyUsage_critical']),
824            basic_constraints=dict(type='list', elements='str', aliases=['basicConstraints']),
825            basic_constraints_critical=dict(type='bool', default=False, aliases=['basicConstraints_critical']),
826            ocsp_must_staple=dict(type='bool', default=False, aliases=['ocspMustStaple']),
827            ocsp_must_staple_critical=dict(type='bool', default=False, aliases=['ocspMustStaple_critical']),
828            name_constraints_permitted=dict(type='list', elements='str'),
829            name_constraints_excluded=dict(type='list', elements='str'),
830            name_constraints_critical=dict(type='bool', default=False),
831            create_subject_key_identifier=dict(type='bool', default=False),
832            subject_key_identifier=dict(type='str'),
833            authority_key_identifier=dict(type='str'),
834            authority_cert_issuer=dict(type='list', elements='str'),
835            authority_cert_serial_number=dict(type='int'),
836            crl_distribution_points=dict(
837                type='list',
838                elements='dict',
839                options=dict(
840                    full_name=dict(type='list', elements='str'),
841                    relative_name=dict(type='list', elements='str'),
842                    crl_issuer=dict(type='list', elements='str'),
843                    reasons=dict(type='list', elements='str', choices=[
844                        'key_compromise',
845                        'ca_compromise',
846                        'affiliation_changed',
847                        'superseded',
848                        'cessation_of_operation',
849                        'certificate_hold',
850                        'privilege_withdrawn',
851                        'aa_compromise',
852                    ]),
853                ),
854                mutually_exclusive=[('full_name', 'relative_name')]
855            ),
856            select_crypto_backend=dict(type='str', default='auto', choices=['auto', 'cryptography', 'pyopenssl']),
857        ),
858        required_together=[
859            ['authority_cert_issuer', 'authority_cert_serial_number'],
860        ],
861        mutually_exclusive=[
862            ['privatekey_path', 'privatekey_content'],
863        ],
864        required_one_of=[
865            ['privatekey_path', 'privatekey_content'],
866        ],
867    )
868