1# -*- coding: utf-8 -*-
2#
3#  Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
4#
5#  Licensed under the Apache License, Version 2.0 (the "License");
6#  you may not use this file except in compliance with the License.
7#  You may obtain a copy of the License at
8#
9#      http://www.apache.org/licenses/LICENSE-2.0
10#
11#  Unless required by applicable law or agreed to in writing, software
12#  distributed under the License is distributed on an "AS IS" BASIS,
13#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14#  See the License for the specific language governing permissions and
15#  limitations under the License.
16
17'''RSA key generation code.
18
19Create new keys with the newkeys() function. It will give you a PublicKey and a
20PrivateKey object.
21
22Loading and saving keys requires the pyasn1 module. This module is imported as
23late as possible, such that other functionality will remain working in absence
24of pyasn1.
25
26'''
27
28import logging
29from rsa._compat import b, bytes_type
30
31import rsa.prime
32import rsa.pem
33import rsa.common
34
35log = logging.getLogger(__name__)
36
37
38
39class AbstractKey(object):
40    '''Abstract superclass for private and public keys.'''
41
42    @classmethod
43    def load_pkcs1(cls, keyfile, format='PEM'):
44        r'''Loads a key in PKCS#1 DER or PEM format.
45
46        :param keyfile: contents of a DER- or PEM-encoded file that contains
47            the public key.
48        :param format: the format of the file to load; 'PEM' or 'DER'
49
50        :return: a PublicKey object
51
52        '''
53
54        methods = {
55            'PEM': cls._load_pkcs1_pem,
56            'DER': cls._load_pkcs1_der,
57        }
58
59        if format not in methods:
60            formats = ', '.join(sorted(methods.keys()))
61            raise ValueError('Unsupported format: %r, try one of %s' % (format,
62                formats))
63
64        method = methods[format]
65        return method(keyfile)
66
67    def save_pkcs1(self, format='PEM'):
68        '''Saves the public key in PKCS#1 DER or PEM format.
69
70        :param format: the format to save; 'PEM' or 'DER'
71        :returns: the DER- or PEM-encoded public key.
72
73        '''
74
75        methods = {
76            'PEM': self._save_pkcs1_pem,
77            'DER': self._save_pkcs1_der,
78        }
79
80        if format not in methods:
81            formats = ', '.join(sorted(methods.keys()))
82            raise ValueError('Unsupported format: %r, try one of %s' % (format,
83                formats))
84
85        method = methods[format]
86        return method()
87
88class PublicKey(AbstractKey):
89    '''Represents a public RSA key.
90
91    This key is also known as the 'encryption key'. It contains the 'n' and 'e'
92    values.
93
94    Supports attributes as well as dictionary-like access. Attribute accesss is
95    faster, though.
96
97    >>> PublicKey(5, 3)
98    PublicKey(5, 3)
99
100    >>> key = PublicKey(5, 3)
101    >>> key.n
102    5
103    >>> key['n']
104    5
105    >>> key.e
106    3
107    >>> key['e']
108    3
109
110    '''
111
112    __slots__ = ('n', 'e')
113
114    def __init__(self, n, e):
115        self.n = n
116        self.e = e
117
118    def __getitem__(self, key):
119        return getattr(self, key)
120
121    def __repr__(self):
122        return 'PublicKey(%i, %i)' % (self.n, self.e)
123
124    def __eq__(self, other):
125        if other is None:
126            return False
127
128        if not isinstance(other, PublicKey):
129            return False
130
131        return self.n == other.n and self.e == other.e
132
133    def __ne__(self, other):
134        return not (self == other)
135
136    @classmethod
137    def _load_pkcs1_der(cls, keyfile):
138        r'''Loads a key in PKCS#1 DER format.
139
140        @param keyfile: contents of a DER-encoded file that contains the public
141            key.
142        @return: a PublicKey object
143
144        First let's construct a DER encoded key:
145
146        >>> import base64
147        >>> b64der = 'MAwCBQCNGmYtAgMBAAE='
148        >>> der = base64.decodestring(b64der)
149
150        This loads the file:
151
152        >>> PublicKey._load_pkcs1_der(der)
153        PublicKey(2367317549, 65537)
154
155        '''
156
157        from pyasn1.codec.der import decoder
158        from rsa.asn1 import AsnPubKey
159
160        (priv, _) = decoder.decode(keyfile, asn1Spec=AsnPubKey())
161        return cls(n=int(priv['modulus']), e=int(priv['publicExponent']))
162
163    def _save_pkcs1_der(self):
164        '''Saves the public key in PKCS#1 DER format.
165
166        @returns: the DER-encoded public key.
167        '''
168
169        from pyasn1.codec.der import encoder
170        from rsa.asn1 import AsnPubKey
171
172        # Create the ASN object
173        asn_key = AsnPubKey()
174        asn_key.setComponentByName('modulus', self.n)
175        asn_key.setComponentByName('publicExponent', self.e)
176
177        return encoder.encode(asn_key)
178
179    @classmethod
180    def _load_pkcs1_pem(cls, keyfile):
181        '''Loads a PKCS#1 PEM-encoded public key file.
182
183        The contents of the file before the "-----BEGIN RSA PUBLIC KEY-----" and
184        after the "-----END RSA PUBLIC KEY-----" lines is ignored.
185
186        @param keyfile: contents of a PEM-encoded file that contains the public
187            key.
188        @return: a PublicKey object
189        '''
190
191        der = rsa.pem.load_pem(keyfile, 'RSA PUBLIC KEY')
192        return cls._load_pkcs1_der(der)
193
194    def _save_pkcs1_pem(self):
195        '''Saves a PKCS#1 PEM-encoded public key file.
196
197        @return: contents of a PEM-encoded file that contains the public key.
198        '''
199
200        der = self._save_pkcs1_der()
201        return rsa.pem.save_pem(der, 'RSA PUBLIC KEY')
202
203    @classmethod
204    def load_pkcs1_openssl_pem(cls, keyfile):
205        '''Loads a PKCS#1.5 PEM-encoded public key file from OpenSSL.
206
207        These files can be recognised in that they start with BEGIN PUBLIC KEY
208        rather than BEGIN RSA PUBLIC KEY.
209
210        The contents of the file before the "-----BEGIN PUBLIC KEY-----" and
211        after the "-----END PUBLIC KEY-----" lines is ignored.
212
213        @param keyfile: contents of a PEM-encoded file that contains the public
214            key, from OpenSSL.
215        @return: a PublicKey object
216        '''
217
218        der = rsa.pem.load_pem(keyfile, 'PUBLIC KEY')
219        return cls.load_pkcs1_openssl_der(der)
220
221    @classmethod
222    def load_pkcs1_openssl_der(cls, keyfile):
223        '''Loads a PKCS#1 DER-encoded public key file from OpenSSL.
224
225        @param keyfile: contents of a DER-encoded file that contains the public
226            key, from OpenSSL.
227        @return: a PublicKey object
228        '''
229
230        from rsa.asn1 import OpenSSLPubKey
231        from pyasn1.codec.der import decoder
232        from pyasn1.type import univ
233
234        (keyinfo, _) = decoder.decode(keyfile, asn1Spec=OpenSSLPubKey())
235
236        if keyinfo['header']['oid'] != univ.ObjectIdentifier('1.2.840.113549.1.1.1'):
237            raise TypeError("This is not a DER-encoded OpenSSL-compatible public key")
238
239        return cls._load_pkcs1_der(keyinfo['key'][1:])
240
241
242
243
244class PrivateKey(AbstractKey):
245    '''Represents a private RSA key.
246
247    This key is also known as the 'decryption key'. It contains the 'n', 'e',
248    'd', 'p', 'q' and other values.
249
250    Supports attributes as well as dictionary-like access. Attribute accesss is
251    faster, though.
252
253    >>> PrivateKey(3247, 65537, 833, 191, 17)
254    PrivateKey(3247, 65537, 833, 191, 17)
255
256    exp1, exp2 and coef don't have to be given, they will be calculated:
257
258    >>> pk = PrivateKey(3727264081, 65537, 3349121513, 65063, 57287)
259    >>> pk.exp1
260    55063
261    >>> pk.exp2
262    10095
263    >>> pk.coef
264    50797
265
266    If you give exp1, exp2 or coef, they will be used as-is:
267
268    >>> pk = PrivateKey(1, 2, 3, 4, 5, 6, 7, 8)
269    >>> pk.exp1
270    6
271    >>> pk.exp2
272    7
273    >>> pk.coef
274    8
275
276    '''
277
278    __slots__ = ('n', 'e', 'd', 'p', 'q', 'exp1', 'exp2', 'coef')
279
280    def __init__(self, n, e, d, p, q, exp1=None, exp2=None, coef=None):
281        self.n = n
282        self.e = e
283        self.d = d
284        self.p = p
285        self.q = q
286
287        # Calculate the other values if they aren't supplied
288        if exp1 is None:
289            self.exp1 = int(d % (p - 1))
290        else:
291            self.exp1 = exp1
292
293        if exp1 is None:
294            self.exp2 = int(d % (q - 1))
295        else:
296            self.exp2 = exp2
297
298        if coef is None:
299            self.coef = rsa.common.inverse(q, p)
300        else:
301            self.coef = coef
302
303    def __getitem__(self, key):
304        return getattr(self, key)
305
306    def __repr__(self):
307        return 'PrivateKey(%(n)i, %(e)i, %(d)i, %(p)i, %(q)i)' % self
308
309    def __eq__(self, other):
310        if other is None:
311            return False
312
313        if not isinstance(other, PrivateKey):
314            return False
315
316        return (self.n == other.n and
317            self.e == other.e and
318            self.d == other.d and
319            self.p == other.p and
320            self.q == other.q and
321            self.exp1 == other.exp1 and
322            self.exp2 == other.exp2 and
323            self.coef == other.coef)
324
325    def __ne__(self, other):
326        return not (self == other)
327
328    @classmethod
329    def _load_pkcs1_der(cls, keyfile):
330        r'''Loads a key in PKCS#1 DER format.
331
332        @param keyfile: contents of a DER-encoded file that contains the private
333            key.
334        @return: a PrivateKey object
335
336        First let's construct a DER encoded key:
337
338        >>> import base64
339        >>> b64der = 'MC4CAQACBQDeKYlRAgMBAAECBQDHn4npAgMA/icCAwDfxwIDANcXAgInbwIDAMZt'
340        >>> der = base64.decodestring(b64der)
341
342        This loads the file:
343
344        >>> PrivateKey._load_pkcs1_der(der)
345        PrivateKey(3727264081, 65537, 3349121513, 65063, 57287)
346
347        '''
348
349        from pyasn1.codec.der import decoder
350        (priv, _) = decoder.decode(keyfile)
351
352        # ASN.1 contents of DER encoded private key:
353        #
354        # RSAPrivateKey ::= SEQUENCE {
355        #     version           Version,
356        #     modulus           INTEGER,  -- n
357        #     publicExponent    INTEGER,  -- e
358        #     privateExponent   INTEGER,  -- d
359        #     prime1            INTEGER,  -- p
360        #     prime2            INTEGER,  -- q
361        #     exponent1         INTEGER,  -- d mod (p-1)
362        #     exponent2         INTEGER,  -- d mod (q-1)
363        #     coefficient       INTEGER,  -- (inverse of q) mod p
364        #     otherPrimeInfos   OtherPrimeInfos OPTIONAL
365        # }
366
367        if priv[0] != 0:
368            raise ValueError('Unable to read this file, version %s != 0' % priv[0])
369
370        as_ints = tuple(int(x) for x in priv[1:9])
371        return cls(*as_ints)
372
373    def _save_pkcs1_der(self):
374        '''Saves the private key in PKCS#1 DER format.
375
376        @returns: the DER-encoded private key.
377        '''
378
379        from pyasn1.type import univ, namedtype
380        from pyasn1.codec.der import encoder
381
382        class AsnPrivKey(univ.Sequence):
383            componentType = namedtype.NamedTypes(
384                namedtype.NamedType('version', univ.Integer()),
385                namedtype.NamedType('modulus', univ.Integer()),
386                namedtype.NamedType('publicExponent', univ.Integer()),
387                namedtype.NamedType('privateExponent', univ.Integer()),
388                namedtype.NamedType('prime1', univ.Integer()),
389                namedtype.NamedType('prime2', univ.Integer()),
390                namedtype.NamedType('exponent1', univ.Integer()),
391                namedtype.NamedType('exponent2', univ.Integer()),
392                namedtype.NamedType('coefficient', univ.Integer()),
393            )
394
395        # Create the ASN object
396        asn_key = AsnPrivKey()
397        asn_key.setComponentByName('version', 0)
398        asn_key.setComponentByName('modulus', self.n)
399        asn_key.setComponentByName('publicExponent', self.e)
400        asn_key.setComponentByName('privateExponent', self.d)
401        asn_key.setComponentByName('prime1', self.p)
402        asn_key.setComponentByName('prime2', self.q)
403        asn_key.setComponentByName('exponent1', self.exp1)
404        asn_key.setComponentByName('exponent2', self.exp2)
405        asn_key.setComponentByName('coefficient', self.coef)
406
407        return encoder.encode(asn_key)
408
409    @classmethod
410    def _load_pkcs1_pem(cls, keyfile):
411        '''Loads a PKCS#1 PEM-encoded private key file.
412
413        The contents of the file before the "-----BEGIN RSA PRIVATE KEY-----" and
414        after the "-----END RSA PRIVATE KEY-----" lines is ignored.
415
416        @param keyfile: contents of a PEM-encoded file that contains the private
417            key.
418        @return: a PrivateKey object
419        '''
420
421        der = rsa.pem.load_pem(keyfile, b('RSA PRIVATE KEY'))
422        return cls._load_pkcs1_der(der)
423
424    def _save_pkcs1_pem(self):
425        '''Saves a PKCS#1 PEM-encoded private key file.
426
427        @return: contents of a PEM-encoded file that contains the private key.
428        '''
429
430        der = self._save_pkcs1_der()
431        return rsa.pem.save_pem(der, b('RSA PRIVATE KEY'))
432
433def find_p_q(nbits, getprime_func=rsa.prime.getprime, accurate=True):
434    ''''Returns a tuple of two different primes of nbits bits each.
435
436    The resulting p * q has exacty 2 * nbits bits, and the returned p and q
437    will not be equal.
438
439    :param nbits: the number of bits in each of p and q.
440    :param getprime_func: the getprime function, defaults to
441        :py:func:`rsa.prime.getprime`.
442
443        *Introduced in Python-RSA 3.1*
444
445    :param accurate: whether to enable accurate mode or not.
446    :returns: (p, q), where p > q
447
448    >>> (p, q) = find_p_q(128)
449    >>> from rsa import common
450    >>> common.bit_size(p * q)
451    256
452
453    When not in accurate mode, the number of bits can be slightly less
454
455    >>> (p, q) = find_p_q(128, accurate=False)
456    >>> from rsa import common
457    >>> common.bit_size(p * q) <= 256
458    True
459    >>> common.bit_size(p * q) > 240
460    True
461
462    '''
463
464    total_bits = nbits * 2
465
466    # Make sure that p and q aren't too close or the factoring programs can
467    # factor n.
468    shift = nbits // 16
469    pbits = nbits + shift
470    qbits = nbits - shift
471
472    # Choose the two initial primes
473    log.debug('find_p_q(%i): Finding p', nbits)
474    p = getprime_func(pbits)
475    log.debug('find_p_q(%i): Finding q', nbits)
476    q = getprime_func(qbits)
477
478    def is_acceptable(p, q):
479        '''Returns True iff p and q are acceptable:
480
481            - p and q differ
482            - (p * q) has the right nr of bits (when accurate=True)
483        '''
484
485        if p == q:
486            return False
487
488        if not accurate:
489            return True
490
491        # Make sure we have just the right amount of bits
492        found_size = rsa.common.bit_size(p * q)
493        return total_bits == found_size
494
495    # Keep choosing other primes until they match our requirements.
496    change_p = False
497    while not is_acceptable(p, q):
498        # Change p on one iteration and q on the other
499        if change_p:
500            p = getprime_func(pbits)
501        else:
502            q = getprime_func(qbits)
503
504        change_p = not change_p
505
506    # We want p > q as described on
507    # http://www.di-mgt.com.au/rsa_alg.html#crt
508    return (max(p, q), min(p, q))
509
510def calculate_keys(p, q, nbits):
511    '''Calculates an encryption and a decryption key given p and q, and
512    returns them as a tuple (e, d)
513
514    '''
515
516    phi_n = (p - 1) * (q - 1)
517
518    # A very common choice for e is 65537
519    e = 65537
520
521    try:
522        d = rsa.common.inverse(e, phi_n)
523    except ValueError:
524        raise ValueError("e (%d) and phi_n (%d) are not relatively prime" %
525                (e, phi_n))
526
527    if (e * d) % phi_n != 1:
528        raise ValueError("e (%d) and d (%d) are not mult. inv. modulo "
529                "phi_n (%d)" % (e, d, phi_n))
530
531    return (e, d)
532
533def gen_keys(nbits, getprime_func, accurate=True):
534    '''Generate RSA keys of nbits bits. Returns (p, q, e, d).
535
536    Note: this can take a long time, depending on the key size.
537
538    :param nbits: the total number of bits in ``p`` and ``q``. Both ``p`` and
539        ``q`` will use ``nbits/2`` bits.
540    :param getprime_func: either :py:func:`rsa.prime.getprime` or a function
541        with similar signature.
542    '''
543
544    (p, q) = find_p_q(nbits // 2, getprime_func, accurate)
545    (e, d) = calculate_keys(p, q, nbits // 2)
546
547    return (p, q, e, d)
548
549def newkeys(nbits, accurate=True, poolsize=1):
550    '''Generates public and private keys, and returns them as (pub, priv).
551
552    The public key is also known as the 'encryption key', and is a
553    :py:class:`rsa.PublicKey` object. The private key is also known as the
554    'decryption key' and is a :py:class:`rsa.PrivateKey` object.
555
556    :param nbits: the number of bits required to store ``n = p*q``.
557    :param accurate: when True, ``n`` will have exactly the number of bits you
558        asked for. However, this makes key generation much slower. When False,
559        `n`` may have slightly less bits.
560    :param poolsize: the number of processes to use to generate the prime
561        numbers. If set to a number > 1, a parallel algorithm will be used.
562        This requires Python 2.6 or newer.
563
564    :returns: a tuple (:py:class:`rsa.PublicKey`, :py:class:`rsa.PrivateKey`)
565
566    The ``poolsize`` parameter was added in *Python-RSA 3.1* and requires
567    Python 2.6 or newer.
568
569    '''
570
571    if nbits < 16:
572        raise ValueError('Key too small')
573
574    if poolsize < 1:
575        raise ValueError('Pool size (%i) should be >= 1' % poolsize)
576
577    # Determine which getprime function to use
578    if poolsize > 1:
579        from rsa import parallel
580        import functools
581
582        getprime_func = functools.partial(parallel.getprime, poolsize=poolsize)
583    else: getprime_func = rsa.prime.getprime
584
585    # Generate the key components
586    (p, q, e, d) = gen_keys(nbits, getprime_func)
587
588    # Create the key objects
589    n = p * q
590
591    return (
592        PublicKey(n, e),
593        PrivateKey(n, e, d, p, q)
594    )
595
596__all__ = ['PublicKey', 'PrivateKey', 'newkeys']
597
598if __name__ == '__main__':
599    import doctest
600
601    try:
602        for count in range(100):
603            (failures, tests) = doctest.testmod()
604            if failures:
605                break
606
607            if (count and count % 10 == 0) or count == 1:
608                print('%i times' % count)
609    except KeyboardInterrupt:
610        print('Aborted')
611    else:
612        print('Doctests done')
613