1"""RSA module
2
3Module for calculating large primes, and RSA encryption, decryption,
4signing and verification. Includes generating public and private keys.
5
6WARNING: this implementation does not use random padding, compression of the
7cleartext input to prevent repetitions, or other common security improvements.
8Use with care.
9
10"""
11
12__author__ = "Sybren Stuvel, Marloes de Boer, Ivo Tamboer, and Barry Mead"
13__date__ = "2010-02-08"
14__version__ = '2.0'
15
16import math
17import os
18import random
19import sys
20import types
21from rsa._compat import byte
22
23# Display a warning that this insecure version is imported.
24import warnings
25warnings.warn('Insecure version of the RSA module is imported as %s' % __name__)
26
27
28def bit_size(number):
29    """Returns the number of bits required to hold a specific long number"""
30
31    return int(math.ceil(math.log(number,2)))
32
33def gcd(p, q):
34    """Returns the greatest common divisor of p and q
35    >>> gcd(48, 180)
36    12
37    """
38    # Iterateive Version is faster and uses much less stack space
39    while q != 0:
40        if p < q: (p,q) = (q,p)
41        (p,q) = (q, p % q)
42    return p
43
44
45def bytes2int(bytes):
46    """Converts a list of bytes or a string to an integer
47
48    >>> (((128 * 256) + 64) * 256) + 15
49    8405007
50    >>> l = [128, 64, 15]
51    >>> bytes2int(l)              #same as bytes2int('\x80@\x0f')
52    8405007
53    """
54
55    if not (type(bytes) is types.ListType or type(bytes) is types.StringType):
56        raise TypeError("You must pass a string or a list")
57
58    # Convert byte stream to integer
59    integer = 0
60    for byte in bytes:
61        integer *= 256
62        if type(byte) is types.StringType: byte = ord(byte)
63        integer += byte
64
65    return integer
66
67def int2bytes(number):
68    """
69    Converts a number to a string of bytes
70    """
71
72    if not (type(number) is types.LongType or type(number) is types.IntType):
73        raise TypeError("You must pass a long or an int")
74
75    string = ""
76
77    while number > 0:
78        string = "%s%s" % (byte(number & 0xFF), string)
79        number /= 256
80
81    return string
82
83def to64(number):
84    """Converts a number in the range of 0 to 63 into base 64 digit
85    character in the range of '0'-'9', 'A'-'Z', 'a'-'z','-','_'.
86
87    >>> to64(10)
88    'A'
89    """
90
91    if not (type(number) is types.LongType or type(number) is types.IntType):
92        raise TypeError("You must pass a long or an int")
93
94    if 0 <= number <= 9:            #00-09 translates to '0' - '9'
95        return byte(number + 48)
96
97    if 10 <= number <= 35:
98        return byte(number + 55)     #10-35 translates to 'A' - 'Z'
99
100    if 36 <= number <= 61:
101        return byte(number + 61)     #36-61 translates to 'a' - 'z'
102
103    if number == 62:                # 62   translates to '-' (minus)
104        return byte(45)
105
106    if number == 63:                # 63   translates to '_' (underscore)
107        return byte(95)
108
109    raise ValueError('Invalid Base64 value: %i' % number)
110
111
112def from64(number):
113    """Converts an ordinal character value in the range of
114    0-9,A-Z,a-z,-,_ to a number in the range of 0-63.
115
116    >>> from64(49)
117    1
118    """
119
120    if not (type(number) is types.LongType or type(number) is types.IntType):
121        raise TypeError("You must pass a long or an int")
122
123    if 48 <= number <= 57:         #ord('0') - ord('9') translates to 0-9
124        return(number - 48)
125
126    if 65 <= number <= 90:         #ord('A') - ord('Z') translates to 10-35
127        return(number - 55)
128
129    if 97 <= number <= 122:        #ord('a') - ord('z') translates to 36-61
130        return(number - 61)
131
132    if number == 45:               #ord('-') translates to 62
133        return(62)
134
135    if number == 95:               #ord('_') translates to 63
136        return(63)
137
138    raise ValueError('Invalid Base64 value: %i' % number)
139
140
141def int2str64(number):
142    """Converts a number to a string of base64 encoded characters in
143    the range of '0'-'9','A'-'Z,'a'-'z','-','_'.
144
145    >>> int2str64(123456789)
146    '7MyqL'
147    """
148
149    if not (type(number) is types.LongType or type(number) is types.IntType):
150        raise TypeError("You must pass a long or an int")
151
152    string = ""
153
154    while number > 0:
155        string = "%s%s" % (to64(number & 0x3F), string)
156        number /= 64
157
158    return string
159
160
161def str642int(string):
162    """Converts a base64 encoded string into an integer.
163    The chars of this string in in the range '0'-'9','A'-'Z','a'-'z','-','_'
164
165    >>> str642int('7MyqL')
166    123456789
167    """
168
169    if not (type(string) is types.ListType or type(string) is types.StringType):
170        raise TypeError("You must pass a string or a list")
171
172    integer = 0
173    for byte in string:
174        integer *= 64
175        if type(byte) is types.StringType: byte = ord(byte)
176        integer += from64(byte)
177
178    return integer
179
180def read_random_int(nbits):
181    """Reads a random integer of approximately nbits bits rounded up
182    to whole bytes"""
183
184    nbytes = int(math.ceil(nbits/8.))
185    randomdata = os.urandom(nbytes)
186    return bytes2int(randomdata)
187
188def randint(minvalue, maxvalue):
189    """Returns a random integer x with minvalue <= x <= maxvalue"""
190
191    # Safety - get a lot of random data even if the range is fairly
192    # small
193    min_nbits = 32
194
195    # The range of the random numbers we need to generate
196    range = (maxvalue - minvalue) + 1
197
198    # Which is this number of bytes
199    rangebytes = ((bit_size(range) + 7) / 8)
200
201    # Convert to bits, but make sure it's always at least min_nbits*2
202    rangebits = max(rangebytes * 8, min_nbits * 2)
203
204    # Take a random number of bits between min_nbits and rangebits
205    nbits = random.randint(min_nbits, rangebits)
206
207    return (read_random_int(nbits) % range) + minvalue
208
209def jacobi(a, b):
210    """Calculates the value of the Jacobi symbol (a/b)
211    where both a and b are positive integers, and b is odd
212    """
213
214    if a == 0: return 0
215    result = 1
216    while a > 1:
217        if a & 1:
218            if ((a-1)*(b-1) >> 2) & 1:
219                result = -result
220            a, b = b % a, a
221        else:
222            if (((b * b) - 1) >> 3) & 1:
223                result = -result
224            a >>= 1
225    if a == 0: return 0
226    return result
227
228def jacobi_witness(x, n):
229    """Returns False if n is an Euler pseudo-prime with base x, and
230    True otherwise.
231    """
232
233    j = jacobi(x, n) % n
234    f = pow(x, (n-1)/2, n)
235
236    if j == f: return False
237    return True
238
239def randomized_primality_testing(n, k):
240    """Calculates whether n is composite (which is always correct) or
241    prime (which is incorrect with error probability 2**-k)
242
243    Returns False if the number is composite, and True if it's
244    probably prime.
245    """
246
247    # 50% of Jacobi-witnesses can report compositness of non-prime numbers
248
249    for i in range(k):
250        x = randint(1, n-1)
251        if jacobi_witness(x, n): return False
252
253    return True
254
255def is_prime(number):
256    """Returns True if the number is prime, and False otherwise.
257
258    >>> is_prime(42)
259    0
260    >>> is_prime(41)
261    1
262    """
263
264    if randomized_primality_testing(number, 6):
265        # Prime, according to Jacobi
266        return True
267
268    # Not prime
269    return False
270
271
272def getprime(nbits):
273    """Returns a prime number of max. 'math.ceil(nbits/8)*8' bits. In
274    other words: nbits is rounded up to whole bytes.
275
276    >>> p = getprime(8)
277    >>> is_prime(p-1)
278    0
279    >>> is_prime(p)
280    1
281    >>> is_prime(p+1)
282    0
283    """
284
285    while True:
286        integer = read_random_int(nbits)
287
288        # Make sure it's odd
289        integer |= 1
290
291        # Test for primeness
292        if is_prime(integer): break
293
294        # Retry if not prime
295
296    return integer
297
298def are_relatively_prime(a, b):
299    """Returns True if a and b are relatively prime, and False if they
300    are not.
301
302    >>> are_relatively_prime(2, 3)
303    1
304    >>> are_relatively_prime(2, 4)
305    0
306    """
307
308    d = gcd(a, b)
309    return (d == 1)
310
311def find_p_q(nbits):
312    """Returns a tuple of two different primes of nbits bits"""
313    pbits = nbits + (nbits/16)  #Make sure that p and q aren't too close
314    qbits = nbits - (nbits/16)  #or the factoring programs can factor n
315    p = getprime(pbits)
316    while True:
317        q = getprime(qbits)
318        #Make sure p and q are different.
319        if not q == p: break
320    return (p, q)
321
322def extended_gcd(a, b):
323    """Returns a tuple (r, i, j) such that r = gcd(a, b) = ia + jb
324    """
325    # r = gcd(a,b) i = multiplicitive inverse of a mod b
326    #      or      j = multiplicitive inverse of b mod a
327    # Neg return values for i or j are made positive mod b or a respectively
328    # Iterateive Version is faster and uses much less stack space
329    x = 0
330    y = 1
331    lx = 1
332    ly = 0
333    oa = a                             #Remember original a/b to remove
334    ob = b                             #negative values from return results
335    while b != 0:
336        q = long(a/b)
337        (a, b)  = (b, a % b)
338        (x, lx) = ((lx - (q * x)),x)
339        (y, ly) = ((ly - (q * y)),y)
340    if (lx < 0): lx += ob              #If neg wrap modulo orignal b
341    if (ly < 0): ly += oa              #If neg wrap modulo orignal a
342    return (a, lx, ly)                 #Return only positive values
343
344# Main function: calculate encryption and decryption keys
345def calculate_keys(p, q, nbits):
346    """Calculates an encryption and a decryption key for p and q, and
347    returns them as a tuple (e, d)"""
348
349    n = p * q
350    phi_n = (p-1) * (q-1)
351
352    while True:
353        # Make sure e has enough bits so we ensure "wrapping" through
354        # modulo n
355        e = max(65537,getprime(nbits/4))
356        if are_relatively_prime(e, n) and are_relatively_prime(e, phi_n): break
357
358    (d, i, j) = extended_gcd(e, phi_n)
359
360    if not d == 1:
361        raise Exception("e (%d) and phi_n (%d) are not relatively prime" % (e, phi_n))
362    if (i < 0):
363        raise Exception("New extended_gcd shouldn't return negative values")
364    if not (e * i) % phi_n == 1:
365        raise Exception("e (%d) and i (%d) are not mult. inv. modulo phi_n (%d)" % (e, i, phi_n))
366
367    return (e, i)
368
369
370def gen_keys(nbits):
371    """Generate RSA keys of nbits bits. Returns (p, q, e, d).
372
373    Note: this can take a long time, depending on the key size.
374    """
375
376    (p, q) = find_p_q(nbits)
377    (e, d) = calculate_keys(p, q, nbits)
378
379    return (p, q, e, d)
380
381def newkeys(nbits):
382    """Generates public and private keys, and returns them as (pub,
383    priv).
384
385    The public key consists of a dict {e: ..., , n: ....). The private
386    key consists of a dict {d: ...., p: ...., q: ....).
387    """
388    nbits = max(9,nbits)           # Don't let nbits go below 9 bits
389    (p, q, e, d) = gen_keys(nbits)
390
391    return ( {'e': e, 'n': p*q}, {'d': d, 'p': p, 'q': q} )
392
393def encrypt_int(message, ekey, n):
394    """Encrypts a message using encryption key 'ekey', working modulo n"""
395
396    if type(message) is types.IntType:
397        message = long(message)
398
399    if not type(message) is types.LongType:
400        raise TypeError("You must pass a long or int")
401
402    if message < 0 or message > n:
403        raise OverflowError("The message is too long")
404
405    #Note: Bit exponents start at zero (bit counts start at 1) this is correct
406    safebit = bit_size(n) - 2                   #compute safe bit (MSB - 1)
407    message += (1 << safebit)                   #add safebit to ensure folding
408
409    return pow(message, ekey, n)
410
411def decrypt_int(cyphertext, dkey, n):
412    """Decrypts a cypher text using the decryption key 'dkey', working
413    modulo n"""
414
415    message = pow(cyphertext, dkey, n)
416
417    safebit = bit_size(n) - 2                   #compute safe bit (MSB - 1)
418    message -= (1 << safebit)                   #remove safebit before decode
419
420    return message
421
422def encode64chops(chops):
423    """base64encodes chops and combines them into a ',' delimited string"""
424
425    chips = []                              #chips are character chops
426
427    for value in chops:
428        chips.append(int2str64(value))
429
430    #delimit chops with comma
431    encoded = ','.join(chips)
432
433    return encoded
434
435def decode64chops(string):
436    """base64decodes and makes a ',' delimited string into chops"""
437
438    chips = string.split(',')               #split chops at commas
439
440    chops = []
441
442    for string in chips:                    #make char chops (chips) into chops
443        chops.append(str642int(string))
444
445    return chops
446
447def chopstring(message, key, n, funcref):
448    """Chops the 'message' into integers that fit into n,
449    leaving room for a safebit to be added to ensure that all
450    messages fold during exponentiation.  The MSB of the number n
451    is not independant modulo n (setting it could cause overflow), so
452    use the next lower bit for the safebit.  Therefore reserve 2-bits
453    in the number n for non-data bits.  Calls specified encryption
454    function for each chop.
455
456    Used by 'encrypt' and 'sign'.
457    """
458
459    msglen = len(message)
460    mbits = msglen * 8
461    #Set aside 2-bits so setting of safebit won't overflow modulo n.
462    nbits = bit_size(n) - 2             # leave room for safebit
463    nbytes = nbits / 8
464    blocks = msglen / nbytes
465
466    if msglen % nbytes > 0:
467        blocks += 1
468
469    cypher = []
470
471    for bindex in range(blocks):
472        offset = bindex * nbytes
473        block = message[offset:offset+nbytes]
474        value = bytes2int(block)
475        cypher.append(funcref(value, key, n))
476
477    return encode64chops(cypher)   #Encode encrypted ints to base64 strings
478
479def gluechops(string, key, n, funcref):
480    """Glues chops back together into a string.  calls
481    funcref(integer, key, n) for each chop.
482
483    Used by 'decrypt' and 'verify'.
484    """
485    message = ""
486
487    chops = decode64chops(string)  #Decode base64 strings into integer chops
488
489    for cpart in chops:
490        mpart = funcref(cpart, key, n) #Decrypt each chop
491        message += int2bytes(mpart)    #Combine decrypted strings into a msg
492
493    return message
494
495def encrypt(message, key):
496    """Encrypts a string 'message' with the public key 'key'"""
497    if 'n' not in key:
498        raise Exception("You must use the public key with encrypt")
499
500    return chopstring(message, key['e'], key['n'], encrypt_int)
501
502def sign(message, key):
503    """Signs a string 'message' with the private key 'key'"""
504    if 'p' not in key:
505        raise Exception("You must use the private key with sign")
506
507    return chopstring(message, key['d'], key['p']*key['q'], encrypt_int)
508
509def decrypt(cypher, key):
510    """Decrypts a string 'cypher' with the private key 'key'"""
511    if 'p' not in key:
512        raise Exception("You must use the private key with decrypt")
513
514    return gluechops(cypher, key['d'], key['p']*key['q'], decrypt_int)
515
516def verify(cypher, key):
517    """Verifies a string 'cypher' with the public key 'key'"""
518    if 'n' not in key:
519        raise Exception("You must use the public key with verify")
520
521    return gluechops(cypher, key['e'], key['n'], decrypt_int)
522
523# Do doctest if we're not imported
524if __name__ == "__main__":
525    import doctest
526    doctest.testmod()
527
528__all__ = ["newkeys", "encrypt", "decrypt", "sign", "verify"]
529
530