1import zlib
2from .oct_key import OctKey
3from ._cryptography_backends import JWE_ALG_ALGORITHMS, JWE_ENC_ALGORITHMS
4from ..rfc7516 import JWEAlgorithm, JWEZipAlgorithm, JsonWebEncryption
5
6
7class DirectAlgorithm(JWEAlgorithm):
8    name = 'dir'
9    description = 'Direct use of a shared symmetric key'
10
11    def prepare_key(self, raw_data):
12        return OctKey.import_key(raw_data)
13
14    def wrap(self, enc_alg, headers, key):
15        cek = key.get_op_key('encrypt')
16        if len(cek) * 8 != enc_alg.CEK_SIZE:
17            raise ValueError('Invalid "cek" length')
18        return {'ek': b'', 'cek': cek}
19
20    def unwrap(self, enc_alg, ek, headers, key):
21        cek = key.get_op_key('decrypt')
22        if len(cek) * 8 != enc_alg.CEK_SIZE:
23            raise ValueError('Invalid "cek" length')
24        return cek
25
26
27class DeflateZipAlgorithm(JWEZipAlgorithm):
28    name = 'DEF'
29    description = 'DEFLATE'
30
31    def compress(self, s):
32        """Compress bytes data with DEFLATE algorithm."""
33        data = zlib.compress(s)
34        # drop gzip headers and tail
35        return data[2:-4]
36
37    def decompress(self, s):
38        """Decompress DEFLATE bytes data."""
39        return zlib.decompress(s, -zlib.MAX_WBITS)
40
41
42def register_jwe_rfc7518():
43    JsonWebEncryption.register_algorithm(DirectAlgorithm())
44    JsonWebEncryption.register_algorithm(DeflateZipAlgorithm())
45
46    for algorithm in JWE_ALG_ALGORITHMS:
47        JsonWebEncryption.register_algorithm(algorithm)
48
49    for algorithm in JWE_ENC_ALGORITHMS:
50        JsonWebEncryption.register_algorithm(algorithm)
51