1from cryptography.hazmat.primitives.serialization import ( 2 Encoding, PrivateFormat, PublicFormat, 3 BestAvailableEncryption, NoEncryption, 4) 5from cryptography.hazmat.primitives.asymmetric import rsa 6from cryptography.hazmat.primitives.asymmetric.rsa import ( 7 RSAPublicKey, RSAPrivateKeyWithSerialization, 8 RSAPrivateNumbers, RSAPublicNumbers, 9 rsa_recover_prime_factors, rsa_crt_dmp1, rsa_crt_dmq1, rsa_crt_iqmp 10) 11from cryptography.hazmat.primitives.asymmetric import ec 12from cryptography.hazmat.primitives.asymmetric.ec import ( 13 EllipticCurvePublicKey, EllipticCurvePrivateKeyWithSerialization, 14 EllipticCurvePrivateNumbers, EllipticCurvePublicNumbers, 15 SECP256R1, SECP384R1, SECP521R1, 16) 17from cryptography.hazmat.backends import default_backend 18from authlib.jose.rfc7517 import Key, load_pem_key 19from authlib.common.encoding import to_bytes 20from authlib.common.encoding import base64_to_int, int_to_base64 21 22 23class RSAKey(Key): 24 """Key class of the ``RSA`` key type.""" 25 26 kty = 'RSA' 27 RAW_KEY_CLS = (RSAPublicKey, RSAPrivateKeyWithSerialization) 28 REQUIRED_JSON_FIELDS = ['e', 'n'] 29 30 def as_pem(self, is_private=False, password=None): 31 """Export key into PEM format bytes. 32 33 :param is_private: export private key or public key 34 :param password: encrypt private key with password 35 :return: bytes 36 """ 37 return export_key(self, is_private=is_private, password=password) 38 39 @staticmethod 40 def dumps_private_key(raw_key): 41 numbers = raw_key.private_numbers() 42 return { 43 'n': int_to_base64(numbers.public_numbers.n), 44 'e': int_to_base64(numbers.public_numbers.e), 45 'd': int_to_base64(numbers.d), 46 'p': int_to_base64(numbers.p), 47 'q': int_to_base64(numbers.q), 48 'dp': int_to_base64(numbers.dmp1), 49 'dq': int_to_base64(numbers.dmq1), 50 'qi': int_to_base64(numbers.iqmp) 51 } 52 53 @staticmethod 54 def dumps_public_key(raw_key): 55 numbers = raw_key.public_numbers() 56 return { 57 'n': int_to_base64(numbers.n), 58 'e': int_to_base64(numbers.e) 59 } 60 61 @staticmethod 62 def loads_private_key(obj): 63 if 'oth' in obj: # pragma: no cover 64 # https://tools.ietf.org/html/rfc7518#section-6.3.2.7 65 raise ValueError('"oth" is not supported yet') 66 67 props = ['p', 'q', 'dp', 'dq', 'qi'] 68 props_found = [prop in obj for prop in props] 69 any_props_found = any(props_found) 70 71 if any_props_found and not all(props_found): 72 raise ValueError( 73 'RSA key must include all parameters ' 74 'if any are present besides d') 75 76 public_numbers = RSAPublicNumbers( 77 base64_to_int(obj['e']), base64_to_int(obj['n'])) 78 79 if any_props_found: 80 numbers = RSAPrivateNumbers( 81 d=base64_to_int(obj['d']), 82 p=base64_to_int(obj['p']), 83 q=base64_to_int(obj['q']), 84 dmp1=base64_to_int(obj['dp']), 85 dmq1=base64_to_int(obj['dq']), 86 iqmp=base64_to_int(obj['qi']), 87 public_numbers=public_numbers) 88 else: 89 d = base64_to_int(obj['d']) 90 p, q = rsa_recover_prime_factors( 91 public_numbers.n, d, public_numbers.e) 92 numbers = RSAPrivateNumbers( 93 d=d, 94 p=p, 95 q=q, 96 dmp1=rsa_crt_dmp1(d, p), 97 dmq1=rsa_crt_dmq1(d, q), 98 iqmp=rsa_crt_iqmp(p, q), 99 public_numbers=public_numbers) 100 101 return numbers.private_key(default_backend()) 102 103 @staticmethod 104 def loads_public_key(obj): 105 numbers = RSAPublicNumbers( 106 base64_to_int(obj['e']), 107 base64_to_int(obj['n']) 108 ) 109 return numbers.public_key(default_backend()) 110 111 @classmethod 112 def import_key(cls, raw, options=None): 113 """Import a key from PEM or dict data.""" 114 return import_key( 115 cls, raw, 116 RSAPublicKey, RSAPrivateKeyWithSerialization, 117 b'ssh-rsa', options 118 ) 119 120 @classmethod 121 def generate_key(cls, key_size=2048, options=None, is_private=False): 122 if key_size < 512: 123 raise ValueError('key_size must not be less than 512') 124 if key_size % 8 != 0: 125 raise ValueError('Invalid key_size for RSAKey') 126 raw_key = rsa.generate_private_key( 127 public_exponent=65537, 128 key_size=key_size, 129 backend=default_backend(), 130 ) 131 if not is_private: 132 raw_key = raw_key.public_key() 133 return cls.import_key(raw_key, options=options) 134 135 136class ECKey(Key): 137 """Key class of the ``EC`` key type.""" 138 139 kty = 'EC' 140 DSS_CURVES = { 141 'P-256': SECP256R1, 142 'P-384': SECP384R1, 143 'P-521': SECP521R1, 144 } 145 CURVES_DSS = { 146 SECP256R1.name: 'P-256', 147 SECP384R1.name: 'P-384', 148 SECP521R1.name: 'P-521', 149 } 150 REQUIRED_JSON_FIELDS = ['crv', 'x', 'y'] 151 RAW_KEY_CLS = (EllipticCurvePublicKey, EllipticCurvePrivateKeyWithSerialization) 152 153 def as_pem(self, is_private=False, password=None): 154 """Export key into PEM format bytes. 155 156 :param is_private: export private key or public key 157 :param password: encrypt private key with password 158 :return: bytes 159 """ 160 return export_key(self, is_private=is_private, password=password) 161 162 def exchange_shared_key(self, pubkey): 163 # # used in ECDHAlgorithm 164 if isinstance(self.raw_key, EllipticCurvePrivateKeyWithSerialization): 165 return self.raw_key.exchange(ec.ECDH(), pubkey) 166 raise ValueError('Invalid key for exchanging shared key') 167 168 @property 169 def curve_key_size(self): 170 return self.raw_key.curve.key_size 171 172 @classmethod 173 def loads_private_key(cls, obj): 174 curve = cls.DSS_CURVES[obj['crv']]() 175 public_numbers = EllipticCurvePublicNumbers( 176 base64_to_int(obj['x']), 177 base64_to_int(obj['y']), 178 curve, 179 ) 180 private_numbers = EllipticCurvePrivateNumbers( 181 base64_to_int(obj['d']), 182 public_numbers 183 ) 184 return private_numbers.private_key(default_backend()) 185 186 @classmethod 187 def loads_public_key(cls, obj): 188 curve = cls.DSS_CURVES[obj['crv']]() 189 public_numbers = EllipticCurvePublicNumbers( 190 base64_to_int(obj['x']), 191 base64_to_int(obj['y']), 192 curve, 193 ) 194 return public_numbers.public_key(default_backend()) 195 196 @classmethod 197 def dumps_private_key(cls, raw_key): 198 numbers = raw_key.private_numbers() 199 return { 200 'crv': cls.CURVES_DSS[raw_key.curve.name], 201 'x': int_to_base64(numbers.public_numbers.x), 202 'y': int_to_base64(numbers.public_numbers.y), 203 'd': int_to_base64(numbers.private_value), 204 } 205 206 @classmethod 207 def dumps_public_key(cls, raw_key): 208 numbers = raw_key.public_numbers() 209 return { 210 'crv': cls.CURVES_DSS[numbers.curve.name], 211 'x': int_to_base64(numbers.x), 212 'y': int_to_base64(numbers.y) 213 } 214 215 @classmethod 216 def import_key(cls, raw, options=None): 217 """Import a key from PEM or dict data.""" 218 return import_key( 219 cls, raw, 220 EllipticCurvePublicKey, EllipticCurvePrivateKeyWithSerialization, 221 b'ecdsa-sha2-', options 222 ) 223 224 @classmethod 225 def generate_key(cls, crv='P-256', options=None, is_private=False): 226 if crv not in cls.DSS_CURVES: 227 raise ValueError('Invalid crv value: "{}"'.format(crv)) 228 raw_key = ec.generate_private_key( 229 curve=cls.DSS_CURVES[crv](), 230 backend=default_backend(), 231 ) 232 if not is_private: 233 raw_key = raw_key.public_key() 234 return cls.import_key(raw_key, options=options) 235 236 237def import_key(cls, raw, public_key_cls, private_key_cls, ssh_type=None, options=None): 238 if isinstance(raw, cls): 239 if options is not None: 240 raw.update(options) 241 return raw 242 243 payload = None 244 if isinstance(raw, (public_key_cls, private_key_cls)): 245 raw_key = raw 246 elif isinstance(raw, dict): 247 cls.check_required_fields(raw) 248 payload = raw 249 if 'd' in payload: 250 raw_key = cls.loads_private_key(payload) 251 else: 252 raw_key = cls.loads_public_key(payload) 253 else: 254 if options is not None: 255 password = options.get('password') 256 else: 257 password = None 258 raw_key = load_pem_key(raw, ssh_type, password=password) 259 260 if isinstance(raw_key, private_key_cls): 261 if payload is None: 262 payload = cls.dumps_private_key(raw_key) 263 key_type = 'private' 264 elif isinstance(raw_key, public_key_cls): 265 if payload is None: 266 payload = cls.dumps_public_key(raw_key) 267 key_type = 'public' 268 else: 269 raise ValueError('Invalid data for importing key') 270 271 obj = cls(payload) 272 obj.raw_key = raw_key 273 obj.key_type = key_type 274 return obj 275 276 277def export_key(key, encoding=None, is_private=False, password=None): 278 if encoding is None or encoding == 'PEM': 279 encoding = Encoding.PEM 280 elif encoding == 'DER': 281 encoding = Encoding.DER 282 else: 283 raise ValueError('Invalid encoding: {!r}'.format(encoding)) 284 285 if is_private: 286 if key.key_type == 'private': 287 if password is None: 288 encryption_algorithm = NoEncryption() 289 else: 290 encryption_algorithm = BestAvailableEncryption(to_bytes(password)) 291 return key.raw_key.private_bytes( 292 encoding=encoding, 293 format=PrivateFormat.PKCS8, 294 encryption_algorithm=encryption_algorithm, 295 ) 296 raise ValueError('This is a public key') 297 298 if key.key_type == 'private': 299 raw_key = key.raw_key.public_key() 300 else: 301 raw_key = key.raw_key 302 303 return raw_key.public_bytes( 304 encoding=encoding, 305 format=PublicFormat.SubjectPublicKeyInfo, 306 ) 307