1# This file is dual licensed under the terms of the Apache License, Version
2# 2.0, and the BSD License. See the LICENSE file in the root of this repository
3# for complete details.
4
5from __future__ import absolute_import, division, print_function
6
7from cryptography import utils
8from cryptography.exceptions import (
9    InvalidSignature,
10    UnsupportedAlgorithm,
11    _Reasons,
12)
13from cryptography.hazmat.backends.openssl.utils import (
14    _calculate_digest_and_algorithm,
15    _check_not_prehashed,
16    _warn_sign_verify_deprecated,
17)
18from cryptography.hazmat.primitives import hashes
19from cryptography.hazmat.primitives.asymmetric import (
20    AsymmetricSignatureContext,
21    AsymmetricVerificationContext,
22    rsa,
23)
24from cryptography.hazmat.primitives.asymmetric.padding import (
25    AsymmetricPadding,
26    MGF1,
27    OAEP,
28    PKCS1v15,
29    PSS,
30    calculate_max_pss_salt_length,
31)
32from cryptography.hazmat.primitives.asymmetric.rsa import (
33    RSAPrivateKeyWithSerialization,
34    RSAPublicKeyWithSerialization,
35)
36
37
38def _get_rsa_pss_salt_length(pss, key, hash_algorithm):
39    salt = pss._salt_length
40
41    if salt is MGF1.MAX_LENGTH or salt is PSS.MAX_LENGTH:
42        return calculate_max_pss_salt_length(key, hash_algorithm)
43    else:
44        return salt
45
46
47def _enc_dec_rsa(backend, key, data, padding):
48    if not isinstance(padding, AsymmetricPadding):
49        raise TypeError("Padding must be an instance of AsymmetricPadding.")
50
51    if isinstance(padding, PKCS1v15):
52        padding_enum = backend._lib.RSA_PKCS1_PADDING
53    elif isinstance(padding, OAEP):
54        padding_enum = backend._lib.RSA_PKCS1_OAEP_PADDING
55
56        if not isinstance(padding._mgf, MGF1):
57            raise UnsupportedAlgorithm(
58                "Only MGF1 is supported by this backend.",
59                _Reasons.UNSUPPORTED_MGF,
60            )
61
62        if not backend.rsa_padding_supported(padding):
63            raise UnsupportedAlgorithm(
64                "This combination of padding and hash algorithm is not "
65                "supported by this backend.",
66                _Reasons.UNSUPPORTED_PADDING,
67            )
68
69    else:
70        raise UnsupportedAlgorithm(
71            "{} is not supported by this backend.".format(padding.name),
72            _Reasons.UNSUPPORTED_PADDING,
73        )
74
75    return _enc_dec_rsa_pkey_ctx(backend, key, data, padding_enum, padding)
76
77
78def _enc_dec_rsa_pkey_ctx(backend, key, data, padding_enum, padding):
79    if isinstance(key, _RSAPublicKey):
80        init = backend._lib.EVP_PKEY_encrypt_init
81        crypt = backend._lib.EVP_PKEY_encrypt
82    else:
83        init = backend._lib.EVP_PKEY_decrypt_init
84        crypt = backend._lib.EVP_PKEY_decrypt
85
86    pkey_ctx = backend._lib.EVP_PKEY_CTX_new(key._evp_pkey, backend._ffi.NULL)
87    backend.openssl_assert(pkey_ctx != backend._ffi.NULL)
88    pkey_ctx = backend._ffi.gc(pkey_ctx, backend._lib.EVP_PKEY_CTX_free)
89    res = init(pkey_ctx)
90    backend.openssl_assert(res == 1)
91    res = backend._lib.EVP_PKEY_CTX_set_rsa_padding(pkey_ctx, padding_enum)
92    backend.openssl_assert(res > 0)
93    buf_size = backend._lib.EVP_PKEY_size(key._evp_pkey)
94    backend.openssl_assert(buf_size > 0)
95    if isinstance(padding, OAEP) and backend._lib.Cryptography_HAS_RSA_OAEP_MD:
96        mgf1_md = backend._evp_md_non_null_from_algorithm(
97            padding._mgf._algorithm
98        )
99        res = backend._lib.EVP_PKEY_CTX_set_rsa_mgf1_md(pkey_ctx, mgf1_md)
100        backend.openssl_assert(res > 0)
101        oaep_md = backend._evp_md_non_null_from_algorithm(padding._algorithm)
102        res = backend._lib.EVP_PKEY_CTX_set_rsa_oaep_md(pkey_ctx, oaep_md)
103        backend.openssl_assert(res > 0)
104
105    if (
106        isinstance(padding, OAEP)
107        and padding._label is not None
108        and len(padding._label) > 0
109    ):
110        # set0_rsa_oaep_label takes ownership of the char * so we need to
111        # copy it into some new memory
112        labelptr = backend._lib.OPENSSL_malloc(len(padding._label))
113        backend.openssl_assert(labelptr != backend._ffi.NULL)
114        backend._ffi.memmove(labelptr, padding._label, len(padding._label))
115        res = backend._lib.EVP_PKEY_CTX_set0_rsa_oaep_label(
116            pkey_ctx, labelptr, len(padding._label)
117        )
118        backend.openssl_assert(res == 1)
119
120    outlen = backend._ffi.new("size_t *", buf_size)
121    buf = backend._ffi.new("unsigned char[]", buf_size)
122    # Everything from this line onwards is written with the goal of being as
123    # constant-time as is practical given the constraints of Python and our
124    # API. See Bleichenbacher's '98 attack on RSA, and its many many variants.
125    # As such, you should not attempt to change this (particularly to "clean it
126    # up") without understanding why it was written this way (see
127    # Chesterton's Fence), and without measuring to verify you have not
128    # introduced observable time differences.
129    res = crypt(pkey_ctx, buf, outlen, data, len(data))
130    resbuf = backend._ffi.buffer(buf)[: outlen[0]]
131    backend._lib.ERR_clear_error()
132    if res <= 0:
133        raise ValueError("Encryption/decryption failed.")
134    return resbuf
135
136
137def _rsa_sig_determine_padding(backend, key, padding, algorithm):
138    if not isinstance(padding, AsymmetricPadding):
139        raise TypeError("Expected provider of AsymmetricPadding.")
140
141    pkey_size = backend._lib.EVP_PKEY_size(key._evp_pkey)
142    backend.openssl_assert(pkey_size > 0)
143
144    if isinstance(padding, PKCS1v15):
145        # Hash algorithm is ignored for PKCS1v15-padding, may be None.
146        padding_enum = backend._lib.RSA_PKCS1_PADDING
147    elif isinstance(padding, PSS):
148        if not isinstance(padding._mgf, MGF1):
149            raise UnsupportedAlgorithm(
150                "Only MGF1 is supported by this backend.",
151                _Reasons.UNSUPPORTED_MGF,
152            )
153
154        # PSS padding requires a hash algorithm
155        if not isinstance(algorithm, hashes.HashAlgorithm):
156            raise TypeError("Expected instance of hashes.HashAlgorithm.")
157
158        # Size of key in bytes - 2 is the maximum
159        # PSS signature length (salt length is checked later)
160        if pkey_size - algorithm.digest_size - 2 < 0:
161            raise ValueError(
162                "Digest too large for key size. Use a larger "
163                "key or different digest."
164            )
165
166        padding_enum = backend._lib.RSA_PKCS1_PSS_PADDING
167    else:
168        raise UnsupportedAlgorithm(
169            "{} is not supported by this backend.".format(padding.name),
170            _Reasons.UNSUPPORTED_PADDING,
171        )
172
173    return padding_enum
174
175
176# Hash algorithm can be absent (None) to initialize the context without setting
177# any message digest algorithm. This is currently only valid for the PKCS1v15
178# padding type, where it means that the signature data is encoded/decoded
179# as provided, without being wrapped in a DigestInfo structure.
180def _rsa_sig_setup(backend, padding, algorithm, key, init_func):
181    padding_enum = _rsa_sig_determine_padding(backend, key, padding, algorithm)
182    pkey_ctx = backend._lib.EVP_PKEY_CTX_new(key._evp_pkey, backend._ffi.NULL)
183    backend.openssl_assert(pkey_ctx != backend._ffi.NULL)
184    pkey_ctx = backend._ffi.gc(pkey_ctx, backend._lib.EVP_PKEY_CTX_free)
185    res = init_func(pkey_ctx)
186    backend.openssl_assert(res == 1)
187    if algorithm is not None:
188        evp_md = backend._evp_md_non_null_from_algorithm(algorithm)
189        res = backend._lib.EVP_PKEY_CTX_set_signature_md(pkey_ctx, evp_md)
190        if res == 0:
191            backend._consume_errors()
192            raise UnsupportedAlgorithm(
193                "{} is not supported by this backend for RSA signing.".format(
194                    algorithm.name
195                ),
196                _Reasons.UNSUPPORTED_HASH,
197            )
198    res = backend._lib.EVP_PKEY_CTX_set_rsa_padding(pkey_ctx, padding_enum)
199    if res <= 0:
200        backend._consume_errors()
201        raise UnsupportedAlgorithm(
202            "{} is not supported for the RSA signature operation.".format(
203                padding.name
204            ),
205            _Reasons.UNSUPPORTED_PADDING,
206        )
207    if isinstance(padding, PSS):
208        res = backend._lib.EVP_PKEY_CTX_set_rsa_pss_saltlen(
209            pkey_ctx, _get_rsa_pss_salt_length(padding, key, algorithm)
210        )
211        backend.openssl_assert(res > 0)
212
213        mgf1_md = backend._evp_md_non_null_from_algorithm(
214            padding._mgf._algorithm
215        )
216        res = backend._lib.EVP_PKEY_CTX_set_rsa_mgf1_md(pkey_ctx, mgf1_md)
217        backend.openssl_assert(res > 0)
218
219    return pkey_ctx
220
221
222def _rsa_sig_sign(backend, padding, algorithm, private_key, data):
223    pkey_ctx = _rsa_sig_setup(
224        backend,
225        padding,
226        algorithm,
227        private_key,
228        backend._lib.EVP_PKEY_sign_init,
229    )
230    buflen = backend._ffi.new("size_t *")
231    res = backend._lib.EVP_PKEY_sign(
232        pkey_ctx, backend._ffi.NULL, buflen, data, len(data)
233    )
234    backend.openssl_assert(res == 1)
235    buf = backend._ffi.new("unsigned char[]", buflen[0])
236    res = backend._lib.EVP_PKEY_sign(pkey_ctx, buf, buflen, data, len(data))
237    if res != 1:
238        errors = backend._consume_errors_with_text()
239        raise ValueError(
240            "Digest or salt length too long for key size. Use a larger key "
241            "or shorter salt length if you are specifying a PSS salt",
242            errors,
243        )
244
245    return backend._ffi.buffer(buf)[:]
246
247
248def _rsa_sig_verify(backend, padding, algorithm, public_key, signature, data):
249    pkey_ctx = _rsa_sig_setup(
250        backend,
251        padding,
252        algorithm,
253        public_key,
254        backend._lib.EVP_PKEY_verify_init,
255    )
256    res = backend._lib.EVP_PKEY_verify(
257        pkey_ctx, signature, len(signature), data, len(data)
258    )
259    # The previous call can return negative numbers in the event of an
260    # error. This is not a signature failure but we need to fail if it
261    # occurs.
262    backend.openssl_assert(res >= 0)
263    if res == 0:
264        backend._consume_errors()
265        raise InvalidSignature
266
267
268def _rsa_sig_recover(backend, padding, algorithm, public_key, signature):
269    pkey_ctx = _rsa_sig_setup(
270        backend,
271        padding,
272        algorithm,
273        public_key,
274        backend._lib.EVP_PKEY_verify_recover_init,
275    )
276
277    # Attempt to keep the rest of the code in this function as constant/time
278    # as possible. See the comment in _enc_dec_rsa_pkey_ctx. Note that the
279    # outlen parameter is used even though its value may be undefined in the
280    # error case. Due to the tolerant nature of Python slicing this does not
281    # trigger any exceptions.
282    maxlen = backend._lib.EVP_PKEY_size(public_key._evp_pkey)
283    backend.openssl_assert(maxlen > 0)
284    buf = backend._ffi.new("unsigned char[]", maxlen)
285    buflen = backend._ffi.new("size_t *", maxlen)
286    res = backend._lib.EVP_PKEY_verify_recover(
287        pkey_ctx, buf, buflen, signature, len(signature)
288    )
289    resbuf = backend._ffi.buffer(buf)[: buflen[0]]
290    backend._lib.ERR_clear_error()
291    # Assume that all parameter errors are handled during the setup phase and
292    # any error here is due to invalid signature.
293    if res != 1:
294        raise InvalidSignature
295    return resbuf
296
297
298@utils.register_interface(AsymmetricSignatureContext)
299class _RSASignatureContext(object):
300    def __init__(self, backend, private_key, padding, algorithm):
301        self._backend = backend
302        self._private_key = private_key
303
304        # We now call _rsa_sig_determine_padding in _rsa_sig_setup. However
305        # we need to make a pointless call to it here so we maintain the
306        # API of erroring on init with this context if the values are invalid.
307        _rsa_sig_determine_padding(backend, private_key, padding, algorithm)
308        self._padding = padding
309        self._algorithm = algorithm
310        self._hash_ctx = hashes.Hash(self._algorithm, self._backend)
311
312    def update(self, data):
313        self._hash_ctx.update(data)
314
315    def finalize(self):
316        return _rsa_sig_sign(
317            self._backend,
318            self._padding,
319            self._algorithm,
320            self._private_key,
321            self._hash_ctx.finalize(),
322        )
323
324
325@utils.register_interface(AsymmetricVerificationContext)
326class _RSAVerificationContext(object):
327    def __init__(self, backend, public_key, signature, padding, algorithm):
328        self._backend = backend
329        self._public_key = public_key
330        self._signature = signature
331        self._padding = padding
332        # We now call _rsa_sig_determine_padding in _rsa_sig_setup. However
333        # we need to make a pointless call to it here so we maintain the
334        # API of erroring on init with this context if the values are invalid.
335        _rsa_sig_determine_padding(backend, public_key, padding, algorithm)
336
337        padding = padding
338        self._algorithm = algorithm
339        self._hash_ctx = hashes.Hash(self._algorithm, self._backend)
340
341    def update(self, data):
342        self._hash_ctx.update(data)
343
344    def verify(self):
345        return _rsa_sig_verify(
346            self._backend,
347            self._padding,
348            self._algorithm,
349            self._public_key,
350            self._signature,
351            self._hash_ctx.finalize(),
352        )
353
354
355@utils.register_interface(RSAPrivateKeyWithSerialization)
356class _RSAPrivateKey(object):
357    def __init__(self, backend, rsa_cdata, evp_pkey):
358        res = backend._lib.RSA_check_key(rsa_cdata)
359        if res != 1:
360            errors = backend._consume_errors_with_text()
361            raise ValueError("Invalid private key", errors)
362
363        # Blinding is on by default in many versions of OpenSSL, but let's
364        # just be conservative here.
365        res = backend._lib.RSA_blinding_on(rsa_cdata, backend._ffi.NULL)
366        backend.openssl_assert(res == 1)
367
368        self._backend = backend
369        self._rsa_cdata = rsa_cdata
370        self._evp_pkey = evp_pkey
371
372        n = self._backend._ffi.new("BIGNUM **")
373        self._backend._lib.RSA_get0_key(
374            self._rsa_cdata,
375            n,
376            self._backend._ffi.NULL,
377            self._backend._ffi.NULL,
378        )
379        self._backend.openssl_assert(n[0] != self._backend._ffi.NULL)
380        self._key_size = self._backend._lib.BN_num_bits(n[0])
381
382    key_size = utils.read_only_property("_key_size")
383
384    def signer(self, padding, algorithm):
385        _warn_sign_verify_deprecated()
386        _check_not_prehashed(algorithm)
387        return _RSASignatureContext(self._backend, self, padding, algorithm)
388
389    def decrypt(self, ciphertext, padding):
390        key_size_bytes = (self.key_size + 7) // 8
391        if key_size_bytes != len(ciphertext):
392            raise ValueError("Ciphertext length must be equal to key size.")
393
394        return _enc_dec_rsa(self._backend, self, ciphertext, padding)
395
396    def public_key(self):
397        ctx = self._backend._lib.RSAPublicKey_dup(self._rsa_cdata)
398        self._backend.openssl_assert(ctx != self._backend._ffi.NULL)
399        ctx = self._backend._ffi.gc(ctx, self._backend._lib.RSA_free)
400        evp_pkey = self._backend._rsa_cdata_to_evp_pkey(ctx)
401        return _RSAPublicKey(self._backend, ctx, evp_pkey)
402
403    def private_numbers(self):
404        n = self._backend._ffi.new("BIGNUM **")
405        e = self._backend._ffi.new("BIGNUM **")
406        d = self._backend._ffi.new("BIGNUM **")
407        p = self._backend._ffi.new("BIGNUM **")
408        q = self._backend._ffi.new("BIGNUM **")
409        dmp1 = self._backend._ffi.new("BIGNUM **")
410        dmq1 = self._backend._ffi.new("BIGNUM **")
411        iqmp = self._backend._ffi.new("BIGNUM **")
412        self._backend._lib.RSA_get0_key(self._rsa_cdata, n, e, d)
413        self._backend.openssl_assert(n[0] != self._backend._ffi.NULL)
414        self._backend.openssl_assert(e[0] != self._backend._ffi.NULL)
415        self._backend.openssl_assert(d[0] != self._backend._ffi.NULL)
416        self._backend._lib.RSA_get0_factors(self._rsa_cdata, p, q)
417        self._backend.openssl_assert(p[0] != self._backend._ffi.NULL)
418        self._backend.openssl_assert(q[0] != self._backend._ffi.NULL)
419        self._backend._lib.RSA_get0_crt_params(
420            self._rsa_cdata, dmp1, dmq1, iqmp
421        )
422        self._backend.openssl_assert(dmp1[0] != self._backend._ffi.NULL)
423        self._backend.openssl_assert(dmq1[0] != self._backend._ffi.NULL)
424        self._backend.openssl_assert(iqmp[0] != self._backend._ffi.NULL)
425        return rsa.RSAPrivateNumbers(
426            p=self._backend._bn_to_int(p[0]),
427            q=self._backend._bn_to_int(q[0]),
428            d=self._backend._bn_to_int(d[0]),
429            dmp1=self._backend._bn_to_int(dmp1[0]),
430            dmq1=self._backend._bn_to_int(dmq1[0]),
431            iqmp=self._backend._bn_to_int(iqmp[0]),
432            public_numbers=rsa.RSAPublicNumbers(
433                e=self._backend._bn_to_int(e[0]),
434                n=self._backend._bn_to_int(n[0]),
435            ),
436        )
437
438    def private_bytes(self, encoding, format, encryption_algorithm):
439        return self._backend._private_key_bytes(
440            encoding,
441            format,
442            encryption_algorithm,
443            self,
444            self._evp_pkey,
445            self._rsa_cdata,
446        )
447
448    def sign(self, data, padding, algorithm):
449        data, algorithm = _calculate_digest_and_algorithm(
450            self._backend, data, algorithm
451        )
452        return _rsa_sig_sign(self._backend, padding, algorithm, self, data)
453
454
455@utils.register_interface(RSAPublicKeyWithSerialization)
456class _RSAPublicKey(object):
457    def __init__(self, backend, rsa_cdata, evp_pkey):
458        self._backend = backend
459        self._rsa_cdata = rsa_cdata
460        self._evp_pkey = evp_pkey
461
462        n = self._backend._ffi.new("BIGNUM **")
463        self._backend._lib.RSA_get0_key(
464            self._rsa_cdata,
465            n,
466            self._backend._ffi.NULL,
467            self._backend._ffi.NULL,
468        )
469        self._backend.openssl_assert(n[0] != self._backend._ffi.NULL)
470        self._key_size = self._backend._lib.BN_num_bits(n[0])
471
472    key_size = utils.read_only_property("_key_size")
473
474    def verifier(self, signature, padding, algorithm):
475        _warn_sign_verify_deprecated()
476        utils._check_bytes("signature", signature)
477
478        _check_not_prehashed(algorithm)
479        return _RSAVerificationContext(
480            self._backend, self, signature, padding, algorithm
481        )
482
483    def encrypt(self, plaintext, padding):
484        return _enc_dec_rsa(self._backend, self, plaintext, padding)
485
486    def public_numbers(self):
487        n = self._backend._ffi.new("BIGNUM **")
488        e = self._backend._ffi.new("BIGNUM **")
489        self._backend._lib.RSA_get0_key(
490            self._rsa_cdata, n, e, self._backend._ffi.NULL
491        )
492        self._backend.openssl_assert(n[0] != self._backend._ffi.NULL)
493        self._backend.openssl_assert(e[0] != self._backend._ffi.NULL)
494        return rsa.RSAPublicNumbers(
495            e=self._backend._bn_to_int(e[0]),
496            n=self._backend._bn_to_int(n[0]),
497        )
498
499    def public_bytes(self, encoding, format):
500        return self._backend._public_key_bytes(
501            encoding, format, self, self._evp_pkey, self._rsa_cdata
502        )
503
504    def verify(self, signature, data, padding, algorithm):
505        data, algorithm = _calculate_digest_and_algorithm(
506            self._backend, data, algorithm
507        )
508        return _rsa_sig_verify(
509            self._backend, padding, algorithm, self, signature, data
510        )
511
512    def recover_data_from_signature(self, signature, padding, algorithm):
513        _check_not_prehashed(algorithm)
514        return _rsa_sig_recover(
515            self._backend, padding, algorithm, self, signature
516        )
517