1# This file is dual licensed under the terms of the Apache License, Version
2# 2.0, and the BSD License. See the LICENSE file in the root of this repository
3# for complete details.
4
5from __future__ import absolute_import, division, print_function
6
7from cryptography.exceptions import InvalidTag
8
9
10_ENCRYPT = 1
11_DECRYPT = 0
12
13
14def _aead_cipher_name(cipher):
15    from cryptography.hazmat.primitives.ciphers.aead import (
16        AESCCM,
17        AESGCM,
18        ChaCha20Poly1305,
19    )
20
21    if isinstance(cipher, ChaCha20Poly1305):
22        return b"chacha20-poly1305"
23    elif isinstance(cipher, AESCCM):
24        return "aes-{}-ccm".format(len(cipher._key) * 8).encode("ascii")
25    else:
26        assert isinstance(cipher, AESGCM)
27        return "aes-{}-gcm".format(len(cipher._key) * 8).encode("ascii")
28
29
30def _aead_setup(backend, cipher_name, key, nonce, tag, tag_len, operation):
31    evp_cipher = backend._lib.EVP_get_cipherbyname(cipher_name)
32    backend.openssl_assert(evp_cipher != backend._ffi.NULL)
33    ctx = backend._lib.EVP_CIPHER_CTX_new()
34    ctx = backend._ffi.gc(ctx, backend._lib.EVP_CIPHER_CTX_free)
35    res = backend._lib.EVP_CipherInit_ex(
36        ctx,
37        evp_cipher,
38        backend._ffi.NULL,
39        backend._ffi.NULL,
40        backend._ffi.NULL,
41        int(operation == _ENCRYPT),
42    )
43    backend.openssl_assert(res != 0)
44    res = backend._lib.EVP_CIPHER_CTX_set_key_length(ctx, len(key))
45    backend.openssl_assert(res != 0)
46    res = backend._lib.EVP_CIPHER_CTX_ctrl(
47        ctx,
48        backend._lib.EVP_CTRL_AEAD_SET_IVLEN,
49        len(nonce),
50        backend._ffi.NULL,
51    )
52    backend.openssl_assert(res != 0)
53    if operation == _DECRYPT:
54        res = backend._lib.EVP_CIPHER_CTX_ctrl(
55            ctx, backend._lib.EVP_CTRL_AEAD_SET_TAG, len(tag), tag
56        )
57        backend.openssl_assert(res != 0)
58    elif cipher_name.endswith(b"-ccm"):
59        res = backend._lib.EVP_CIPHER_CTX_ctrl(
60            ctx, backend._lib.EVP_CTRL_AEAD_SET_TAG, tag_len, backend._ffi.NULL
61        )
62        backend.openssl_assert(res != 0)
63
64    nonce_ptr = backend._ffi.from_buffer(nonce)
65    key_ptr = backend._ffi.from_buffer(key)
66    res = backend._lib.EVP_CipherInit_ex(
67        ctx,
68        backend._ffi.NULL,
69        backend._ffi.NULL,
70        key_ptr,
71        nonce_ptr,
72        int(operation == _ENCRYPT),
73    )
74    backend.openssl_assert(res != 0)
75    return ctx
76
77
78def _set_length(backend, ctx, data_len):
79    intptr = backend._ffi.new("int *")
80    res = backend._lib.EVP_CipherUpdate(
81        ctx, backend._ffi.NULL, intptr, backend._ffi.NULL, data_len
82    )
83    backend.openssl_assert(res != 0)
84
85
86def _process_aad(backend, ctx, associated_data):
87    outlen = backend._ffi.new("int *")
88    res = backend._lib.EVP_CipherUpdate(
89        ctx, backend._ffi.NULL, outlen, associated_data, len(associated_data)
90    )
91    backend.openssl_assert(res != 0)
92
93
94def _process_data(backend, ctx, data):
95    outlen = backend._ffi.new("int *")
96    buf = backend._ffi.new("unsigned char[]", len(data))
97    res = backend._lib.EVP_CipherUpdate(ctx, buf, outlen, data, len(data))
98    backend.openssl_assert(res != 0)
99    return backend._ffi.buffer(buf, outlen[0])[:]
100
101
102def _encrypt(backend, cipher, nonce, data, associated_data, tag_length):
103    from cryptography.hazmat.primitives.ciphers.aead import AESCCM
104
105    cipher_name = _aead_cipher_name(cipher)
106    ctx = _aead_setup(
107        backend, cipher_name, cipher._key, nonce, None, tag_length, _ENCRYPT
108    )
109    # CCM requires us to pass the length of the data before processing anything
110    # However calling this with any other AEAD results in an error
111    if isinstance(cipher, AESCCM):
112        _set_length(backend, ctx, len(data))
113
114    _process_aad(backend, ctx, associated_data)
115    processed_data = _process_data(backend, ctx, data)
116    outlen = backend._ffi.new("int *")
117    res = backend._lib.EVP_CipherFinal_ex(ctx, backend._ffi.NULL, outlen)
118    backend.openssl_assert(res != 0)
119    backend.openssl_assert(outlen[0] == 0)
120    tag_buf = backend._ffi.new("unsigned char[]", tag_length)
121    res = backend._lib.EVP_CIPHER_CTX_ctrl(
122        ctx, backend._lib.EVP_CTRL_AEAD_GET_TAG, tag_length, tag_buf
123    )
124    backend.openssl_assert(res != 0)
125    tag = backend._ffi.buffer(tag_buf)[:]
126
127    return processed_data + tag
128
129
130def _decrypt(backend, cipher, nonce, data, associated_data, tag_length):
131    from cryptography.hazmat.primitives.ciphers.aead import AESCCM
132
133    if len(data) < tag_length:
134        raise InvalidTag
135    tag = data[-tag_length:]
136    data = data[:-tag_length]
137    cipher_name = _aead_cipher_name(cipher)
138    ctx = _aead_setup(
139        backend, cipher_name, cipher._key, nonce, tag, tag_length, _DECRYPT
140    )
141    # CCM requires us to pass the length of the data before processing anything
142    # However calling this with any other AEAD results in an error
143    if isinstance(cipher, AESCCM):
144        _set_length(backend, ctx, len(data))
145
146    _process_aad(backend, ctx, associated_data)
147    # CCM has a different error path if the tag doesn't match. Errors are
148    # raised in Update and Final is irrelevant.
149    if isinstance(cipher, AESCCM):
150        outlen = backend._ffi.new("int *")
151        buf = backend._ffi.new("unsigned char[]", len(data))
152        res = backend._lib.EVP_CipherUpdate(ctx, buf, outlen, data, len(data))
153        if res != 1:
154            backend._consume_errors()
155            raise InvalidTag
156
157        processed_data = backend._ffi.buffer(buf, outlen[0])[:]
158    else:
159        processed_data = _process_data(backend, ctx, data)
160        outlen = backend._ffi.new("int *")
161        res = backend._lib.EVP_CipherFinal_ex(ctx, backend._ffi.NULL, outlen)
162        if res == 0:
163            backend._consume_errors()
164            raise InvalidTag
165
166    return processed_data
167