1# This software is provided 'as-is', without any express or implied
2# warranty.  In no event will the author be held liable for any damages
3# arising from the use of this software.
4#
5# Permission is granted to anyone to use this software for any purpose,
6# including commercial applications, and to alter it and redistribute it
7# freely, subject to the following restrictions:
8#
9# 1. The origin of this software must not be misrepresented; you must not
10#    claim that you wrote the original software. If you use this software
11#    in a product, an acknowledgment in the product documentation would be
12#    appreciated but is not required.
13# 2. Altered source versions must be plainly marked as such, and must not be
14#    misrepresented as being the original software.
15# 3. This notice may not be removed or altered from any source distribution.
16#
17# Copyright (c) 2008 Greg Hewgill http://hewgill.com
18#
19# This has been modified from the original software.
20# Copyright (c) 2011 William Grant <me@williamgrant.id.au>
21# Copyright (c) 2018 Scott Kitterman <scott@kitterman.com>
22
23__all__ = [
24    'DigestTooLargeError',
25    'HASH_ALGORITHMS',
26    'ARC_HASH_ALGORITHMS',
27    'parse_pem_private_key',
28    'parse_private_key',
29    'parse_public_key',
30    'RSASSA_PKCS1_v1_5_sign',
31    'RSASSA_PKCS1_v1_5_verify',
32    'UnparsableKeyError',
33    ]
34
35import base64
36import hashlib
37import re
38
39from dkim.asn1 import (
40    ASN1FormatError,
41    asn1_build,
42    asn1_parse,
43    BIT_STRING,
44    INTEGER,
45    SEQUENCE,
46    OBJECT_IDENTIFIER,
47    OCTET_STRING,
48    NULL,
49    )
50
51
52ASN1_Object = [
53    (SEQUENCE, [
54        (SEQUENCE, [
55            (OBJECT_IDENTIFIER,),
56            (NULL,),
57        ]),
58        (BIT_STRING,),
59    ])
60]
61
62ASN1_RSAPublicKey = [
63    (SEQUENCE, [
64        (INTEGER,),
65        (INTEGER,),
66    ])
67]
68
69ASN1_RSAPrivateKey = [
70    (SEQUENCE, [
71        (INTEGER,),
72        (INTEGER,),
73        (INTEGER,),
74        (INTEGER,),
75        (INTEGER,),
76        (INTEGER,),
77        (INTEGER,),
78        (INTEGER,),
79        (INTEGER,),
80    ])
81]
82
83HASH_ALGORITHMS = {
84    b'rsa-sha1': hashlib.sha1,
85    b'rsa-sha256': hashlib.sha256,
86    b'ed25519-sha256': hashlib.sha256
87    }
88
89ARC_HASH_ALGORITHMS = {
90    b'rsa-sha256': hashlib.sha256,
91    }
92
93# These values come from RFC 8017, section 9.2 Notes, page 46.
94HASH_ID_MAP = {
95    'sha1': b"\x2b\x0e\x03\x02\x1a",
96    'sha256': b"\x60\x86\x48\x01\x65\x03\x04\x02\x01",
97    }
98
99
100class DigestTooLargeError(Exception):
101    """The digest is too large to fit within the requested length."""
102    pass
103
104
105class UnparsableKeyError(Exception):
106    """The data could not be parsed as a key."""
107    pass
108
109
110def parse_public_key(data):
111    """Parse an RSA public key.
112
113    @param data: DER-encoded X.509 subjectPublicKeyInfo
114        containing an RFC8017 RSAPublicKey.
115    @return: RSA public key
116    """
117    try:
118        # Not sure why the [1:] is necessary to skip a byte.
119        x = asn1_parse(ASN1_Object, data)
120        pkd = asn1_parse(ASN1_RSAPublicKey, x[0][1][1:])
121    except ASN1FormatError as e_spki:
122        try:
123            pkd = asn1_parse(ASN1_RSAPublicKey, data)
124        except ASN1FormatError as e_rsa:
125            raise UnparsableKeyError('Unparsable public key; SubjectPublicKeyInfo: ' + str(e_spki) + '; RSAPublicKey: ' + str(e_rsa))
126    pk = {
127        'modulus': pkd[0][0],
128        'publicExponent': pkd[0][1],
129    }
130    return pk
131
132
133def parse_private_key(data):
134    """Parse an RSA private key.
135
136    @param data: DER-encoded RFC8017 RSAPrivateKey.
137    @return: RSA private key
138    """
139    try:
140        pka = asn1_parse(ASN1_RSAPrivateKey, data)
141    except ASN1FormatError as e:
142        raise UnparsableKeyError('Unparsable private key: ' + str(e))
143    pk = {
144        'version': pka[0][0],
145        'modulus': pka[0][1],
146        'publicExponent': pka[0][2],
147        'privateExponent': pka[0][3],
148        'prime1': pka[0][4],
149        'prime2': pka[0][5],
150        'exponent1': pka[0][6],
151        'exponent2': pka[0][7],
152        'coefficient': pka[0][8],
153    }
154    return pk
155
156
157def parse_pem_private_key(data):
158    """Parse a PEM RSA private key.
159
160    @param data: RFC8017 RSAPrivateKey in PEM format.
161    @return: RSA private key
162    """
163    m = re.search(b"--\n(.*?)\n--", data, re.DOTALL)
164    if m is None:
165        raise UnparsableKeyError("Private key not found")
166    try:
167        pkdata = base64.b64decode(m.group(1))
168    except TypeError as e:
169        raise UnparsableKeyError(str(e))
170    return parse_private_key(pkdata)
171
172
173def EMSA_PKCS1_v1_5_encode(hash, mlen):
174    """Encode a digest with RFC8017 EMSA-PKCS1-v1_5.
175
176    @param hash: hash object to encode
177    @param mlen: desired message length
178    @return: encoded digest byte string
179    """
180    dinfo = asn1_build(
181        (SEQUENCE, [
182            (SEQUENCE, [
183                (OBJECT_IDENTIFIER, HASH_ID_MAP[hash.name.lower()]),
184                (NULL, None),
185            ]),
186            (OCTET_STRING, hash.digest()),
187        ]))
188    if len(dinfo) + 11 > mlen:
189        raise DigestTooLargeError()
190    return b"\x00\x01"+b"\xff"*(mlen-len(dinfo)-3)+b"\x00"+dinfo
191
192
193def str2int(s):
194    """Convert a byte string to an integer.
195
196    @param s: byte string representing a positive integer to convert
197    @return: converted integer
198    """
199    s = bytearray(s)
200    r = 0
201    for c in s:
202        r = (r << 8) | c
203    return r
204
205
206def int2str(n, length=-1):
207    """Convert an integer to a byte string.
208
209    @param n: positive integer to convert
210    @param length: minimum length
211    @return: converted bytestring, of at least the minimum length if it was
212        specified
213    """
214    assert n >= 0
215    r = bytearray()
216    while length < 0 or len(r) < length:
217        r.append(n & 0xff)
218        n >>= 8
219        if length < 0 and n == 0:
220            break
221    r.reverse()
222    assert length < 0 or len(r) == length
223    return r
224
225
226def rsa_decrypt(message, pk, mlen):
227    """Perform RSA decryption/signing
228
229    @param message: byte string to operate on
230    @param pk: private key data
231    @param mlen: desired output length
232    @return: byte string result of the operation
233    """
234    c = str2int(message)
235
236    m1 = pow(c, pk['exponent1'], pk['prime1'])
237    m2 = pow(c, pk['exponent2'], pk['prime2'])
238
239    if m1 < m2:
240        h = pk['coefficient'] * (m1 + pk['prime1'] - m2) % pk['prime1']
241    else:
242        h = pk['coefficient'] * (m1 - m2) % pk['prime1']
243
244    return int2str(m2 + h * pk['prime2'], mlen)
245
246
247def rsa_encrypt(message, pk, mlen):
248    """Perform RSA encryption/verification
249
250    @param message: byte string to operate on
251    @param pk: public key data
252    @param mlen: desired output length
253    @return: byte string result of the operation
254    """
255    m = str2int(message)
256    return int2str(pow(m, pk['publicExponent'], pk['modulus']), mlen)
257
258
259def RSASSA_PKCS1_v1_5_sign(hash, private_key):
260    """Sign a digest with RFC8017 RSASSA-PKCS1-v1_5.
261
262    @param hash: hash object to sign
263    @param private_key: private key data
264    @return: signed digest byte string
265    """
266    modlen = len(int2str(private_key['modulus']))
267    encoded_digest = EMSA_PKCS1_v1_5_encode(hash, modlen)
268    return rsa_decrypt(encoded_digest, private_key, modlen)
269
270
271def RSASSA_PKCS1_v1_5_verify(hash, signature, public_key):
272    """Verify a digest signed with RFC8017 RSASSA-PKCS1-v1_5.
273
274    @param hash: hash object to check
275    @param signature: signed digest byte string
276    @param public_key: public key data
277    @return: True if the signature is valid, False otherwise
278    """
279    modlen = len(int2str(public_key['modulus']))
280    encoded_digest = EMSA_PKCS1_v1_5_encode(hash, modlen)
281    signed_digest = rsa_encrypt(signature, public_key, modlen)
282    return encoded_digest == signed_digest
283