1# -*- coding: utf-8 -*-
2#
3#  Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
4#
5#  Licensed under the Apache License, Version 2.0 (the "License");
6#  you may not use this file except in compliance with the License.
7#  You may obtain a copy of the License at
8#
9#      https://www.apache.org/licenses/LICENSE-2.0
10#
11#  Unless required by applicable law or agreed to in writing, software
12#  distributed under the License is distributed on an "AS IS" BASIS,
13#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14#  See the License for the specific language governing permissions and
15#  limitations under the License.
16
17"""RSA key generation code.
18
19Create new keys with the newkeys() function. It will give you a PublicKey and a
20PrivateKey object.
21
22Loading and saving keys requires the pyasn1 module. This module is imported as
23late as possible, such that other functionality will remain working in absence
24of pyasn1.
25
26.. note::
27
28    Storing public and private keys via the `pickle` module is possible.
29    However, it is insecure to load a key from an untrusted source.
30    The pickle module is not secure against erroneous or maliciously
31    constructed data. Never unpickle data received from an untrusted
32    or unauthenticated source.
33
34"""
35
36import logging
37from rsa._compat import b
38
39import rsa.prime
40import rsa.pem
41import rsa.common
42import rsa.randnum
43import rsa.core
44
45log = logging.getLogger(__name__)
46DEFAULT_EXPONENT = 65537
47
48
49class AbstractKey(object):
50    """Abstract superclass for private and public keys."""
51
52    __slots__ = ('n', 'e')
53
54    def __init__(self, n, e):
55        self.n = n
56        self.e = e
57
58    @classmethod
59    def load_pkcs1(cls, keyfile, format='PEM'):
60        """Loads a key in PKCS#1 DER or PEM format.
61
62        :param keyfile: contents of a DER- or PEM-encoded file that contains
63            the public key.
64        :param format: the format of the file to load; 'PEM' or 'DER'
65
66        :return: a PublicKey object
67        """
68
69        methods = {
70            'PEM': cls._load_pkcs1_pem,
71            'DER': cls._load_pkcs1_der,
72        }
73
74        method = cls._assert_format_exists(format, methods)
75        return method(keyfile)
76
77    @staticmethod
78    def _assert_format_exists(file_format, methods):
79        """Checks whether the given file format exists in 'methods'.
80        """
81
82        try:
83            return methods[file_format]
84        except KeyError:
85            formats = ', '.join(sorted(methods.keys()))
86            raise ValueError('Unsupported format: %r, try one of %s' % (file_format,
87                                                                        formats))
88
89    def save_pkcs1(self, format='PEM'):
90        """Saves the public key in PKCS#1 DER or PEM format.
91
92        :param format: the format to save; 'PEM' or 'DER'
93        :returns: the DER- or PEM-encoded public key.
94        """
95
96        methods = {
97            'PEM': self._save_pkcs1_pem,
98            'DER': self._save_pkcs1_der,
99        }
100
101        method = self._assert_format_exists(format, methods)
102        return method()
103
104    def blind(self, message, r):
105        """Performs blinding on the message using random number 'r'.
106
107        :param message: the message, as integer, to blind.
108        :type message: int
109        :param r: the random number to blind with.
110        :type r: int
111        :return: the blinded message.
112        :rtype: int
113
114        The blinding is such that message = unblind(decrypt(blind(encrypt(message))).
115
116        See https://en.wikipedia.org/wiki/Blinding_%28cryptography%29
117        """
118
119        return (message * pow(r, self.e, self.n)) % self.n
120
121    def unblind(self, blinded, r):
122        """Performs blinding on the message using random number 'r'.
123
124        :param blinded: the blinded message, as integer, to unblind.
125        :param r: the random number to unblind with.
126        :return: the original message.
127
128        The blinding is such that message = unblind(decrypt(blind(encrypt(message))).
129
130        See https://en.wikipedia.org/wiki/Blinding_%28cryptography%29
131        """
132
133        return (rsa.common.inverse(r, self.n) * blinded) % self.n
134
135
136class PublicKey(AbstractKey):
137    """Represents a public RSA key.
138
139    This key is also known as the 'encryption key'. It contains the 'n' and 'e'
140    values.
141
142    Supports attributes as well as dictionary-like access. Attribute accesss is
143    faster, though.
144
145    >>> PublicKey(5, 3)
146    PublicKey(5, 3)
147
148    >>> key = PublicKey(5, 3)
149    >>> key.n
150    5
151    >>> key['n']
152    5
153    >>> key.e
154    3
155    >>> key['e']
156    3
157
158    """
159
160    __slots__ = ('n', 'e')
161
162    def __getitem__(self, key):
163        return getattr(self, key)
164
165    def __repr__(self):
166        return 'PublicKey(%i, %i)' % (self.n, self.e)
167
168    def __getstate__(self):
169        """Returns the key as tuple for pickling."""
170        return self.n, self.e
171
172    def __setstate__(self, state):
173        """Sets the key from tuple."""
174        self.n, self.e = state
175
176    def __eq__(self, other):
177        if other is None:
178            return False
179
180        if not isinstance(other, PublicKey):
181            return False
182
183        return self.n == other.n and self.e == other.e
184
185    def __ne__(self, other):
186        return not (self == other)
187
188    @classmethod
189    def _load_pkcs1_der(cls, keyfile):
190        """Loads a key in PKCS#1 DER format.
191
192        :param keyfile: contents of a DER-encoded file that contains the public
193            key.
194        :return: a PublicKey object
195
196        First let's construct a DER encoded key:
197
198        >>> import base64
199        >>> b64der = 'MAwCBQCNGmYtAgMBAAE='
200        >>> der = base64.standard_b64decode(b64der)
201
202        This loads the file:
203
204        >>> PublicKey._load_pkcs1_der(der)
205        PublicKey(2367317549, 65537)
206
207        """
208
209        from pyasn1.codec.der import decoder
210        from rsa.asn1 import AsnPubKey
211
212        (priv, _) = decoder.decode(keyfile, asn1Spec=AsnPubKey())
213        return cls(n=int(priv['modulus']), e=int(priv['publicExponent']))
214
215    def _save_pkcs1_der(self):
216        """Saves the public key in PKCS#1 DER format.
217
218        @returns: the DER-encoded public key.
219        """
220
221        from pyasn1.codec.der import encoder
222        from rsa.asn1 import AsnPubKey
223
224        # Create the ASN object
225        asn_key = AsnPubKey()
226        asn_key.setComponentByName('modulus', self.n)
227        asn_key.setComponentByName('publicExponent', self.e)
228
229        return encoder.encode(asn_key)
230
231    @classmethod
232    def _load_pkcs1_pem(cls, keyfile):
233        """Loads a PKCS#1 PEM-encoded public key file.
234
235        The contents of the file before the "-----BEGIN RSA PUBLIC KEY-----" and
236        after the "-----END RSA PUBLIC KEY-----" lines is ignored.
237
238        :param keyfile: contents of a PEM-encoded file that contains the public
239            key.
240        :return: a PublicKey object
241        """
242
243        der = rsa.pem.load_pem(keyfile, 'RSA PUBLIC KEY')
244        return cls._load_pkcs1_der(der)
245
246    def _save_pkcs1_pem(self):
247        """Saves a PKCS#1 PEM-encoded public key file.
248
249        :return: contents of a PEM-encoded file that contains the public key.
250        """
251
252        der = self._save_pkcs1_der()
253        return rsa.pem.save_pem(der, 'RSA PUBLIC KEY')
254
255    @classmethod
256    def load_pkcs1_openssl_pem(cls, keyfile):
257        """Loads a PKCS#1.5 PEM-encoded public key file from OpenSSL.
258
259        These files can be recognised in that they start with BEGIN PUBLIC KEY
260        rather than BEGIN RSA PUBLIC KEY.
261
262        The contents of the file before the "-----BEGIN PUBLIC KEY-----" and
263        after the "-----END PUBLIC KEY-----" lines is ignored.
264
265        :param keyfile: contents of a PEM-encoded file that contains the public
266            key, from OpenSSL.
267        :return: a PublicKey object
268        """
269
270        der = rsa.pem.load_pem(keyfile, 'PUBLIC KEY')
271        return cls.load_pkcs1_openssl_der(der)
272
273    @classmethod
274    def load_pkcs1_openssl_der(cls, keyfile):
275        """Loads a PKCS#1 DER-encoded public key file from OpenSSL.
276
277        :param keyfile: contents of a DER-encoded file that contains the public
278            key, from OpenSSL.
279        :return: a PublicKey object
280
281        """
282
283        from rsa.asn1 import OpenSSLPubKey
284        from pyasn1.codec.der import decoder
285        from pyasn1.type import univ
286
287        (keyinfo, _) = decoder.decode(keyfile, asn1Spec=OpenSSLPubKey())
288
289        if keyinfo['header']['oid'] != univ.ObjectIdentifier('1.2.840.113549.1.1.1'):
290            raise TypeError("This is not a DER-encoded OpenSSL-compatible public key")
291
292        return cls._load_pkcs1_der(keyinfo['key'][1:])
293
294
295class PrivateKey(AbstractKey):
296    """Represents a private RSA key.
297
298    This key is also known as the 'decryption key'. It contains the 'n', 'e',
299    'd', 'p', 'q' and other values.
300
301    Supports attributes as well as dictionary-like access. Attribute accesss is
302    faster, though.
303
304    >>> PrivateKey(3247, 65537, 833, 191, 17)
305    PrivateKey(3247, 65537, 833, 191, 17)
306
307    exp1, exp2 and coef can be given, but if None or omitted they will be calculated:
308
309    >>> pk = PrivateKey(3727264081, 65537, 3349121513, 65063, 57287, exp2=4)
310    >>> pk.exp1
311    55063
312    >>> pk.exp2  # this is of course not a correct value, but it is the one we passed.
313    4
314    >>> pk.coef
315    50797
316
317    If you give exp1, exp2 or coef, they will be used as-is:
318
319    >>> pk = PrivateKey(1, 2, 3, 4, 5, 6, 7, 8)
320    >>> pk.exp1
321    6
322    >>> pk.exp2
323    7
324    >>> pk.coef
325    8
326
327    """
328
329    __slots__ = ('n', 'e', 'd', 'p', 'q', 'exp1', 'exp2', 'coef')
330
331    def __init__(self, n, e, d, p, q, exp1=None, exp2=None, coef=None):
332        AbstractKey.__init__(self, n, e)
333        self.d = d
334        self.p = p
335        self.q = q
336
337        # Calculate the other values if they aren't supplied
338        if exp1 is None:
339            self.exp1 = int(d % (p - 1))
340        else:
341            self.exp1 = exp1
342
343        if exp2 is None:
344            self.exp2 = int(d % (q - 1))
345        else:
346            self.exp2 = exp2
347
348        if coef is None:
349            self.coef = rsa.common.inverse(q, p)
350        else:
351            self.coef = coef
352
353    def __getitem__(self, key):
354        return getattr(self, key)
355
356    def __repr__(self):
357        return 'PrivateKey(%(n)i, %(e)i, %(d)i, %(p)i, %(q)i)' % self
358
359    def __getstate__(self):
360        """Returns the key as tuple for pickling."""
361        return self.n, self.e, self.d, self.p, self.q, self.exp1, self.exp2, self.coef
362
363    def __setstate__(self, state):
364        """Sets the key from tuple."""
365        self.n, self.e, self.d, self.p, self.q, self.exp1, self.exp2, self.coef = state
366
367    def __eq__(self, other):
368        if other is None:
369            return False
370
371        if not isinstance(other, PrivateKey):
372            return False
373
374        return (self.n == other.n and
375                self.e == other.e and
376                self.d == other.d and
377                self.p == other.p and
378                self.q == other.q and
379                self.exp1 == other.exp1 and
380                self.exp2 == other.exp2 and
381                self.coef == other.coef)
382
383    def __ne__(self, other):
384        return not (self == other)
385
386    def blinded_decrypt(self, encrypted):
387        """Decrypts the message using blinding to prevent side-channel attacks.
388
389        :param encrypted: the encrypted message
390        :type encrypted: int
391
392        :returns: the decrypted message
393        :rtype: int
394        """
395
396        blind_r = rsa.randnum.randint(self.n - 1)
397        blinded = self.blind(encrypted, blind_r)  # blind before decrypting
398        decrypted = rsa.core.decrypt_int(blinded, self.d, self.n)
399
400        return self.unblind(decrypted, blind_r)
401
402    def blinded_encrypt(self, message):
403        """Encrypts the message using blinding to prevent side-channel attacks.
404
405        :param message: the message to encrypt
406        :type message: int
407
408        :returns: the encrypted message
409        :rtype: int
410        """
411
412        blind_r = rsa.randnum.randint(self.n - 1)
413        blinded = self.blind(message, blind_r)  # blind before encrypting
414        encrypted = rsa.core.encrypt_int(blinded, self.d, self.n)
415        return self.unblind(encrypted, blind_r)
416
417    @classmethod
418    def _load_pkcs1_der(cls, keyfile):
419        """Loads a key in PKCS#1 DER format.
420
421        :param keyfile: contents of a DER-encoded file that contains the private
422            key.
423        :return: a PrivateKey object
424
425        First let's construct a DER encoded key:
426
427        >>> import base64
428        >>> b64der = 'MC4CAQACBQDeKYlRAgMBAAECBQDHn4npAgMA/icCAwDfxwIDANcXAgInbwIDAMZt'
429        >>> der = base64.standard_b64decode(b64der)
430
431        This loads the file:
432
433        >>> PrivateKey._load_pkcs1_der(der)
434        PrivateKey(3727264081, 65537, 3349121513, 65063, 57287)
435
436        """
437
438        from pyasn1.codec.der import decoder
439        (priv, _) = decoder.decode(keyfile)
440
441        # ASN.1 contents of DER encoded private key:
442        #
443        # RSAPrivateKey ::= SEQUENCE {
444        #     version           Version,
445        #     modulus           INTEGER,  -- n
446        #     publicExponent    INTEGER,  -- e
447        #     privateExponent   INTEGER,  -- d
448        #     prime1            INTEGER,  -- p
449        #     prime2            INTEGER,  -- q
450        #     exponent1         INTEGER,  -- d mod (p-1)
451        #     exponent2         INTEGER,  -- d mod (q-1)
452        #     coefficient       INTEGER,  -- (inverse of q) mod p
453        #     otherPrimeInfos   OtherPrimeInfos OPTIONAL
454        # }
455
456        if priv[0] != 0:
457            raise ValueError('Unable to read this file, version %s != 0' % priv[0])
458
459        as_ints = tuple(int(x) for x in priv[1:9])
460        return cls(*as_ints)
461
462    def _save_pkcs1_der(self):
463        """Saves the private key in PKCS#1 DER format.
464
465        @returns: the DER-encoded private key.
466        """
467
468        from pyasn1.type import univ, namedtype
469        from pyasn1.codec.der import encoder
470
471        class AsnPrivKey(univ.Sequence):
472            componentType = namedtype.NamedTypes(
473                    namedtype.NamedType('version', univ.Integer()),
474                    namedtype.NamedType('modulus', univ.Integer()),
475                    namedtype.NamedType('publicExponent', univ.Integer()),
476                    namedtype.NamedType('privateExponent', univ.Integer()),
477                    namedtype.NamedType('prime1', univ.Integer()),
478                    namedtype.NamedType('prime2', univ.Integer()),
479                    namedtype.NamedType('exponent1', univ.Integer()),
480                    namedtype.NamedType('exponent2', univ.Integer()),
481                    namedtype.NamedType('coefficient', univ.Integer()),
482            )
483
484        # Create the ASN object
485        asn_key = AsnPrivKey()
486        asn_key.setComponentByName('version', 0)
487        asn_key.setComponentByName('modulus', self.n)
488        asn_key.setComponentByName('publicExponent', self.e)
489        asn_key.setComponentByName('privateExponent', self.d)
490        asn_key.setComponentByName('prime1', self.p)
491        asn_key.setComponentByName('prime2', self.q)
492        asn_key.setComponentByName('exponent1', self.exp1)
493        asn_key.setComponentByName('exponent2', self.exp2)
494        asn_key.setComponentByName('coefficient', self.coef)
495
496        return encoder.encode(asn_key)
497
498    @classmethod
499    def _load_pkcs1_pem(cls, keyfile):
500        """Loads a PKCS#1 PEM-encoded private key file.
501
502        The contents of the file before the "-----BEGIN RSA PRIVATE KEY-----" and
503        after the "-----END RSA PRIVATE KEY-----" lines is ignored.
504
505        :param keyfile: contents of a PEM-encoded file that contains the private
506            key.
507        :return: a PrivateKey object
508        """
509
510        der = rsa.pem.load_pem(keyfile, b('RSA PRIVATE KEY'))
511        return cls._load_pkcs1_der(der)
512
513    def _save_pkcs1_pem(self):
514        """Saves a PKCS#1 PEM-encoded private key file.
515
516        :return: contents of a PEM-encoded file that contains the private key.
517        """
518
519        der = self._save_pkcs1_der()
520        return rsa.pem.save_pem(der, b('RSA PRIVATE KEY'))
521
522
523def find_p_q(nbits, getprime_func=rsa.prime.getprime, accurate=True):
524    """Returns a tuple of two different primes of nbits bits each.
525
526    The resulting p * q has exacty 2 * nbits bits, and the returned p and q
527    will not be equal.
528
529    :param nbits: the number of bits in each of p and q.
530    :param getprime_func: the getprime function, defaults to
531        :py:func:`rsa.prime.getprime`.
532
533        *Introduced in Python-RSA 3.1*
534
535    :param accurate: whether to enable accurate mode or not.
536    :returns: (p, q), where p > q
537
538    >>> (p, q) = find_p_q(128)
539    >>> from rsa import common
540    >>> common.bit_size(p * q)
541    256
542
543    When not in accurate mode, the number of bits can be slightly less
544
545    >>> (p, q) = find_p_q(128, accurate=False)
546    >>> from rsa import common
547    >>> common.bit_size(p * q) <= 256
548    True
549    >>> common.bit_size(p * q) > 240
550    True
551
552    """
553
554    total_bits = nbits * 2
555
556    # Make sure that p and q aren't too close or the factoring programs can
557    # factor n.
558    shift = nbits // 16
559    pbits = nbits + shift
560    qbits = nbits - shift
561
562    # Choose the two initial primes
563    log.debug('find_p_q(%i): Finding p', nbits)
564    p = getprime_func(pbits)
565    log.debug('find_p_q(%i): Finding q', nbits)
566    q = getprime_func(qbits)
567
568    def is_acceptable(p, q):
569        """Returns True iff p and q are acceptable:
570
571            - p and q differ
572            - (p * q) has the right nr of bits (when accurate=True)
573        """
574
575        if p == q:
576            return False
577
578        if not accurate:
579            return True
580
581        # Make sure we have just the right amount of bits
582        found_size = rsa.common.bit_size(p * q)
583        return total_bits == found_size
584
585    # Keep choosing other primes until they match our requirements.
586    change_p = False
587    while not is_acceptable(p, q):
588        # Change p on one iteration and q on the other
589        if change_p:
590            p = getprime_func(pbits)
591        else:
592            q = getprime_func(qbits)
593
594        change_p = not change_p
595
596    # We want p > q as described on
597    # http://www.di-mgt.com.au/rsa_alg.html#crt
598    return max(p, q), min(p, q)
599
600
601def calculate_keys_custom_exponent(p, q, exponent):
602    """Calculates an encryption and a decryption key given p, q and an exponent,
603    and returns them as a tuple (e, d)
604
605    :param p: the first large prime
606    :param q: the second large prime
607    :param exponent: the exponent for the key; only change this if you know
608        what you're doing, as the exponent influences how difficult your
609        private key can be cracked. A very common choice for e is 65537.
610    :type exponent: int
611
612    """
613
614    phi_n = (p - 1) * (q - 1)
615
616    try:
617        d = rsa.common.inverse(exponent, phi_n)
618    except ValueError:
619        raise ValueError("e (%d) and phi_n (%d) are not relatively prime" %
620                         (exponent, phi_n))
621
622    if (exponent * d) % phi_n != 1:
623        raise ValueError("e (%d) and d (%d) are not mult. inv. modulo "
624                         "phi_n (%d)" % (exponent, d, phi_n))
625
626    return exponent, d
627
628
629def calculate_keys(p, q):
630    """Calculates an encryption and a decryption key given p and q, and
631    returns them as a tuple (e, d)
632
633    :param p: the first large prime
634    :param q: the second large prime
635
636    :return: tuple (e, d) with the encryption and decryption exponents.
637    """
638
639    return calculate_keys_custom_exponent(p, q, DEFAULT_EXPONENT)
640
641
642def gen_keys(nbits, getprime_func, accurate=True, exponent=DEFAULT_EXPONENT):
643    """Generate RSA keys of nbits bits. Returns (p, q, e, d).
644
645    Note: this can take a long time, depending on the key size.
646
647    :param nbits: the total number of bits in ``p`` and ``q``. Both ``p`` and
648        ``q`` will use ``nbits/2`` bits.
649    :param getprime_func: either :py:func:`rsa.prime.getprime` or a function
650        with similar signature.
651    :param exponent: the exponent for the key; only change this if you know
652        what you're doing, as the exponent influences how difficult your
653        private key can be cracked. A very common choice for e is 65537.
654    :type exponent: int
655    """
656
657    # Regenerate p and q values, until calculate_keys doesn't raise a
658    # ValueError.
659    while True:
660        (p, q) = find_p_q(nbits // 2, getprime_func, accurate)
661        try:
662            (e, d) = calculate_keys_custom_exponent(p, q, exponent=exponent)
663            break
664        except ValueError:
665            pass
666
667    return p, q, e, d
668
669
670def newkeys(nbits, accurate=True, poolsize=1, exponent=DEFAULT_EXPONENT):
671    """Generates public and private keys, and returns them as (pub, priv).
672
673    The public key is also known as the 'encryption key', and is a
674    :py:class:`rsa.PublicKey` object. The private key is also known as the
675    'decryption key' and is a :py:class:`rsa.PrivateKey` object.
676
677    :param nbits: the number of bits required to store ``n = p*q``.
678    :param accurate: when True, ``n`` will have exactly the number of bits you
679        asked for. However, this makes key generation much slower. When False,
680        `n`` may have slightly less bits.
681    :param poolsize: the number of processes to use to generate the prime
682        numbers. If set to a number > 1, a parallel algorithm will be used.
683        This requires Python 2.6 or newer.
684    :param exponent: the exponent for the key; only change this if you know
685        what you're doing, as the exponent influences how difficult your
686        private key can be cracked. A very common choice for e is 65537.
687    :type exponent: int
688
689    :returns: a tuple (:py:class:`rsa.PublicKey`, :py:class:`rsa.PrivateKey`)
690
691    The ``poolsize`` parameter was added in *Python-RSA 3.1* and requires
692    Python 2.6 or newer.
693
694    """
695
696    if nbits < 16:
697        raise ValueError('Key too small')
698
699    if poolsize < 1:
700        raise ValueError('Pool size (%i) should be >= 1' % poolsize)
701
702    # Determine which getprime function to use
703    if poolsize > 1:
704        from rsa import parallel
705        import functools
706
707        getprime_func = functools.partial(parallel.getprime, poolsize=poolsize)
708    else:
709        getprime_func = rsa.prime.getprime
710
711    # Generate the key components
712    (p, q, e, d) = gen_keys(nbits, getprime_func, accurate=accurate, exponent=exponent)
713
714    # Create the key objects
715    n = p * q
716
717    return (
718        PublicKey(n, e),
719        PrivateKey(n, e, d, p, q)
720    )
721
722
723__all__ = ['PublicKey', 'PrivateKey', 'newkeys']
724
725if __name__ == '__main__':
726    import doctest
727
728    try:
729        for count in range(100):
730            (failures, tests) = doctest.testmod()
731            if failures:
732                break
733
734            if (count and count % 10 == 0) or count == 1:
735                print('%i times' % count)
736    except KeyboardInterrupt:
737        print('Aborted')
738    else:
739        print('Doctests done')
740