1# from future import standard_library 2# standard_library.install_aliases() 3try: 4 from builtins import object 5except ImportError: 6 pass 7 8import struct 9import io 10import logging 11import zlib 12import six 13 14from Crypto import Random 15from Crypto.Hash import SHA 16from Crypto.Util.number import bytes_to_long 17from Crypto.Util.number import long_to_bytes 18from Crypto.Cipher import PKCS1_v1_5 19from Crypto.Cipher import PKCS1_OAEP 20 21from jwkest import b64d, as_bytes 22from jwkest import b64e 23from jwkest import JWKESTException 24from jwkest import MissingKey 25from jwkest.aes_gcm import AES_GCM 26from jwkest.aes_key_wrap import aes_wrap_key 27from jwkest.aes_key_wrap import aes_unwrap_key 28from jwkest.ecc import NISTEllipticCurve 29from jwkest.extra import aes_cbc_hmac_encrypt 30from jwkest.extra import ecdh_derive_key 31from jwkest.extra import aes_cbc_hmac_decrypt 32from jwkest.jwk import intarr2str 33from jwkest.jwk import ECKey 34from jwkest.jws import JWx 35from jwkest.jwt import JWT, b64encode_item 36 37logger = logging.getLogger(__name__) 38 39__author__ = 'rohe0002' 40 41ENC = 1 42DEC = 0 43 44 45class JWEException(JWKESTException): 46 pass 47 48 49class CannotDecode(JWEException): 50 pass 51 52 53class NotSupportedAlgorithm(JWEException): 54 pass 55 56 57class MethodNotSupported(JWEException): 58 pass 59 60 61class ParameterError(JWEException): 62 pass 63 64 65class NoSuitableEncryptionKey(JWEException): 66 pass 67 68 69class NoSuitableDecryptionKey(JWEException): 70 pass 71 72 73class DecryptionFailed(JWEException): 74 pass 75 76 77class WrongEncryptionAlgorithm(JWEException): 78 pass 79 80 81# --------------------------------------------------------------------------- 82# Base class 83 84KEYLEN = { 85 "A128GCM": 128, 86 "A192GCM": 192, 87 "A256GCM": 256, 88 "A128CBC-HS256": 256, 89 "A192CBC-HS384": 384, 90 "A256CBC-HS512": 512 91} 92 93 94class Encrypter(object): 95 """Abstract base class for encryption algorithms.""" 96 97 def __init__(self, with_digest=False): 98 self.with_digest = with_digest 99 100 def encrypt(self, msg, key): 101 """Encrypt ``msg`` with ``key`` and return the encrypted message.""" 102 raise NotImplementedError 103 104 def decrypt(self, msg, key): 105 """Return decrypted message.""" 106 raise NotImplementedError 107 108 109class RSAEncrypter(Encrypter): 110 def encrypt(self, msg, key, padding="pkcs1_padding"): 111 if padding == "pkcs1_padding": 112 cipher = PKCS1_v1_5.new(key) 113 if self.with_digest: # add a SHA digest to the message 114 h = SHA.new(msg) 115 msg += h.digest() 116 elif padding == "pkcs1_oaep_padding": 117 cipher = PKCS1_OAEP.new(key) 118 else: 119 raise Exception("Unsupported padding") 120 return cipher.encrypt(msg) 121 122 def decrypt(self, ciphertext, key, padding="pkcs1_padding"): 123 if padding == "pkcs1_padding": 124 cipher = PKCS1_v1_5.new(key) 125 if self.with_digest: 126 dsize = SHA.digest_size 127 else: 128 dsize = 0 129 sentinel = Random.new().read(32 + dsize) 130 text = cipher.decrypt(ciphertext, sentinel) 131 if dsize: 132 _digest = text[-dsize:] 133 _msg = text[:-dsize] 134 digest = SHA.new(_msg).digest() 135 if digest == _digest: 136 text = _msg 137 else: 138 raise DecryptionFailed() 139 else: 140 if text == sentinel: 141 raise DecryptionFailed() 142 elif padding == "pkcs1_oaep_padding": 143 cipher = PKCS1_OAEP.new(key) 144 text = cipher.decrypt(ciphertext) 145 else: 146 raise Exception("Unsupported padding") 147 148 return text 149 150 151# --------------------------------------------------------------------------- 152 153 154def int2bigendian(n): 155 return [ord(c) for c in struct.pack('>I', n)] 156 157 158def party_value(pv): 159 if pv: 160 s = b64e(pv) 161 r = int2bigendian(len(s)) 162 r.extend(s) 163 return r 164 else: 165 return [0, 0, 0, 0] 166 167 168def _hash_input(cmk, enc, label, rond=1, length=128, hashsize=256, 169 epu="", epv=""): 170 r = [0, 0, 0, rond] 171 r.extend(cmk) 172 r.extend([0, 0, 0, length]) 173 r.extend([ord(c) for c in enc]) 174 r.extend(party_value(epu)) 175 r.extend(party_value(epv)) 176 r.extend(label) 177 return r 178 179 180# --------------------------------------------------------------------------- 181 182def cipher_filter(cipher, inf, outf): 183 while 1: 184 buf = inf.read() 185 if not buf: 186 break 187 outf.write(cipher.update(buf)) 188 outf.write(cipher.final()) 189 return outf.getvalue() 190 191 192def aes_enc(key, txt): 193 pbuf = io.StringIO(txt) 194 cbuf = io.StringIO() 195 ciphertext = cipher_filter(key, pbuf, cbuf) 196 pbuf.close() 197 cbuf.close() 198 return ciphertext 199 200 201def aes_dec(key, ciptxt): 202 pbuf = io.StringIO() 203 cbuf = io.StringIO(ciptxt) 204 plaintext = cipher_filter(key, cbuf, pbuf) 205 pbuf.close() 206 cbuf.close() 207 return plaintext 208 209 210def keysize(spec): 211 if spec.startswith("HS"): 212 return int(spec[2:]) 213 elif spec.startswith("CS"): 214 return int(spec[2:]) 215 elif spec.startswith("A"): 216 return int(spec[1:4]) 217 return 0 218 219 220ENC2ALG = {"A128CBC": "aes_128_cbc", "A192CBC": "aes_192_cbc", 221 "A256CBC": "aes_256_cbc"} 222 223SUPPORTED = { 224 "alg": ["RSA1_5", "RSA-OAEP", "A128KW", "A192KW", "A256KW", 225 "ECDH-ES", "ECDH-ES+A128KW", "ECDH-ES+A192KW", "ECDH-ES+A256KW"], 226 "enc": ["A128CBC-HS256", "A192CBC-HS384", "A256CBC-HS512", 227 # "A128GCM", "A192GCM", 228 "A256GCM"], 229} 230 231 232def alg2keytype(alg): 233 if alg.startswith("RSA"): 234 return "RSA" 235 elif alg.startswith("A"): 236 return "oct" 237 elif alg.startswith("ECDH"): 238 return "EC" 239 else: 240 return None 241 242 243# ============================================================================= 244 245ENCALGLEN1 = { 246 "A128GCM": 16, 247 "A192GCM": 24, 248 "A256GCM": 32 249} 250 251ENCALGLEN2 = { 252 "A128CBC-HS256": 32, 253 "A192CBC-HS384": 48, 254 "A256CBC-HS512": 64, 255} 256 257 258class JWEnc(JWT): 259 def b64_protected_header(self): 260 return self.b64part[0] 261 262 def b64_encrypted_key(self): 263 return self.b64part[1] 264 265 def b64_initialization_vector(self): 266 return self.b64part[2] 267 268 def b64_ciphertext(self): 269 return self.b64part[3] 270 271 def b64_authentication_tag(self): 272 return self.b64part[4] 273 274 def protected_header(self): 275 return self.part[0] 276 277 def encrypted_key(self): 278 return self.part[1] 279 280 def initialization_vector(self): 281 return self.part[2] 282 283 def ciphertext(self): 284 return self.part[3] 285 286 def authentication_tag(self): 287 return self.part[4] 288 289 def b64_encode_header(self): 290 return b64encode_item(self.headers) 291 292 def is_jwe(self): 293 if "typ" in self.headers and self.headers["typ"].lower() == "jwe": 294 return True 295 296 try: 297 assert "alg" in self.headers and "enc" in self.headers 298 except AssertionError: 299 return False 300 else: 301 for typ in ["alg", "enc"]: 302 try: 303 assert self.headers[typ] in SUPPORTED[typ] 304 except AssertionError: 305 logger.debug("Not supported %s algorithm: %s" % ( 306 typ, self.headers[typ])) 307 return False 308 return True 309 310 311class JWe(JWx): 312 @staticmethod 313 def _generate_key_and_iv(encalg, cek="", iv=""): 314 if cek and iv: 315 return cek, iv 316 317 try: 318 _key = Random.get_random_bytes(ENCALGLEN1[encalg]) 319 _iv = Random.get_random_bytes(12) 320 except KeyError: 321 try: 322 _key = Random.get_random_bytes(ENCALGLEN2[encalg]) 323 _iv = Random.get_random_bytes(16) 324 except KeyError: 325 raise Exception("Unsupported encryption algorithm %s" % encalg) 326 if cek: 327 _key = cek 328 if iv: 329 _iv = iv 330 331 return _key, _iv 332 333 def alg2keytype(self, alg): 334 return alg2keytype(alg) 335 336 def enc_setup(self, enc_alg, msg, auth_data, key=None, iv=""): 337 """ Encrypt JWE content. 338 339 :param enc_alg: The JWE "enc" value specifying the encryption algorithm 340 :param msg: The plain text message 341 :param auth_data: Additional authenticated data 342 :param key: Key (CEK) 343 :return: Tuple (ciphertext, tag), both as bytes 344 """ 345 346 key, iv = self._generate_key_and_iv(enc_alg, key, iv) 347 348 if enc_alg == "A256GCM": 349 gcm = AES_GCM(bytes_to_long(key)) 350 ctxt, tag = gcm.encrypt(bytes_to_long(iv), msg, auth_data) 351 tag = long_to_bytes(tag) 352 elif enc_alg in ["A128CBC-HS256", "A192CBC-HS384", "A256CBC-HS512"]: 353 assert enc_alg in SUPPORTED["enc"] 354 ctxt, tag = aes_cbc_hmac_encrypt(key, iv, auth_data, msg) 355 else: 356 raise NotSupportedAlgorithm(enc_alg) 357 358 return ctxt, tag, key 359 360 @staticmethod 361 def _decrypt(enc, key, ctxt, auth_data, iv, tag): 362 """ Decrypt JWE content. 363 364 :param enc: The JWE "enc" value specifying the encryption algorithm 365 :param key: Key (CEK) 366 :param iv : Initialization vector 367 :param auth_data: Additional authenticated data (AAD) 368 :param ctxt : Ciphertext 369 :param tag: Authentication tag 370 :return: plain text message or None if decryption failed 371 """ 372 if enc in ["A128GCM", "A192GCM", "A256GCM"]: 373 gcm = AES_GCM(bytes_to_long(key)) 374 try: 375 text = gcm.decrypt(bytes_to_long(iv), ctxt, bytes_to_long(tag), 376 auth_data) 377 return text, True 378 except DecryptionFailed: 379 return None, False 380 elif enc in ["A128CBC-HS256", "A192CBC-HS384", "A256CBC-HS512"]: 381 return aes_cbc_hmac_decrypt(key, iv, auth_data, ctxt, tag) 382 else: 383 raise Exception("Unsupported encryption algorithm %s" % enc) 384 385 386class JWE_SYM(JWe): 387 args = JWe.args[:] 388 args.append("enc") 389 390 def encrypt(self, key, iv="", cek="", **kwargs): 391 """ 392 393 :param key: Shared symmetric key 394 :param iv: initialization vector 395 :param cek: 396 :param kwargs: Extra keyword arguments, just ignore for now. 397 :return: 398 """ 399 _msg = self.msg 400 401 _args = self._dict 402 try: 403 _args["kid"] = kwargs["kid"] 404 except KeyError: 405 pass 406 407 jwe = JWEnc(**_args) 408 409 # If no iv and cek are given generate them 410 cek, iv = self._generate_key_and_iv(self["enc"], cek, iv) 411 if isinstance(key, six.binary_type): 412 kek = key 413 else: 414 kek = intarr2str(key) 415 416 # The iv for this function must be 64 bit 417 # Which is certainly different from the one above 418 jek = aes_wrap_key(kek, cek) 419 420 _enc = self["enc"] 421 422 ctxt, tag, cek = self.enc_setup(_enc, _msg.encode(), 423 jwe.b64_encode_header(), 424 cek, iv=iv) 425 return jwe.pack(parts=[jek, iv, ctxt, tag]) 426 427 def decrypt(self, token, key=None, cek=None): 428 if not key and not cek: 429 raise MissingKey("On of key or cek must be specified") 430 431 jwe = JWEnc().unpack(token) 432 433 if not cek: 434 jek = jwe.encrypted_key() 435 # The iv for this function must be 64 bit 436 cek = aes_unwrap_key(key, jek) 437 438 msg = self._decrypt( 439 jwe.headers["enc"], cek, jwe.ciphertext(), 440 jwe.b64_protected_header(), 441 jwe.initialization_vector(), jwe.authentication_tag()) 442 443 if "zip" in self and self["zip"] == "DEF": 444 msg = zlib.decompress(msg) 445 446 return msg 447 448 449class JWE_RSA(JWe): 450 args = ["msg", "alg", "enc", "epk", "zip", "jku", "jwk", "x5u", "x5t", 451 "x5c", "kid", "typ", "cty", "apu", "crit"] 452 453 def encrypt(self, key, iv="", cek="", **kwargs): 454 """ 455 Produces a JWE using RSA algorithms 456 457 :param key: RSA key 458 :param context: 459 :param iv: 460 :param cek: 461 :return: A jwe 462 """ 463 464 _msg = as_bytes(self.msg) 465 if "zip" in self: 466 if self["zip"] == "DEF": 467 _msg = zlib.compress(_msg) 468 else: 469 raise ParameterError("Zip has unknown value: %s" % self["zip"]) 470 471 _enc = self["enc"] 472 cek, iv = self._generate_key_and_iv(_enc, cek, iv) 473 474 logger.debug("cek: %s, iv: %s" % ([c for c in cek], [c for c in iv])) 475 476 _encrypt = RSAEncrypter(self.with_digest).encrypt 477 478 _alg = self["alg"] 479 if _alg == "RSA-OAEP": 480 jwe_enc_key = _encrypt(cek, key, 'pkcs1_oaep_padding') 481 elif _alg == "RSA1_5": 482 jwe_enc_key = _encrypt(cek, key) 483 else: 484 raise NotSupportedAlgorithm(_alg) 485 486 jwe = JWEnc(**self.headers()) 487 488 enc_header = jwe.b64_encode_header() 489 490 ctxt, tag, key = self.enc_setup(_enc, _msg, enc_header, cek, iv) 491 return jwe.pack(parts=[jwe_enc_key, iv, ctxt, tag]) 492 493 def decrypt(self, token, key): 494 """ Decrypts a JWT 495 496 :param token: The JWT 497 :param key: A key to use for decrypting 498 :return: The decrypted message 499 """ 500 jwe = JWEnc().unpack(token) 501 self.jwt = jwe.encrypted_key() 502 jek = jwe.encrypted_key() 503 504 _decrypt = RSAEncrypter(self.with_digest).decrypt 505 506 _alg = jwe.headers["alg"] 507 if _alg == "RSA-OAEP": 508 cek = _decrypt(jek, key, 'pkcs1_oaep_padding') 509 elif _alg == "RSA1_5": 510 cek = _decrypt(jek, key) 511 else: 512 raise NotSupportedAlgorithm(_alg) 513 514 enc = jwe.headers["enc"] 515 try: 516 assert enc in SUPPORTED["enc"] 517 except AssertionError: 518 raise NotSupportedAlgorithm(enc) 519 520 msg, flag = self._decrypt(enc, cek, jwe.ciphertext(), 521 jwe.b64_protected_header(), 522 jwe.initialization_vector(), 523 jwe.authentication_tag()) 524 if flag is False: 525 raise DecryptionFailed() 526 527 if "zip" in jwe.headers and jwe.headers["zip"] == "DEF": 528 msg = zlib.decompress(msg) 529 530 return msg 531 532 533class JWE_EC(JWe): 534 def enc_setup(self, msg, auth_data, key=None, **kwargs): 535 536 encrypted_key = "" 537 # Generate the input parameters 538 try: 539 apu = b64d(kwargs["apu"]) 540 except KeyError: 541 apu = b64d(Random.get_random_bytes(16)) 542 try: 543 apv = b64d(kwargs["apv"]) 544 except KeyError: 545 apv = b64d(Random.get_random_bytes(16)) 546 547 # Generate an ephemeral key pair 548 curve = NISTEllipticCurve.by_name(key.crv) 549 if "epk" in kwargs: 550 eprivk = ECKey(kwargs["epk"]) 551 else: 552 (eprivk, epk) = curve.key_pair() 553 params = { 554 "apu": b64e(apu), 555 "apv": b64e(apv), 556 } 557 558 cek, iv = self._generate_key_and_iv(self.enc) 559 if self.alg == "ECDH-ES": 560 try: 561 dk_len = KEYLEN[self.enc] 562 except KeyError: 563 raise Exception( 564 "Unknown key length for algorithm %s" % self.enc) 565 566 cek = ecdh_derive_key(curve, eprivk, key, apu, apv, self.enc, 567 dk_len) 568 elif self.alg in ["ECDH-ES+A128KW", "ECDH-ES+A192KW", "ECDH-ES+A256KW"]: 569 _pre, _post = self.alg.split("+") 570 klen = int(_post[1:4]) 571 kek = ecdh_derive_key(curve, eprivk, key, apu, apv, _post, klen) 572 encrypted_key = aes_wrap_key(kek, cek) 573 else: 574 raise Exception("Unsupported algorithm %s" % self.alg) 575 576 return cek, encrypted_key, iv, params 577 578 579class JWE(JWx): 580 args = ["alg", "enc", "epk", "zip", "jku", "jwk", "x5u", "x5t", 581 "x5c", "kid", "typ", "cty", "apu", "crit"] 582 583 """ 584 :param msg: The message 585 :param alg: Algorithm 586 :param enc: Encryption Method 587 :param epk: Ephemeral Public Key 588 :param zip: Compression Algorithm 589 :param jku: a URI that refers to a resource for a set of JSON-encoded 590 public keys, one of which corresponds to the key used to digitally 591 sign the JWS 592 :param jwk: A JSON Web Key that corresponds to the key used to 593 digitally sign the JWS 594 :param x5u: a URI that refers to a resource for the X.509 public key 595 certificate or certificate chain [RFC5280] corresponding to the key 596 used to digitally sign the JWS. 597 :param x5t: a base64url encoded SHA-1 thumbprint (a.k.a. digest) of the 598 DER encoding of the X.509 certificate [RFC5280] corresponding to 599 the key used to digitally sign the JWS. 600 :param x5c: the X.509 public key certificate or certificate chain 601 corresponding to the key used to digitally sign the JWS. 602 :param kid: Key ID a hint indicating which key was used to secure the 603 JWS. 604 :param typ: the type of this object. 'JWS' == JWS Compact Serialization 605 'JWS+JSON' == JWS JSON Serialization 606 :param cty: Content Type 607 :param apu: Agreement PartyUInfo 608 :param crit: indicates which extensions that are being used and MUST 609 be understood and processed. 610 :return: A class instance 611 """ 612 613 def encrypt(self, keys=None, cek="", iv="", **kwargs): 614 """ 615 616 :param keys: A set of possibly usable keys 617 :param context: If the other party's public or my private key should be 618 used for encryption 619 :param cek: Content master key 620 :param iv: Initialization vector 621 :param kwargs: Extra key word arguments 622 :return: Encrypted message 623 """ 624 _alg = self["alg"] 625 if _alg.startswith("RSA") and _alg in ["RSA-OAEP", "RSA1_5"]: 626 encrypter = JWE_RSA(self.msg, **self._dict) 627 elif _alg.startswith("A") and _alg.endswith("KW"): 628 encrypter = JWE_SYM(self.msg, **self._dict) 629 else: 630 logger.error("'{}' is not a supported algorithm".format(_alg)) 631 raise NotSupportedAlgorithm 632 633 if keys: 634 keys = self._pick_keys(keys, use="enc") 635 else: 636 keys = self._pick_keys(self._get_keys(), use="enc") 637 638 if not keys: 639 logger.error( 640 "Could not find any suitable encryption key for alg='{" 641 "}'".format(_alg)) 642 raise NoSuitableEncryptionKey(_alg) 643 644 if cek: 645 kwargs["cek"] = cek 646 if iv: 647 kwargs["iv"] = iv 648 649 for key in keys: 650 _key = key.encryption_key(alg=_alg, private=True) 651 652 if key.kid: 653 encrypter["kid"] = key.kid 654 655 try: 656 token = encrypter.encrypt(_key, **kwargs) 657 except TypeError as err: 658 raise err 659 else: 660 logger.debug( 661 "Encrypted message using key with kid={}".format(key.kid)) 662 return token 663 664 logger.error("Could not find any suitable encryption key") 665 raise NoSuitableEncryptionKey() 666 667 def decrypt(self, token, keys=None, alg=None): 668 jwe = JWEnc().unpack(token) 669 # header, ek, eiv, ctxt, tag = token.split(b".") 670 # self.parse_header(header) 671 672 _alg = jwe.headers["alg"] 673 if alg and alg != _alg: 674 raise WrongEncryptionAlgorithm() 675 676 if _alg in ["RSA-OAEP", "RSA1_5"]: 677 decrypter = JWE_RSA(**self._dict) 678 elif _alg.startswith("A") and _alg.endswith("KW"): 679 decrypter = JWE_SYM(self.msg, **self._dict) 680 else: 681 raise NotSupportedAlgorithm 682 683 if keys: 684 keys = self._pick_keys(keys, use="enc", alg=_alg) 685 else: 686 keys = self._pick_keys(self._get_keys(), use="enc", alg=_alg) 687 688 if not keys: 689 raise NoSuitableDecryptionKey(_alg) 690 691 for key in keys: 692 _key = key.encryption_key(alg=_alg, private=False) 693 try: 694 msg = decrypter.decrypt(as_bytes(token), _key) 695 except (KeyError, DecryptionFailed): 696 pass 697 else: 698 logger.debug( 699 "Decrypted message using key with kid=%s" % key.kid) 700 return msg 701 702 raise DecryptionFailed( 703 "No available key that could decrypt the message") 704 705 706def factory(token): 707 _jwt = JWEnc().unpack(token) 708 if _jwt.is_jwe(): 709 _jwe = JWE() 710 _jwe.jwt = _jwt 711 return _jwe 712 else: 713 return None 714