1# from future import standard_library
2# standard_library.install_aliases()
3try:
4    from builtins import object
5except ImportError:
6    pass
7
8import struct
9import io
10import logging
11import zlib
12import six
13
14from Crypto import Random
15from Crypto.Hash import SHA
16from Crypto.Util.number import bytes_to_long
17from Crypto.Util.number import long_to_bytes
18from Crypto.Cipher import PKCS1_v1_5
19from Crypto.Cipher import PKCS1_OAEP
20
21from jwkest import b64d, as_bytes
22from jwkest import b64e
23from jwkest import JWKESTException
24from jwkest import MissingKey
25from jwkest.aes_gcm import AES_GCM
26from jwkest.aes_key_wrap import aes_wrap_key
27from jwkest.aes_key_wrap import aes_unwrap_key
28from jwkest.ecc import NISTEllipticCurve
29from jwkest.extra import aes_cbc_hmac_encrypt
30from jwkest.extra import ecdh_derive_key
31from jwkest.extra import aes_cbc_hmac_decrypt
32from jwkest.jwk import intarr2str
33from jwkest.jwk import ECKey
34from jwkest.jws import JWx
35from jwkest.jwt import JWT, b64encode_item
36
37logger = logging.getLogger(__name__)
38
39__author__ = 'rohe0002'
40
41ENC = 1
42DEC = 0
43
44
45class JWEException(JWKESTException):
46    pass
47
48
49class CannotDecode(JWEException):
50    pass
51
52
53class NotSupportedAlgorithm(JWEException):
54    pass
55
56
57class MethodNotSupported(JWEException):
58    pass
59
60
61class ParameterError(JWEException):
62    pass
63
64
65class NoSuitableEncryptionKey(JWEException):
66    pass
67
68
69class NoSuitableDecryptionKey(JWEException):
70    pass
71
72
73class DecryptionFailed(JWEException):
74    pass
75
76
77class WrongEncryptionAlgorithm(JWEException):
78    pass
79
80
81# ---------------------------------------------------------------------------
82# Base class
83
84KEYLEN = {
85    "A128GCM": 128,
86    "A192GCM": 192,
87    "A256GCM": 256,
88    "A128CBC-HS256": 256,
89    "A192CBC-HS384": 384,
90    "A256CBC-HS512": 512
91}
92
93
94class Encrypter(object):
95    """Abstract base class for encryption algorithms."""
96
97    def __init__(self, with_digest=False):
98        self.with_digest = with_digest
99
100    def encrypt(self, msg, key):
101        """Encrypt ``msg`` with ``key`` and return the encrypted message."""
102        raise NotImplementedError
103
104    def decrypt(self, msg, key):
105        """Return decrypted message."""
106        raise NotImplementedError
107
108
109class RSAEncrypter(Encrypter):
110    def encrypt(self, msg, key, padding="pkcs1_padding"):
111        if padding == "pkcs1_padding":
112            cipher = PKCS1_v1_5.new(key)
113            if self.with_digest:  # add a SHA digest to the message
114                h = SHA.new(msg)
115                msg += h.digest()
116        elif padding == "pkcs1_oaep_padding":
117            cipher = PKCS1_OAEP.new(key)
118        else:
119            raise Exception("Unsupported padding")
120        return cipher.encrypt(msg)
121
122    def decrypt(self, ciphertext, key, padding="pkcs1_padding"):
123        if padding == "pkcs1_padding":
124            cipher = PKCS1_v1_5.new(key)
125            if self.with_digest:
126                dsize = SHA.digest_size
127            else:
128                dsize = 0
129            sentinel = Random.new().read(32 + dsize)
130            text = cipher.decrypt(ciphertext, sentinel)
131            if dsize:
132                _digest = text[-dsize:]
133                _msg = text[:-dsize]
134                digest = SHA.new(_msg).digest()
135                if digest == _digest:
136                    text = _msg
137                else:
138                    raise DecryptionFailed()
139            else:
140                if text == sentinel:
141                    raise DecryptionFailed()
142        elif padding == "pkcs1_oaep_padding":
143            cipher = PKCS1_OAEP.new(key)
144            text = cipher.decrypt(ciphertext)
145        else:
146            raise Exception("Unsupported padding")
147
148        return text
149
150
151# ---------------------------------------------------------------------------
152
153
154def int2bigendian(n):
155    return [ord(c) for c in struct.pack('>I', n)]
156
157
158def party_value(pv):
159    if pv:
160        s = b64e(pv)
161        r = int2bigendian(len(s))
162        r.extend(s)
163        return r
164    else:
165        return [0, 0, 0, 0]
166
167
168def _hash_input(cmk, enc, label, rond=1, length=128, hashsize=256,
169                epu="", epv=""):
170    r = [0, 0, 0, rond]
171    r.extend(cmk)
172    r.extend([0, 0, 0, length])
173    r.extend([ord(c) for c in enc])
174    r.extend(party_value(epu))
175    r.extend(party_value(epv))
176    r.extend(label)
177    return r
178
179
180# ---------------------------------------------------------------------------
181
182def cipher_filter(cipher, inf, outf):
183    while 1:
184        buf = inf.read()
185        if not buf:
186            break
187        outf.write(cipher.update(buf))
188    outf.write(cipher.final())
189    return outf.getvalue()
190
191
192def aes_enc(key, txt):
193    pbuf = io.StringIO(txt)
194    cbuf = io.StringIO()
195    ciphertext = cipher_filter(key, pbuf, cbuf)
196    pbuf.close()
197    cbuf.close()
198    return ciphertext
199
200
201def aes_dec(key, ciptxt):
202    pbuf = io.StringIO()
203    cbuf = io.StringIO(ciptxt)
204    plaintext = cipher_filter(key, cbuf, pbuf)
205    pbuf.close()
206    cbuf.close()
207    return plaintext
208
209
210def keysize(spec):
211    if spec.startswith("HS"):
212        return int(spec[2:])
213    elif spec.startswith("CS"):
214        return int(spec[2:])
215    elif spec.startswith("A"):
216        return int(spec[1:4])
217    return 0
218
219
220ENC2ALG = {"A128CBC": "aes_128_cbc", "A192CBC": "aes_192_cbc",
221           "A256CBC": "aes_256_cbc"}
222
223SUPPORTED = {
224    "alg": ["RSA1_5", "RSA-OAEP", "A128KW", "A192KW", "A256KW",
225            "ECDH-ES", "ECDH-ES+A128KW", "ECDH-ES+A192KW", "ECDH-ES+A256KW"],
226    "enc": ["A128CBC-HS256", "A192CBC-HS384", "A256CBC-HS512",
227            # "A128GCM", "A192GCM",
228            "A256GCM"],
229}
230
231
232def alg2keytype(alg):
233    if alg.startswith("RSA"):
234        return "RSA"
235    elif alg.startswith("A"):
236        return "oct"
237    elif alg.startswith("ECDH"):
238        return "EC"
239    else:
240        return None
241
242
243# =============================================================================
244
245ENCALGLEN1 = {
246    "A128GCM": 16,
247    "A192GCM": 24,
248    "A256GCM": 32
249}
250
251ENCALGLEN2 = {
252    "A128CBC-HS256": 32,
253    "A192CBC-HS384": 48,
254    "A256CBC-HS512": 64,
255}
256
257
258class JWEnc(JWT):
259    def b64_protected_header(self):
260        return self.b64part[0]
261
262    def b64_encrypted_key(self):
263        return self.b64part[1]
264
265    def b64_initialization_vector(self):
266        return self.b64part[2]
267
268    def b64_ciphertext(self):
269        return self.b64part[3]
270
271    def b64_authentication_tag(self):
272        return self.b64part[4]
273
274    def protected_header(self):
275        return self.part[0]
276
277    def encrypted_key(self):
278        return self.part[1]
279
280    def initialization_vector(self):
281        return self.part[2]
282
283    def ciphertext(self):
284        return self.part[3]
285
286    def authentication_tag(self):
287        return self.part[4]
288
289    def b64_encode_header(self):
290        return b64encode_item(self.headers)
291
292    def is_jwe(self):
293        if "typ" in self.headers and self.headers["typ"].lower() == "jwe":
294            return True
295
296        try:
297            assert "alg" in self.headers and "enc" in self.headers
298        except AssertionError:
299            return False
300        else:
301            for typ in ["alg", "enc"]:
302                try:
303                    assert self.headers[typ] in SUPPORTED[typ]
304                except AssertionError:
305                    logger.debug("Not supported %s algorithm: %s" % (
306                        typ, self.headers[typ]))
307                    return False
308        return True
309
310
311class JWe(JWx):
312    @staticmethod
313    def _generate_key_and_iv(encalg, cek="", iv=""):
314        if cek and iv:
315            return cek, iv
316
317        try:
318            _key = Random.get_random_bytes(ENCALGLEN1[encalg])
319            _iv = Random.get_random_bytes(12)
320        except KeyError:
321            try:
322                _key = Random.get_random_bytes(ENCALGLEN2[encalg])
323                _iv = Random.get_random_bytes(16)
324            except KeyError:
325                raise Exception("Unsupported encryption algorithm %s" % encalg)
326        if cek:
327            _key = cek
328        if iv:
329            _iv = iv
330
331        return _key, _iv
332
333    def alg2keytype(self, alg):
334        return alg2keytype(alg)
335
336    def enc_setup(self, enc_alg, msg, auth_data, key=None, iv=""):
337        """ Encrypt JWE content.
338
339        :param enc_alg: The JWE "enc" value specifying the encryption algorithm
340        :param msg: The plain text message
341        :param auth_data: Additional authenticated data
342        :param key: Key (CEK)
343        :return: Tuple (ciphertext, tag), both as bytes
344        """
345
346        key, iv = self._generate_key_and_iv(enc_alg, key, iv)
347
348        if enc_alg == "A256GCM":
349            gcm = AES_GCM(bytes_to_long(key))
350            ctxt, tag = gcm.encrypt(bytes_to_long(iv), msg, auth_data)
351            tag = long_to_bytes(tag)
352        elif enc_alg in ["A128CBC-HS256", "A192CBC-HS384", "A256CBC-HS512"]:
353            assert enc_alg in SUPPORTED["enc"]
354            ctxt, tag = aes_cbc_hmac_encrypt(key, iv, auth_data, msg)
355        else:
356            raise NotSupportedAlgorithm(enc_alg)
357
358        return ctxt, tag, key
359
360    @staticmethod
361    def _decrypt(enc, key, ctxt, auth_data, iv, tag):
362        """ Decrypt JWE content.
363
364        :param enc: The JWE "enc" value specifying the encryption algorithm
365        :param key: Key (CEK)
366        :param iv : Initialization vector
367        :param auth_data: Additional authenticated data (AAD)
368        :param ctxt : Ciphertext
369        :param tag: Authentication tag
370        :return: plain text message or None if decryption failed
371        """
372        if enc in ["A128GCM", "A192GCM", "A256GCM"]:
373            gcm = AES_GCM(bytes_to_long(key))
374            try:
375                text = gcm.decrypt(bytes_to_long(iv), ctxt, bytes_to_long(tag),
376                                   auth_data)
377                return text, True
378            except DecryptionFailed:
379                return None, False
380        elif enc in ["A128CBC-HS256", "A192CBC-HS384", "A256CBC-HS512"]:
381            return aes_cbc_hmac_decrypt(key, iv, auth_data, ctxt, tag)
382        else:
383            raise Exception("Unsupported encryption algorithm %s" % enc)
384
385
386class JWE_SYM(JWe):
387    args = JWe.args[:]
388    args.append("enc")
389
390    def encrypt(self, key, iv="", cek="", **kwargs):
391        """
392
393        :param key: Shared symmetric key
394        :param iv: initialization vector
395        :param cek:
396        :param kwargs: Extra keyword arguments, just ignore for now.
397        :return:
398        """
399        _msg = self.msg
400
401        _args = self._dict
402        try:
403            _args["kid"] = kwargs["kid"]
404        except KeyError:
405            pass
406
407        jwe = JWEnc(**_args)
408
409        # If no iv and cek are given generate them
410        cek, iv = self._generate_key_and_iv(self["enc"], cek, iv)
411        if isinstance(key, six.binary_type):
412            kek = key
413        else:
414            kek = intarr2str(key)
415
416        # The iv for this function must be 64 bit
417        # Which is certainly different from the one above
418        jek = aes_wrap_key(kek, cek)
419
420        _enc = self["enc"]
421
422        ctxt, tag, cek = self.enc_setup(_enc, _msg.encode(),
423                                        jwe.b64_encode_header(),
424                                        cek, iv=iv)
425        return jwe.pack(parts=[jek, iv, ctxt, tag])
426
427    def decrypt(self, token, key=None, cek=None):
428        if not key and not cek:
429            raise MissingKey("On of key or cek must be specified")
430
431        jwe = JWEnc().unpack(token)
432
433        if not cek:
434            jek = jwe.encrypted_key()
435            # The iv for this function must be 64 bit
436            cek = aes_unwrap_key(key, jek)
437
438        msg = self._decrypt(
439            jwe.headers["enc"], cek, jwe.ciphertext(),
440            jwe.b64_protected_header(),
441            jwe.initialization_vector(), jwe.authentication_tag())
442
443        if "zip" in self and self["zip"] == "DEF":
444            msg = zlib.decompress(msg)
445
446        return msg
447
448
449class JWE_RSA(JWe):
450    args = ["msg", "alg", "enc", "epk", "zip", "jku", "jwk", "x5u", "x5t",
451            "x5c", "kid", "typ", "cty", "apu", "crit"]
452
453    def encrypt(self, key, iv="", cek="", **kwargs):
454        """
455        Produces a JWE using RSA algorithms
456
457        :param key: RSA key
458        :param context:
459        :param iv:
460        :param cek:
461        :return: A jwe
462        """
463
464        _msg = as_bytes(self.msg)
465        if "zip" in self:
466            if self["zip"] == "DEF":
467                _msg = zlib.compress(_msg)
468            else:
469                raise ParameterError("Zip has unknown value: %s" % self["zip"])
470
471        _enc = self["enc"]
472        cek, iv = self._generate_key_and_iv(_enc, cek, iv)
473
474        logger.debug("cek: %s, iv: %s" % ([c for c in cek], [c for c in iv]))
475
476        _encrypt = RSAEncrypter(self.with_digest).encrypt
477
478        _alg = self["alg"]
479        if _alg == "RSA-OAEP":
480            jwe_enc_key = _encrypt(cek, key, 'pkcs1_oaep_padding')
481        elif _alg == "RSA1_5":
482            jwe_enc_key = _encrypt(cek, key)
483        else:
484            raise NotSupportedAlgorithm(_alg)
485
486        jwe = JWEnc(**self.headers())
487
488        enc_header = jwe.b64_encode_header()
489
490        ctxt, tag, key = self.enc_setup(_enc, _msg, enc_header, cek, iv)
491        return jwe.pack(parts=[jwe_enc_key, iv, ctxt, tag])
492
493    def decrypt(self, token, key):
494        """ Decrypts a JWT
495
496        :param token: The JWT
497        :param key: A key to use for decrypting
498        :return: The decrypted message
499        """
500        jwe = JWEnc().unpack(token)
501        self.jwt = jwe.encrypted_key()
502        jek = jwe.encrypted_key()
503
504        _decrypt = RSAEncrypter(self.with_digest).decrypt
505
506        _alg = jwe.headers["alg"]
507        if _alg == "RSA-OAEP":
508            cek = _decrypt(jek, key, 'pkcs1_oaep_padding')
509        elif _alg == "RSA1_5":
510            cek = _decrypt(jek, key)
511        else:
512            raise NotSupportedAlgorithm(_alg)
513
514        enc = jwe.headers["enc"]
515        try:
516            assert enc in SUPPORTED["enc"]
517        except AssertionError:
518            raise NotSupportedAlgorithm(enc)
519
520        msg, flag = self._decrypt(enc, cek, jwe.ciphertext(),
521                                  jwe.b64_protected_header(),
522                                  jwe.initialization_vector(),
523                                  jwe.authentication_tag())
524        if flag is False:
525            raise DecryptionFailed()
526
527        if "zip" in jwe.headers and jwe.headers["zip"] == "DEF":
528            msg = zlib.decompress(msg)
529
530        return msg
531
532
533class JWE_EC(JWe):
534    def enc_setup(self, msg, auth_data, key=None, **kwargs):
535
536        encrypted_key = ""
537        # Generate the input parameters
538        try:
539            apu = b64d(kwargs["apu"])
540        except KeyError:
541            apu = b64d(Random.get_random_bytes(16))
542        try:
543            apv = b64d(kwargs["apv"])
544        except KeyError:
545            apv = b64d(Random.get_random_bytes(16))
546
547        # Generate an ephemeral key pair
548        curve = NISTEllipticCurve.by_name(key.crv)
549        if "epk" in kwargs:
550            eprivk = ECKey(kwargs["epk"])
551        else:
552            (eprivk, epk) = curve.key_pair()
553        params = {
554            "apu": b64e(apu),
555            "apv": b64e(apv),
556        }
557
558        cek, iv = self._generate_key_and_iv(self.enc)
559        if self.alg == "ECDH-ES":
560            try:
561                dk_len = KEYLEN[self.enc]
562            except KeyError:
563                raise Exception(
564                    "Unknown key length for algorithm %s" % self.enc)
565
566            cek = ecdh_derive_key(curve, eprivk, key, apu, apv, self.enc,
567                                  dk_len)
568        elif self.alg in ["ECDH-ES+A128KW", "ECDH-ES+A192KW", "ECDH-ES+A256KW"]:
569            _pre, _post = self.alg.split("+")
570            klen = int(_post[1:4])
571            kek = ecdh_derive_key(curve, eprivk, key, apu, apv, _post, klen)
572            encrypted_key = aes_wrap_key(kek, cek)
573        else:
574            raise Exception("Unsupported algorithm %s" % self.alg)
575
576        return cek, encrypted_key, iv, params
577
578
579class JWE(JWx):
580    args = ["alg", "enc", "epk", "zip", "jku", "jwk", "x5u", "x5t",
581            "x5c", "kid", "typ", "cty", "apu", "crit"]
582
583    """
584    :param msg: The message
585    :param alg: Algorithm
586    :param enc: Encryption Method
587    :param epk: Ephemeral Public Key
588    :param zip: Compression Algorithm
589    :param jku: a URI that refers to a resource for a set of JSON-encoded
590        public keys, one of which corresponds to the key used to digitally
591        sign the JWS
592    :param jwk: A JSON Web Key that corresponds to the key used to
593        digitally sign the JWS
594    :param x5u: a URI that refers to a resource for the X.509 public key
595        certificate or certificate chain [RFC5280] corresponding to the key
596        used to digitally sign the JWS.
597    :param x5t: a base64url encoded SHA-1 thumbprint (a.k.a. digest) of the
598        DER encoding of the X.509 certificate [RFC5280] corresponding to
599        the key used to digitally sign the JWS.
600    :param x5c: the X.509 public key certificate or certificate chain
601        corresponding to the key used to digitally sign the JWS.
602    :param kid: Key ID a hint indicating which key was used to secure the
603        JWS.
604    :param typ: the type of this object. 'JWS' == JWS Compact Serialization
605        'JWS+JSON' == JWS JSON Serialization
606    :param cty: Content Type
607    :param apu: Agreement PartyUInfo
608    :param crit: indicates which extensions that are being used and MUST
609        be understood and processed.
610    :return: A class instance
611    """
612
613    def encrypt(self, keys=None, cek="", iv="", **kwargs):
614        """
615
616        :param keys: A set of possibly usable keys
617        :param context: If the other party's public or my private key should be
618            used for encryption
619        :param cek: Content master key
620        :param iv: Initialization vector
621        :param kwargs: Extra key word arguments
622        :return: Encrypted message
623        """
624        _alg = self["alg"]
625        if _alg.startswith("RSA") and _alg in ["RSA-OAEP", "RSA1_5"]:
626            encrypter = JWE_RSA(self.msg, **self._dict)
627        elif _alg.startswith("A") and _alg.endswith("KW"):
628            encrypter = JWE_SYM(self.msg, **self._dict)
629        else:
630            logger.error("'{}' is not a supported algorithm".format(_alg))
631            raise NotSupportedAlgorithm
632
633        if keys:
634            keys = self._pick_keys(keys, use="enc")
635        else:
636            keys = self._pick_keys(self._get_keys(), use="enc")
637
638        if not keys:
639            logger.error(
640                "Could not find any suitable encryption key for alg='{"
641                "}'".format(_alg))
642            raise NoSuitableEncryptionKey(_alg)
643
644        if cek:
645            kwargs["cek"] = cek
646        if iv:
647            kwargs["iv"] = iv
648
649        for key in keys:
650            _key = key.encryption_key(alg=_alg, private=True)
651
652            if key.kid:
653                encrypter["kid"] = key.kid
654
655            try:
656                token = encrypter.encrypt(_key, **kwargs)
657            except TypeError as err:
658                raise err
659            else:
660                logger.debug(
661                    "Encrypted message using key with kid={}".format(key.kid))
662                return token
663
664        logger.error("Could not find any suitable encryption key")
665        raise NoSuitableEncryptionKey()
666
667    def decrypt(self, token, keys=None, alg=None):
668        jwe = JWEnc().unpack(token)
669        # header, ek, eiv, ctxt, tag = token.split(b".")
670        # self.parse_header(header)
671
672        _alg = jwe.headers["alg"]
673        if alg and alg != _alg:
674            raise WrongEncryptionAlgorithm()
675
676        if _alg in ["RSA-OAEP", "RSA1_5"]:
677            decrypter = JWE_RSA(**self._dict)
678        elif _alg.startswith("A") and _alg.endswith("KW"):
679            decrypter = JWE_SYM(self.msg, **self._dict)
680        else:
681            raise NotSupportedAlgorithm
682
683        if keys:
684            keys = self._pick_keys(keys, use="enc", alg=_alg)
685        else:
686            keys = self._pick_keys(self._get_keys(), use="enc", alg=_alg)
687
688        if not keys:
689            raise NoSuitableDecryptionKey(_alg)
690
691        for key in keys:
692            _key = key.encryption_key(alg=_alg, private=False)
693            try:
694                msg = decrypter.decrypt(as_bytes(token), _key)
695            except (KeyError, DecryptionFailed):
696                pass
697            else:
698                logger.debug(
699                    "Decrypted message using key with kid=%s" % key.kid)
700                return msg
701
702        raise DecryptionFailed(
703            "No available key that could decrypt the message")
704
705
706def factory(token):
707    _jwt = JWEnc().unpack(token)
708    if _jwt.is_jwe():
709        _jwe = JWE()
710        _jwe.jwt = _jwt
711        return _jwe
712    else:
713        return None
714