1# -*- coding: utf-8 -*-
2#
3#  Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
4#
5#  Licensed under the Apache License, Version 2.0 (the "License");
6#  you may not use this file except in compliance with the License.
7#  You may obtain a copy of the License at
8#
9#      http://www.apache.org/licenses/LICENSE-2.0
10#
11#  Unless required by applicable law or agreed to in writing, software
12#  distributed under the License is distributed on an "AS IS" BASIS,
13#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14#  See the License for the specific language governing permissions and
15#  limitations under the License.
16
17'''Functions for PKCS#1 version 1.5 encryption and signing
18
19This module implements certain functionality from PKCS#1 version 1.5. For a
20very clear example, read http://www.di-mgt.com.au/rsa_alg.html#pkcs1schemes
21
22At least 8 bytes of random padding is used when encrypting a message. This makes
23these methods much more secure than the ones in the ``rsa`` module.
24
25WARNING: this module leaks information when decryption or verification fails.
26The exceptions that are raised contain the Python traceback information, which
27can be used to deduce where in the process the failure occurred. DO NOT PASS
28SUCH INFORMATION to your users.
29'''
30
31import hashlib
32import os
33
34from rsa._compat import b
35from rsa import common, transform, core, varblock
36
37# ASN.1 codes that describe the hash algorithm used.
38HASH_ASN1 = {
39    'MD5': b('\x30\x20\x30\x0c\x06\x08\x2a\x86\x48\x86\xf7\x0d\x02\x05\x05\x00\x04\x10'),
40    'SHA-1': b('\x30\x21\x30\x09\x06\x05\x2b\x0e\x03\x02\x1a\x05\x00\x04\x14'),
41    'SHA-256': b('\x30\x31\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x01\x05\x00\x04\x20'),
42    'SHA-384': b('\x30\x41\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x02\x05\x00\x04\x30'),
43    'SHA-512': b('\x30\x51\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x03\x05\x00\x04\x40'),
44}
45
46HASH_METHODS = {
47    'MD5': hashlib.md5,
48    'SHA-1': hashlib.sha1,
49    'SHA-256': hashlib.sha256,
50    'SHA-384': hashlib.sha384,
51    'SHA-512': hashlib.sha512,
52}
53
54class CryptoError(Exception):
55    '''Base class for all exceptions in this module.'''
56
57class DecryptionError(CryptoError):
58    '''Raised when decryption fails.'''
59
60class VerificationError(CryptoError):
61    '''Raised when verification fails.'''
62
63def _pad_for_encryption(message, target_length):
64    r'''Pads the message for encryption, returning the padded message.
65
66    :return: 00 02 RANDOM_DATA 00 MESSAGE
67
68    >>> block = _pad_for_encryption('hello', 16)
69    >>> len(block)
70    16
71    >>> block[0:2]
72    '\x00\x02'
73    >>> block[-6:]
74    '\x00hello'
75
76    '''
77
78    max_msglength = target_length - 11
79    msglength = len(message)
80
81    if msglength > max_msglength:
82        raise OverflowError('%i bytes needed for message, but there is only'
83            ' space for %i' % (msglength, max_msglength))
84
85    # Get random padding
86    padding = b('')
87    padding_length = target_length - msglength - 3
88
89    # We remove 0-bytes, so we'll end up with less padding than we've asked for,
90    # so keep adding data until we're at the correct length.
91    while len(padding) < padding_length:
92        needed_bytes = padding_length - len(padding)
93
94        # Always read at least 8 bytes more than we need, and trim off the rest
95        # after removing the 0-bytes. This increases the chance of getting
96        # enough bytes, especially when needed_bytes is small
97        new_padding = os.urandom(needed_bytes + 5)
98        new_padding = new_padding.replace(b('\x00'), b(''))
99        padding = padding + new_padding[:needed_bytes]
100
101    assert len(padding) == padding_length
102
103    return b('').join([b('\x00\x02'),
104                    padding,
105                    b('\x00'),
106                    message])
107
108
109def _pad_for_signing(message, target_length):
110    r'''Pads the message for signing, returning the padded message.
111
112    The padding is always a repetition of FF bytes.
113
114    :return: 00 01 PADDING 00 MESSAGE
115
116    >>> block = _pad_for_signing('hello', 16)
117    >>> len(block)
118    16
119    >>> block[0:2]
120    '\x00\x01'
121    >>> block[-6:]
122    '\x00hello'
123    >>> block[2:-6]
124    '\xff\xff\xff\xff\xff\xff\xff\xff'
125
126    '''
127
128    max_msglength = target_length - 11
129    msglength = len(message)
130
131    if msglength > max_msglength:
132        raise OverflowError('%i bytes needed for message, but there is only'
133            ' space for %i' % (msglength, max_msglength))
134
135    padding_length = target_length - msglength - 3
136
137    return b('').join([b('\x00\x01'),
138                    padding_length * b('\xff'),
139                    b('\x00'),
140                    message])
141
142
143def encrypt(message, pub_key):
144    '''Encrypts the given message using PKCS#1 v1.5
145
146    :param message: the message to encrypt. Must be a byte string no longer than
147        ``k-11`` bytes, where ``k`` is the number of bytes needed to encode
148        the ``n`` component of the public key.
149    :param pub_key: the :py:class:`rsa.PublicKey` to encrypt with.
150    :raise OverflowError: when the message is too large to fit in the padded
151        block.
152
153    >>> from rsa import key, common
154    >>> (pub_key, priv_key) = key.newkeys(256)
155    >>> message = 'hello'
156    >>> crypto = encrypt(message, pub_key)
157
158    The crypto text should be just as long as the public key 'n' component:
159
160    >>> len(crypto) == common.byte_size(pub_key.n)
161    True
162
163    '''
164
165    keylength = common.byte_size(pub_key.n)
166    padded = _pad_for_encryption(message, keylength)
167
168    payload = transform.bytes2int(padded)
169    encrypted = core.encrypt_int(payload, pub_key.e, pub_key.n)
170    block = transform.int2bytes(encrypted, keylength)
171
172    return block
173
174def decrypt(crypto, priv_key):
175    r'''Decrypts the given message using PKCS#1 v1.5
176
177    The decryption is considered 'failed' when the resulting cleartext doesn't
178    start with the bytes 00 02, or when the 00 byte between the padding and
179    the message cannot be found.
180
181    :param crypto: the crypto text as returned by :py:func:`rsa.encrypt`
182    :param priv_key: the :py:class:`rsa.PrivateKey` to decrypt with.
183    :raise DecryptionError: when the decryption fails. No details are given as
184        to why the code thinks the decryption fails, as this would leak
185        information about the private key.
186
187
188    >>> import rsa
189    >>> (pub_key, priv_key) = rsa.newkeys(256)
190
191    It works with strings:
192
193    >>> crypto = encrypt('hello', pub_key)
194    >>> decrypt(crypto, priv_key)
195    'hello'
196
197    And with binary data:
198
199    >>> crypto = encrypt('\x00\x00\x00\x00\x01', pub_key)
200    >>> decrypt(crypto, priv_key)
201    '\x00\x00\x00\x00\x01'
202
203    Altering the encrypted information will *likely* cause a
204    :py:class:`rsa.pkcs1.DecryptionError`. If you want to be *sure*, use
205    :py:func:`rsa.sign`.
206
207
208    .. warning::
209
210        Never display the stack trace of a
211        :py:class:`rsa.pkcs1.DecryptionError` exception. It shows where in the
212        code the exception occurred, and thus leaks information about the key.
213        It's only a tiny bit of information, but every bit makes cracking the
214        keys easier.
215
216    >>> crypto = encrypt('hello', pub_key)
217    >>> crypto = crypto[0:5] + 'X' + crypto[6:] # change a byte
218    >>> decrypt(crypto, priv_key)
219    Traceback (most recent call last):
220    ...
221    DecryptionError: Decryption failed
222
223    '''
224
225    blocksize = common.byte_size(priv_key.n)
226    encrypted = transform.bytes2int(crypto)
227    decrypted = core.decrypt_int(encrypted, priv_key.d, priv_key.n)
228    cleartext = transform.int2bytes(decrypted, blocksize)
229
230    # If we can't find the cleartext marker, decryption failed.
231    if cleartext[0:2] != b('\x00\x02'):
232        raise DecryptionError('Decryption failed')
233
234    # Find the 00 separator between the padding and the message
235    try:
236        sep_idx = cleartext.index(b('\x00'), 2)
237    except ValueError:
238        raise DecryptionError('Decryption failed')
239
240    return cleartext[sep_idx+1:]
241
242def sign(message, priv_key, hash):
243    '''Signs the message with the private key.
244
245    Hashes the message, then signs the hash with the given key. This is known
246    as a "detached signature", because the message itself isn't altered.
247
248    :param message: the message to sign. Can be an 8-bit string or a file-like
249        object. If ``message`` has a ``read()`` method, it is assumed to be a
250        file-like object.
251    :param priv_key: the :py:class:`rsa.PrivateKey` to sign with
252    :param hash: the hash method used on the message. Use 'MD5', 'SHA-1',
253        'SHA-256', 'SHA-384' or 'SHA-512'.
254    :return: a message signature block.
255    :raise OverflowError: if the private key is too small to contain the
256        requested hash.
257
258    '''
259
260    # Get the ASN1 code for this hash method
261    if hash not in HASH_ASN1:
262        raise ValueError('Invalid hash method: %s' % hash)
263    asn1code = HASH_ASN1[hash]
264
265    # Calculate the hash
266    hash = _hash(message, hash)
267
268    # Encrypt the hash with the private key
269    cleartext = asn1code + hash
270    keylength = common.byte_size(priv_key.n)
271    padded = _pad_for_signing(cleartext, keylength)
272
273    payload = transform.bytes2int(padded)
274    encrypted = core.encrypt_int(payload, priv_key.d, priv_key.n)
275    block = transform.int2bytes(encrypted, keylength)
276
277    return block
278
279def verify(message, signature, pub_key):
280    '''Verifies that the signature matches the message.
281
282    The hash method is detected automatically from the signature.
283
284    :param message: the signed message. Can be an 8-bit string or a file-like
285        object. If ``message`` has a ``read()`` method, it is assumed to be a
286        file-like object.
287    :param signature: the signature block, as created with :py:func:`rsa.sign`.
288    :param pub_key: the :py:class:`rsa.PublicKey` of the person signing the message.
289    :raise VerificationError: when the signature doesn't match the message.
290
291    .. warning::
292
293        Never display the stack trace of a
294        :py:class:`rsa.pkcs1.VerificationError` exception. It shows where in
295        the code the exception occurred, and thus leaks information about the
296        key. It's only a tiny bit of information, but every bit makes cracking
297        the keys easier.
298
299    '''
300
301    blocksize = common.byte_size(pub_key.n)
302    encrypted = transform.bytes2int(signature)
303    decrypted = core.decrypt_int(encrypted, pub_key.e, pub_key.n)
304    clearsig = transform.int2bytes(decrypted, blocksize)
305
306    # If we can't find the signature  marker, verification failed.
307    if clearsig[0:2] != b('\x00\x01'):
308        raise VerificationError('Verification failed')
309
310    # Find the 00 separator between the padding and the payload
311    try:
312        sep_idx = clearsig.index(b('\x00'), 2)
313    except ValueError:
314        raise VerificationError('Verification failed')
315
316    # Get the hash and the hash method
317    (method_name, signature_hash) = _find_method_hash(clearsig[sep_idx+1:])
318    message_hash = _hash(message, method_name)
319
320    # Compare the real hash to the hash in the signature
321    if message_hash != signature_hash:
322        raise VerificationError('Verification failed')
323
324    return True
325
326def _hash(message, method_name):
327    '''Returns the message digest.
328
329    :param message: the signed message. Can be an 8-bit string or a file-like
330        object. If ``message`` has a ``read()`` method, it is assumed to be a
331        file-like object.
332    :param method_name: the hash method, must be a key of
333        :py:const:`HASH_METHODS`.
334
335    '''
336
337    if method_name not in HASH_METHODS:
338        raise ValueError('Invalid hash method: %s' % method_name)
339
340    method = HASH_METHODS[method_name]
341    hasher = method()
342
343    if hasattr(message, 'read') and hasattr(message.read, '__call__'):
344        # read as 1K blocks
345        for block in varblock.yield_fixedblocks(message, 1024):
346            hasher.update(block)
347    else:
348        # hash the message object itself.
349        hasher.update(message)
350
351    return hasher.digest()
352
353
354def _find_method_hash(method_hash):
355    '''Finds the hash method and the hash itself.
356
357    :param method_hash: ASN1 code for the hash method concatenated with the
358        hash itself.
359
360    :return: tuple (method, hash) where ``method`` is the used hash method, and
361        ``hash`` is the hash itself.
362
363    :raise VerificationFailed: when the hash method cannot be found
364
365    '''
366
367    for (hashname, asn1code) in HASH_ASN1.items():
368        if not method_hash.startswith(asn1code):
369            continue
370
371        return (hashname, method_hash[len(asn1code):])
372
373    raise VerificationError('Verification failed')
374
375
376__all__ = ['encrypt', 'decrypt', 'sign', 'verify',
377           'DecryptionError', 'VerificationError', 'CryptoError']
378
379if __name__ == '__main__':
380    print('Running doctests 1000x or until failure')
381    import doctest
382
383    for count in range(1000):
384        (failures, tests) = doctest.testmod()
385        if failures:
386            break
387
388        if count and count % 100 == 0:
389            print('%i times' % count)
390
391    print('Doctests done')
392