1from cryptography.hazmat.primitives.serialization import (
2    Encoding, PrivateFormat, PublicFormat,
3    BestAvailableEncryption, NoEncryption,
4)
5from cryptography.hazmat.primitives.asymmetric import rsa
6from cryptography.hazmat.primitives.asymmetric.rsa import (
7    RSAPublicKey, RSAPrivateKeyWithSerialization,
8    RSAPrivateNumbers, RSAPublicNumbers,
9    rsa_recover_prime_factors, rsa_crt_dmp1, rsa_crt_dmq1, rsa_crt_iqmp
10)
11from cryptography.hazmat.primitives.asymmetric import ec
12from cryptography.hazmat.primitives.asymmetric.ec import (
13    EllipticCurvePublicKey, EllipticCurvePrivateKeyWithSerialization,
14    EllipticCurvePrivateNumbers, EllipticCurvePublicNumbers,
15    SECP256R1, SECP384R1, SECP521R1,
16)
17from cryptography.hazmat.backends import default_backend
18from authlib.jose.rfc7517 import Key, load_pem_key
19from authlib.common.encoding import to_bytes
20from authlib.common.encoding import base64_to_int, int_to_base64
21
22
23class RSAKey(Key):
24    """Key class of the ``RSA`` key type."""
25
26    kty = 'RSA'
27    RAW_KEY_CLS = (RSAPublicKey, RSAPrivateKeyWithSerialization)
28    REQUIRED_JSON_FIELDS = ['e', 'n']
29
30    def as_pem(self, is_private=False, password=None):
31        """Export key into PEM format bytes.
32
33        :param is_private: export private key or public key
34        :param password: encrypt private key with password
35        :return: bytes
36        """
37        return export_key(self, is_private=is_private, password=password)
38
39    @staticmethod
40    def dumps_private_key(raw_key):
41        numbers = raw_key.private_numbers()
42        return {
43            'n': int_to_base64(numbers.public_numbers.n),
44            'e': int_to_base64(numbers.public_numbers.e),
45            'd': int_to_base64(numbers.d),
46            'p': int_to_base64(numbers.p),
47            'q': int_to_base64(numbers.q),
48            'dp': int_to_base64(numbers.dmp1),
49            'dq': int_to_base64(numbers.dmq1),
50            'qi': int_to_base64(numbers.iqmp)
51        }
52
53    @staticmethod
54    def dumps_public_key(raw_key):
55        numbers = raw_key.public_numbers()
56        return {
57            'n': int_to_base64(numbers.n),
58            'e': int_to_base64(numbers.e)
59        }
60
61    @staticmethod
62    def loads_private_key(obj):
63        if 'oth' in obj:  # pragma: no cover
64            # https://tools.ietf.org/html/rfc7518#section-6.3.2.7
65            raise ValueError('"oth" is not supported yet')
66
67        props = ['p', 'q', 'dp', 'dq', 'qi']
68        props_found = [prop in obj for prop in props]
69        any_props_found = any(props_found)
70
71        if any_props_found and not all(props_found):
72            raise ValueError(
73                'RSA key must include all parameters '
74                'if any are present besides d')
75
76        public_numbers = RSAPublicNumbers(
77            base64_to_int(obj['e']), base64_to_int(obj['n']))
78
79        if any_props_found:
80            numbers = RSAPrivateNumbers(
81                d=base64_to_int(obj['d']),
82                p=base64_to_int(obj['p']),
83                q=base64_to_int(obj['q']),
84                dmp1=base64_to_int(obj['dp']),
85                dmq1=base64_to_int(obj['dq']),
86                iqmp=base64_to_int(obj['qi']),
87                public_numbers=public_numbers)
88        else:
89            d = base64_to_int(obj['d'])
90            p, q = rsa_recover_prime_factors(
91                public_numbers.n, d, public_numbers.e)
92            numbers = RSAPrivateNumbers(
93                d=d,
94                p=p,
95                q=q,
96                dmp1=rsa_crt_dmp1(d, p),
97                dmq1=rsa_crt_dmq1(d, q),
98                iqmp=rsa_crt_iqmp(p, q),
99                public_numbers=public_numbers)
100
101        return numbers.private_key(default_backend())
102
103    @staticmethod
104    def loads_public_key(obj):
105        numbers = RSAPublicNumbers(
106            base64_to_int(obj['e']),
107            base64_to_int(obj['n'])
108        )
109        return numbers.public_key(default_backend())
110
111    @classmethod
112    def import_key(cls, raw, options=None):
113        """Import a key from PEM or dict data."""
114        return import_key(
115            cls, raw,
116            RSAPublicKey, RSAPrivateKeyWithSerialization,
117            b'ssh-rsa', options
118        )
119
120    @classmethod
121    def generate_key(cls, key_size=2048, options=None, is_private=False):
122        if key_size < 512:
123            raise ValueError('key_size must not be less than 512')
124        if key_size % 8 != 0:
125            raise ValueError('Invalid key_size for RSAKey')
126        raw_key = rsa.generate_private_key(
127            public_exponent=65537,
128            key_size=key_size,
129            backend=default_backend(),
130        )
131        if not is_private:
132            raw_key = raw_key.public_key()
133        return cls.import_key(raw_key, options=options)
134
135
136class ECKey(Key):
137    """Key class of the ``EC`` key type."""
138
139    kty = 'EC'
140    DSS_CURVES = {
141        'P-256': SECP256R1,
142        'P-384': SECP384R1,
143        'P-521': SECP521R1,
144    }
145    CURVES_DSS = {
146        SECP256R1.name: 'P-256',
147        SECP384R1.name: 'P-384',
148        SECP521R1.name: 'P-521',
149    }
150    REQUIRED_JSON_FIELDS = ['crv', 'x', 'y']
151    RAW_KEY_CLS = (EllipticCurvePublicKey, EllipticCurvePrivateKeyWithSerialization)
152
153    def as_pem(self, is_private=False, password=None):
154        """Export key into PEM format bytes.
155
156        :param is_private: export private key or public key
157        :param password: encrypt private key with password
158        :return: bytes
159        """
160        return export_key(self, is_private=is_private, password=password)
161
162    def exchange_shared_key(self, pubkey):
163        # # used in ECDHAlgorithm
164        if isinstance(self.raw_key, EllipticCurvePrivateKeyWithSerialization):
165            return self.raw_key.exchange(ec.ECDH(), pubkey)
166        raise ValueError('Invalid key for exchanging shared key')
167
168    @property
169    def curve_key_size(self):
170        return self.raw_key.curve.key_size
171
172    @classmethod
173    def loads_private_key(cls, obj):
174        curve = cls.DSS_CURVES[obj['crv']]()
175        public_numbers = EllipticCurvePublicNumbers(
176            base64_to_int(obj['x']),
177            base64_to_int(obj['y']),
178            curve,
179        )
180        private_numbers = EllipticCurvePrivateNumbers(
181            base64_to_int(obj['d']),
182            public_numbers
183        )
184        return private_numbers.private_key(default_backend())
185
186    @classmethod
187    def loads_public_key(cls, obj):
188        curve = cls.DSS_CURVES[obj['crv']]()
189        public_numbers = EllipticCurvePublicNumbers(
190            base64_to_int(obj['x']),
191            base64_to_int(obj['y']),
192            curve,
193        )
194        return public_numbers.public_key(default_backend())
195
196    @classmethod
197    def dumps_private_key(cls, raw_key):
198        numbers = raw_key.private_numbers()
199        return {
200            'crv': cls.CURVES_DSS[raw_key.curve.name],
201            'x': int_to_base64(numbers.public_numbers.x),
202            'y': int_to_base64(numbers.public_numbers.y),
203            'd': int_to_base64(numbers.private_value),
204        }
205
206    @classmethod
207    def dumps_public_key(cls, raw_key):
208        numbers = raw_key.public_numbers()
209        return {
210            'crv': cls.CURVES_DSS[numbers.curve.name],
211            'x': int_to_base64(numbers.x),
212            'y': int_to_base64(numbers.y)
213        }
214
215    @classmethod
216    def import_key(cls, raw, options=None):
217        """Import a key from PEM or dict data."""
218        return import_key(
219            cls, raw,
220            EllipticCurvePublicKey, EllipticCurvePrivateKeyWithSerialization,
221            b'ecdsa-sha2-', options
222        )
223
224    @classmethod
225    def generate_key(cls, crv='P-256', options=None, is_private=False):
226        if crv not in cls.DSS_CURVES:
227            raise ValueError('Invalid crv value: "{}"'.format(crv))
228        raw_key = ec.generate_private_key(
229            curve=cls.DSS_CURVES[crv](),
230            backend=default_backend(),
231        )
232        if not is_private:
233            raw_key = raw_key.public_key()
234        return cls.import_key(raw_key, options=options)
235
236
237def import_key(cls, raw, public_key_cls, private_key_cls, ssh_type=None, options=None):
238    if isinstance(raw, cls):
239        if options is not None:
240            raw.update(options)
241        return raw
242
243    payload = None
244    if isinstance(raw, (public_key_cls, private_key_cls)):
245        raw_key = raw
246    elif isinstance(raw, dict):
247        cls.check_required_fields(raw)
248        payload = raw
249        if 'd' in payload:
250            raw_key = cls.loads_private_key(payload)
251        else:
252            raw_key = cls.loads_public_key(payload)
253    else:
254        if options is not None:
255            password = options.get('password')
256        else:
257            password = None
258        raw_key = load_pem_key(raw, ssh_type, password=password)
259
260    if isinstance(raw_key, private_key_cls):
261        if payload is None:
262            payload = cls.dumps_private_key(raw_key)
263        key_type = 'private'
264    elif isinstance(raw_key, public_key_cls):
265        if payload is None:
266            payload = cls.dumps_public_key(raw_key)
267        key_type = 'public'
268    else:
269        raise ValueError('Invalid data for importing key')
270
271    obj = cls(payload)
272    obj.raw_key = raw_key
273    obj.key_type = key_type
274    return obj
275
276
277def export_key(key, encoding=None, is_private=False, password=None):
278    if encoding is None or encoding == 'PEM':
279        encoding = Encoding.PEM
280    elif encoding == 'DER':
281        encoding = Encoding.DER
282    else:
283        raise ValueError('Invalid encoding: {!r}'.format(encoding))
284
285    if is_private:
286        if key.key_type == 'private':
287            if password is None:
288                encryption_algorithm = NoEncryption()
289            else:
290                encryption_algorithm = BestAvailableEncryption(to_bytes(password))
291            return key.raw_key.private_bytes(
292                encoding=encoding,
293                format=PrivateFormat.PKCS8,
294                encryption_algorithm=encryption_algorithm,
295            )
296        raise ValueError('This is a public key')
297
298    if key.key_type == 'private':
299        raw_key = key.raw_key.public_key()
300    else:
301        raw_key = key.raw_key
302
303    return raw_key.public_bytes(
304        encoding=encoding,
305        format=PublicFormat.SubjectPublicKeyInfo,
306    )
307