1"""Tests for certbot.crypto_util."""
2import logging
3import unittest
4
5import certbot.util
6
7try:
8    import mock
9except ImportError:  # pragma: no cover
10    from unittest import mock
11import OpenSSL
12
13from certbot import errors
14from certbot import util
15from certbot.compat import filesystem
16from certbot.compat import os
17import certbot.tests.util as test_util
18
19RSA256_KEY = test_util.load_vector('rsa256_key.pem')
20RSA256_KEY_PATH = test_util.vector_path('rsa256_key.pem')
21RSA512_KEY = test_util.load_vector('rsa512_key.pem')
22RSA2048_KEY_PATH = test_util.vector_path('rsa2048_key.pem')
23CERT_PATH = test_util.vector_path('cert_512.pem')
24CERT = test_util.load_vector('cert_512.pem')
25SS_CERT_PATH = test_util.vector_path('cert_2048.pem')
26SS_CERT = test_util.load_vector('cert_2048.pem')
27P256_KEY = test_util.load_vector('nistp256_key.pem')
28P256_CERT_PATH = test_util.vector_path('cert-nosans_nistp256.pem')
29P256_CERT = test_util.load_vector('cert-nosans_nistp256.pem')
30# CERT_LEAF is signed by CERT_ISSUER. CERT_ALT_ISSUER is a cross-sign of CERT_ISSUER.
31CERT_LEAF = test_util.load_vector('cert_leaf.pem')
32CERT_ISSUER = test_util.load_vector('cert_intermediate_1.pem')
33CERT_ALT_ISSUER = test_util.load_vector('cert_intermediate_2.pem')
34
35
36class GenerateKeyTest(test_util.TempDirTestCase):
37    """Tests for certbot.crypto_util.generate_key."""
38    def setUp(self):
39        super().setUp()
40
41        self.workdir = os.path.join(self.tempdir, 'workdir')
42        filesystem.mkdir(self.workdir, mode=0o700)
43
44        logging.disable(logging.CRITICAL)
45
46    def tearDown(self):
47        super().tearDown()
48
49        logging.disable(logging.NOTSET)
50
51    @classmethod
52    def _call(cls, key_size, key_dir):
53        from certbot.crypto_util import generate_key
54        return generate_key(key_size, key_dir, 'key-certbot.pem', strict_permissions=True)
55
56    @mock.patch('certbot.crypto_util.make_key')
57    def test_success(self, mock_make):
58        mock_make.return_value = b'key_pem'
59        key = self._call(1024, self.workdir)
60        self.assertEqual(key.pem, b'key_pem')
61        self.assertIn('key-certbot.pem', key.file)
62        self.assertTrue(os.path.exists(os.path.join(self.workdir, key.file)))
63
64    @mock.patch('certbot.crypto_util.make_key')
65    def test_key_failure(self, mock_make):
66        mock_make.side_effect = ValueError
67        self.assertRaises(ValueError, self._call, 431, self.workdir)
68
69
70class InitSaveKey(unittest.TestCase):
71    """Test for certbot.crypto_util.init_save_key."""
72    @mock.patch("certbot.crypto_util.generate_key")
73    @mock.patch("certbot.crypto_util.zope.component")
74    def test_it(self, mock_zope, mock_generate):
75        from certbot.crypto_util import init_save_key
76
77        mock_zope.getUtility.return_value = mock.MagicMock(strict_permissions=True)
78
79        with self.assertWarns(DeprecationWarning):
80            init_save_key(4096, "/some/path")
81
82        mock_generate.assert_called_with(4096, "/some/path", elliptic_curve="secp256r1",
83                                         key_type="rsa", keyname="key-certbot.pem",
84                                         strict_permissions=True)
85
86
87class GenerateCSRTest(test_util.TempDirTestCase):
88    """Tests for certbot.crypto_util.generate_csr."""
89    @mock.patch('acme.crypto_util.make_csr')
90    @mock.patch('certbot.crypto_util.util.make_or_verify_dir')
91    def test_it(self, unused_mock_verify, mock_csr):
92        from certbot.crypto_util import generate_csr
93
94        mock_csr.return_value = b'csr_pem'
95
96        csr = generate_csr(
97            mock.Mock(pem='dummy_key'), 'example.com', self.tempdir, strict_permissions=True)
98
99        self.assertEqual(csr.data, b'csr_pem')
100        self.assertIn('csr-certbot.pem', csr.file)
101
102
103class InitSaveCsr(unittest.TestCase):
104    """Tests for certbot.crypto_util.init_save_csr."""
105    @mock.patch("certbot.crypto_util.generate_csr")
106    @mock.patch("certbot.crypto_util.zope.component")
107    def test_it(self, mock_zope, mock_generate):
108        from certbot.crypto_util import init_save_csr
109
110        mock_zope.getUtility.return_value = mock.MagicMock(must_staple=True,
111                                                           strict_permissions=True)
112        key = certbot.util.Key(file=None, pem=None)
113
114        with self.assertWarns(DeprecationWarning):
115            init_save_csr(key, {"dummy"}, "/some/path")
116
117        mock_generate.assert_called_with(key, {"dummy"}, "/some/path",
118                                         must_staple=True, strict_permissions=True)
119
120
121class ValidCSRTest(unittest.TestCase):
122    """Tests for certbot.crypto_util.valid_csr."""
123
124    @classmethod
125    def _call(cls, csr):
126        from certbot.crypto_util import valid_csr
127        return valid_csr(csr)
128
129    def test_valid_pem_true(self):
130        self.assertTrue(self._call(test_util.load_vector('csr_512.pem')))
131
132    def test_valid_pem_san_true(self):
133        self.assertTrue(self._call(test_util.load_vector('csr-san_512.pem')))
134
135    def test_valid_der_false(self):
136        self.assertFalse(self._call(test_util.load_vector('csr_512.der')))
137
138    def test_empty_false(self):
139        self.assertFalse(self._call(''))
140
141    def test_random_false(self):
142        self.assertFalse(self._call('foo bar'))
143
144
145class CSRMatchesPubkeyTest(unittest.TestCase):
146    """Tests for certbot.crypto_util.csr_matches_pubkey."""
147
148    @classmethod
149    def _call(cls, *args, **kwargs):
150        from certbot.crypto_util import csr_matches_pubkey
151        return csr_matches_pubkey(*args, **kwargs)
152
153    def test_valid_true(self):
154        self.assertTrue(self._call(
155            test_util.load_vector('csr_512.pem'), RSA512_KEY))
156
157    def test_invalid_false(self):
158        self.assertFalse(self._call(
159            test_util.load_vector('csr_512.pem'), RSA256_KEY))
160
161
162class ImportCSRFileTest(unittest.TestCase):
163    """Tests for certbot.certbot_util.import_csr_file."""
164
165    @classmethod
166    def _call(cls, *args, **kwargs):
167        from certbot.crypto_util import import_csr_file
168        return import_csr_file(*args, **kwargs)
169
170    def test_der_csr(self):
171        csrfile = test_util.vector_path('csr_512.der')
172        data = test_util.load_vector('csr_512.der')
173        data_pem = test_util.load_vector('csr_512.pem')
174
175        self.assertEqual(
176            (OpenSSL.crypto.FILETYPE_PEM,
177             util.CSR(file=csrfile,
178                      data=data_pem,
179                      form="pem"),
180             ["Example.com"]),
181            self._call(csrfile, data))
182
183    def test_pem_csr(self):
184        csrfile = test_util.vector_path('csr_512.pem')
185        data = test_util.load_vector('csr_512.pem')
186
187        self.assertEqual(
188            (OpenSSL.crypto.FILETYPE_PEM,
189             util.CSR(file=csrfile,
190                      data=data,
191                      form="pem"),
192             ["Example.com"],),
193            self._call(csrfile, data))
194
195    def test_bad_csr(self):
196        self.assertRaises(errors.Error, self._call,
197                          test_util.vector_path('cert_512.pem'),
198                          test_util.load_vector('cert_512.pem'))
199
200
201class MakeKeyTest(unittest.TestCase):
202    """Tests for certbot.crypto_util.make_key."""
203
204    def test_rsa(self):  # pylint: disable=no-self-use
205        # RSA Key Type Test
206        from certbot.crypto_util import make_key
207        # Do not test larger keys as it takes too long.
208        OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, make_key(1024))
209
210    def test_ec(self):  # pylint: disable=no-self-use
211        # ECDSA Key Type Tests
212        from certbot.crypto_util import make_key
213
214        for (name, bits) in [('secp256r1', 256), ('secp384r1', 384), ('secp521r1', 521)]:
215            pkey = OpenSSL.crypto.load_privatekey(
216                OpenSSL.crypto.FILETYPE_PEM,
217                make_key(elliptic_curve=name, key_type='ecdsa')
218            )
219            self.assertEqual(pkey.bits(), bits)
220
221    def test_bad_key_sizes(self):
222        from certbot.crypto_util import make_key
223        # Try a bad key size for RSA and ECDSA
224        with self.assertRaises(errors.Error) as e:
225            make_key(bits=512, key_type='rsa')
226        self.assertEqual(
227            "Unsupported RSA key length: 512",
228            str(e.exception),
229            "Unsupported RSA key length: 512"
230        )
231
232    def test_bad_elliptic_curve_name(self):
233        from certbot.crypto_util import make_key
234        with self.assertRaises(errors.Error) as e:
235            make_key(elliptic_curve="nothere", key_type='ecdsa')
236        self.assertEqual(
237            "Unsupported elliptic curve: nothere",
238            str(e.exception),
239            "Unsupported elliptic curve: nothere"
240        )
241
242    def test_bad_key_type(self):
243        from certbot.crypto_util import make_key
244
245        # Try a bad --key-type
246        with self.assertRaises(errors.Error) as e:
247            OpenSSL.crypto.load_privatekey(
248                OpenSSL.crypto.FILETYPE_PEM, make_key(1024, key_type='unf'))
249        self.assertEqual(
250            "Invalid key_type specified: unf.  Use [rsa|ecdsa]",
251            str(e.exception),
252            "Invalid key_type specified: unf.  Use [rsa|ecdsa]",
253        )
254
255
256class VerifyCertSetup(unittest.TestCase):
257    """Refactoring for verification tests."""
258
259    def setUp(self):
260        self.renewable_cert = mock.MagicMock()
261        self.renewable_cert.cert_path = SS_CERT_PATH
262        self.renewable_cert.chain_path = SS_CERT_PATH
263        self.renewable_cert.key_path = RSA2048_KEY_PATH
264        self.renewable_cert.fullchain_path = test_util.vector_path('cert_fullchain_2048.pem')
265
266        self.bad_renewable_cert = mock.MagicMock()
267        self.bad_renewable_cert.chain_path = SS_CERT_PATH
268        self.bad_renewable_cert.cert_path = SS_CERT_PATH
269        self.bad_renewable_cert.fullchain_path = SS_CERT_PATH
270
271
272class VerifyRenewableCertTest(VerifyCertSetup):
273    """Tests for certbot.crypto_util.verify_renewable_cert."""
274
275    def _call(self, renewable_cert):
276        from certbot.crypto_util import verify_renewable_cert
277        return verify_renewable_cert(renewable_cert)
278
279    def test_verify_renewable_cert(self):
280        self.assertIsNone(self._call(self.renewable_cert))
281
282    @mock.patch('certbot.crypto_util.verify_renewable_cert_sig', side_effect=errors.Error(""))
283    def test_verify_renewable_cert_failure(self, unused_verify_renewable_cert_sign):
284        self.assertRaises(errors.Error, self._call, self.bad_renewable_cert)
285
286
287class VerifyRenewableCertSigTest(VerifyCertSetup):
288    """Tests for certbot.crypto_util.verify_renewable_cert."""
289
290    def _call(self, renewable_cert):
291        from certbot.crypto_util import verify_renewable_cert_sig
292        return verify_renewable_cert_sig(renewable_cert)
293
294    def test_cert_sig_match(self):
295        self.assertIsNone(self._call(self.renewable_cert))
296
297    def test_cert_sig_match_ec(self):
298        renewable_cert = mock.MagicMock()
299        renewable_cert.cert_path = P256_CERT_PATH
300        renewable_cert.chain_path = P256_CERT_PATH
301        renewable_cert.key_path = P256_KEY
302        self.assertIsNone(self._call(renewable_cert))
303
304    def test_cert_sig_mismatch(self):
305        self.bad_renewable_cert.cert_path = test_util.vector_path('cert_512_bad.pem')
306        self.assertRaises(errors.Error, self._call, self.bad_renewable_cert)
307
308
309class VerifyFullchainTest(VerifyCertSetup):
310    """Tests for certbot.crypto_util.verify_fullchain."""
311
312    def _call(self, renewable_cert):
313        from certbot.crypto_util import verify_fullchain
314        return verify_fullchain(renewable_cert)
315
316    def test_fullchain_matches(self):
317        self.assertIsNone(self._call(self.renewable_cert))
318
319    def test_fullchain_mismatch(self):
320        self.assertRaises(errors.Error, self._call, self.bad_renewable_cert)
321
322    def test_fullchain_ioerror(self):
323        self.bad_renewable_cert.chain = "dog"
324        self.assertRaises(errors.Error, self._call, self.bad_renewable_cert)
325
326
327class VerifyCertMatchesPrivKeyTest(VerifyCertSetup):
328    """Tests for certbot.crypto_util.verify_cert_matches_priv_key."""
329
330    def _call(self, renewable_cert):
331        from certbot.crypto_util import verify_cert_matches_priv_key
332        return verify_cert_matches_priv_key(renewable_cert.cert, renewable_cert.privkey)
333
334    def test_cert_priv_key_match(self):
335        self.renewable_cert.cert = SS_CERT_PATH
336        self.renewable_cert.privkey = RSA2048_KEY_PATH
337        self.assertIsNone(self._call(self.renewable_cert))
338
339    def test_cert_priv_key_mismatch(self):
340        self.bad_renewable_cert.privkey = RSA256_KEY_PATH
341        self.bad_renewable_cert.cert = SS_CERT_PATH
342
343        self.assertRaises(errors.Error, self._call, self.bad_renewable_cert)
344
345
346class ValidPrivkeyTest(unittest.TestCase):
347    """Tests for certbot.crypto_util.valid_privkey."""
348
349    @classmethod
350    def _call(cls, privkey):
351        from certbot.crypto_util import valid_privkey
352        return valid_privkey(privkey)
353
354    def test_valid_true(self):
355        self.assertTrue(self._call(RSA512_KEY))
356
357    def test_empty_false(self):
358        self.assertFalse(self._call(''))
359
360    def test_random_false(self):
361        self.assertFalse(self._call('foo bar'))
362
363
364class GetSANsFromCertTest(unittest.TestCase):
365    """Tests for certbot.crypto_util.get_sans_from_cert."""
366
367    @classmethod
368    def _call(cls, *args, **kwargs):
369        from certbot.crypto_util import get_sans_from_cert
370        return get_sans_from_cert(*args, **kwargs)
371
372    def test_single(self):
373        self.assertEqual([], self._call(test_util.load_vector('cert_512.pem')))
374
375    def test_san(self):
376        self.assertEqual(
377            ['example.com', 'www.example.com'],
378            self._call(test_util.load_vector('cert-san_512.pem')))
379
380
381class GetNamesFromCertTest(unittest.TestCase):
382    """Tests for certbot.crypto_util.get_names_from_cert."""
383
384    @classmethod
385    def _call(cls, *args, **kwargs):
386        from certbot.crypto_util import get_names_from_cert
387        return get_names_from_cert(*args, **kwargs)
388
389    def test_single(self):
390        self.assertEqual(
391            ['example.com'],
392            self._call(test_util.load_vector('cert_512.pem')))
393
394    def test_san(self):
395        self.assertEqual(
396            ['example.com', 'www.example.com'],
397            self._call(test_util.load_vector('cert-san_512.pem')))
398
399    def test_common_name_sans_order(self):
400        # Tests that the common name comes first
401        # followed by the SANS in alphabetical order
402        self.assertEqual(
403            ['example.com'] + ['{0}.example.com'.format(c) for c in 'abcd'],
404            self._call(test_util.load_vector('cert-5sans_512.pem')))
405
406    def test_parse_non_cert(self):
407        self.assertRaises(OpenSSL.crypto.Error, self._call, "hello there")
408
409
410class GetNamesFromReqTest(unittest.TestCase):
411    """Tests for certbot.crypto_util.get_names_from_req."""
412
413    @classmethod
414    def _call(cls, *args, **kwargs):
415        from certbot.crypto_util import get_names_from_req
416        return get_names_from_req(*args, **kwargs)
417
418    def test_nonames(self):
419        self.assertEqual(
420            [],
421            self._call(test_util.load_vector('csr-nonames_512.pem')))
422
423    def test_nosans(self):
424        self.assertEqual(
425            ['example.com'],
426            self._call(test_util.load_vector('csr-nosans_512.pem')))
427
428    def test_sans(self):
429        self.assertEqual(
430            ['example.com', 'example.org', 'example.net', 'example.info',
431             'subdomain.example.com', 'other.subdomain.example.com'],
432            self._call(test_util.load_vector('csr-6sans_512.pem')))
433
434    def test_der(self):
435        from OpenSSL.crypto import FILETYPE_ASN1
436        self.assertEqual(
437            ['Example.com'],
438            self._call(test_util.load_vector('csr_512.der'), typ=FILETYPE_ASN1))
439
440
441class CertLoaderTest(unittest.TestCase):
442    """Tests for certbot.crypto_util.pyopenssl_load_certificate"""
443
444    def test_load_valid_cert(self):
445        from certbot.crypto_util import pyopenssl_load_certificate
446
447        cert, file_type = pyopenssl_load_certificate(CERT)
448        self.assertEqual(cert.digest('sha256'),
449                         OpenSSL.crypto.load_certificate(file_type, CERT).digest('sha256'))
450
451    def test_load_invalid_cert(self):
452        from certbot.crypto_util import pyopenssl_load_certificate
453        bad_cert_data = CERT.replace(b"BEGIN CERTIFICATE", b"ASDFASDFASDF!!!")
454        self.assertRaises(
455            errors.Error, pyopenssl_load_certificate, bad_cert_data)
456
457
458class NotBeforeTest(unittest.TestCase):
459    """Tests for certbot.crypto_util.notBefore"""
460
461    def test_notBefore(self):
462        from certbot.crypto_util import notBefore
463        self.assertEqual(notBefore(CERT_PATH).isoformat(),
464                         '2014-12-11T22:34:45+00:00')
465
466
467class NotAfterTest(unittest.TestCase):
468    """Tests for certbot.crypto_util.notAfter"""
469
470    def test_notAfter(self):
471        from certbot.crypto_util import notAfter
472        self.assertEqual(notAfter(CERT_PATH).isoformat(),
473                         '2014-12-18T22:34:45+00:00')
474
475
476class Sha256sumTest(unittest.TestCase):
477    """Tests for certbot.crypto_util.notAfter"""
478    def test_sha256sum(self):
479        from certbot.crypto_util import sha256sum
480        self.assertEqual(sha256sum(CERT_PATH),
481            '914ffed8daf9e2c99d90ac95c77d54f32cbd556672facac380f0c063498df84e')
482
483
484class CertAndChainFromFullchainTest(unittest.TestCase):
485    """Tests for certbot.crypto_util.cert_and_chain_from_fullchain"""
486
487    def _parse_and_reencode_pem(self, cert_pem):
488        from OpenSSL import crypto
489        return crypto.dump_certificate(crypto.FILETYPE_PEM,
490            crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem)).decode()
491
492    def test_cert_and_chain_from_fullchain(self):
493        cert_pem = CERT.decode()
494        chain_pem = cert_pem + SS_CERT.decode()
495        fullchain_pem = cert_pem + chain_pem
496        spacey_fullchain_pem = cert_pem + u'\n' + chain_pem
497        crlf_fullchain_pem = fullchain_pem.replace(u'\n', u'\r\n')
498
499        # In the ACME v1 code path, the fullchain is constructed by loading cert+chain DERs
500        # and using OpenSSL to dump them, so here we confirm that OpenSSL is producing certs
501        # that will be parseable by cert_and_chain_from_fullchain.
502        acmev1_fullchain_pem = self._parse_and_reencode_pem(cert_pem) + \
503            self._parse_and_reencode_pem(cert_pem) + self._parse_and_reencode_pem(SS_CERT.decode())
504
505        from certbot.crypto_util import cert_and_chain_from_fullchain
506        for fullchain in (fullchain_pem, spacey_fullchain_pem, crlf_fullchain_pem,
507                          acmev1_fullchain_pem):
508            cert_out, chain_out = cert_and_chain_from_fullchain(fullchain)
509            self.assertEqual(cert_out, cert_pem)
510            self.assertEqual(chain_out, chain_pem)
511
512        self.assertRaises(errors.Error, cert_and_chain_from_fullchain, cert_pem)
513
514
515class FindChainWithIssuerTest(unittest.TestCase):
516    """Tests for certbot.crypto_util.find_chain_with_issuer"""
517
518    @classmethod
519    def _call(cls, fullchains, issuer_cn, **kwargs):
520        from certbot.crypto_util import find_chain_with_issuer
521        return find_chain_with_issuer(fullchains, issuer_cn, kwargs)
522
523    def _all_fullchains(self):
524        return [CERT_LEAF.decode() + CERT_ISSUER.decode(),
525                CERT_LEAF.decode() + CERT_ALT_ISSUER.decode()]
526
527    def test_positive_match(self):
528        """Correctly pick the chain based on the root's CN"""
529        fullchains = self._all_fullchains()
530        matched = self._call(fullchains, "Pebble Root CA 0cc6f0")
531        self.assertEqual(matched, fullchains[1])
532
533    @mock.patch('certbot.crypto_util.logger.info')
534    def test_intermediate_match(self, mock_info):
535        """Don't pick a chain where only an intermediate matches"""
536        fullchains = self._all_fullchains()
537        # Make the second chain actually only contain "Pebble Root CA 0cc6f0"
538        # as an intermediate, not as the root. This wouldn't be a valid chain
539        # (the CERT_ISSUER cert didn't issue the CERT_ALT_ISSUER cert), but the
540        # function under test here doesn't care about that.
541        fullchains[1] = fullchains[1] + CERT_ISSUER.decode()
542        matched = self._call(fullchains, "Pebble Root CA 0cc6f0")
543        self.assertEqual(matched, fullchains[0])
544        mock_info.assert_not_called()
545
546    @mock.patch('certbot.crypto_util.logger.info')
547    def test_no_match(self, mock_info):
548        fullchains = self._all_fullchains()
549        matched = self._call(fullchains, "non-existent issuer")
550        self.assertEqual(matched, fullchains[0])
551        mock_info.assert_not_called()
552
553    @mock.patch('certbot.crypto_util.logger.warning')
554    def test_warning_on_no_match(self, mock_warning):
555        fullchains = self._all_fullchains()
556        matched = self._call(fullchains, "non-existent issuer",
557                             warn_on_no_match=True)
558        self.assertEqual(matched, fullchains[0])
559        mock_warning.assert_called_once_with("Certbot has been configured to prefer "
560            "certificate chains with issuer '%s', but no chain from the CA matched "
561            "this issuer. Using the default certificate chain instead.",
562            "non-existent issuer")
563
564
565if __name__ == '__main__':
566    unittest.main()  # pragma: no cover
567